Training a classifier using PyTorch Lightning¶
This notebook trains a classifier on CIFAR-10 using PyTorch Lightning.

We integrate 3LC with a LightningModule by creating Tables up front (outside the module) and calling 3LC’s public API (tlc.init, tlc.collect_metrics, tlc.metrics.Predictor) directly from standard Lightning hooks (on_train_start, on_train_end).
Project setup¶
[ ]:
PROJECT_NAME = "3LC Tutorials - PyTorch Lightning Classification"
RUN_NAME = "Train classifier"
RUN_DESCRIPTION = "Train a resnet model on CIFAR-10"
TRAIN_DATASET_NAME = "cifar-10-train"
VAL_DATASET_NAME = "cifar-10-val"
DOWNLOAD_PATH = "../../transient_data"
EPOCHS = 5
BATCH_SIZE = 32
NUM_WORKERS = 0
INSTALL_DEPENDENCIES = True
Install dependencies¶
[ ]:
if INSTALL_DEPENDENCIES:
%pip install -q 3lc[pacmap]
%pip install -q pytorch-lightning
%pip install -q git+https://github.com/3lc-ai/3lc-examples.git
Imports¶
[ ]:
import pytorch_lightning as pl
import tlc
import torch
import torch.nn.functional as F
import torchvision
from tlc.integration.torch.samplers import create_sampler
from torch.utils.data import DataLoader
from tlc_tools.common import infer_torch_device
Define model creation function¶
[ ]:
# Create model for cifar10 training
def create_model():
return torchvision.models.resnet18(pretrained=False, num_classes=10)
Define the schema of our dataset¶
[ ]:
################## 3LC ##################
# Schema describes the columns 3LC will write into the Table. We include an explicit
# weight column (`SampleWeightSchema`) so the 3LC sampler has weights to sample from;
# `TableWriter` doesn't add one automatically the way `Table.from_torch_dataset` did.
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
schema = {
"image": tlc.schemas.ImageSchema(),
"label": tlc.schemas.CategoricalLabelSchema(classes=classes),
"weight": tlc.schemas.SampleWeightSchema(),
}
#########################################
Describe the metrics we want to collect¶
[ ]:
################## 3LC ##################
# Define a function for the metrics we want to collect
def metrics_fn(batch, predictor_output: tlc.metrics.PredictorOutput):
# tuple[torch.Tensor, torch.Tensor]
labels = batch[1].to(infer_torch_device())
predictions = predictor_output.forward
num_classes = predictions.shape[1]
one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()
# Confidence & Predicted
softmax_output = torch.nn.functional.softmax(predictions, dim=1)
predicted_indices = torch.argmax(predictions, dim=1)
confidence = torch.gather(softmax_output, 1, predicted_indices.unsqueeze(1)).squeeze(1)
# Per-sample accuracy (1 if correct, 0 otherwise)
accuracy = (predicted_indices == labels).float()
# Unreduced Cross Entropy Loss
cross_entropy_loss: torch.Tensor = torch.nn.CrossEntropyLoss(reduction="none")(predictions, labels)
# RMSE
mse: torch.Tensor = torch.nn.MSELoss(reduction="none")(softmax_output, one_hot_labels)
mse = mse.mean(dim=1)
rmse = torch.sqrt(mse)
# MAE
mae: torch.Tensor = torch.nn.L1Loss(reduction="none")(softmax_output, one_hot_labels)
mae = mae.mean(dim=1)
# These values will be the columns of the Run in the 3LC Dashboard
return {
"loss": cross_entropy_loss.cpu().numpy(),
"predicted": predicted_indices.cpu().numpy(),
"accuracy": accuracy.cpu().numpy(),
"confidence": confidence.cpu().numpy(),
"rmse": rmse.cpu().numpy(),
"mae": mae.cpu().numpy(),
}
# Schemas will be inferred automatically, but can be explicitly defined if customizations are needed,
# for example to set a description or a value map for an integer label.
schemas = {
"loss": tlc.schemas.Float32Schema(description="Cross entropy loss"),
"predicted": tlc.schemas.CategoricalLabelSchema(
display_name="predicted label",
classes=classes,
),
}
# Use the metrics function and schemas to create a metrics collector
classification_metrics_collector = tlc.metrics.FunctionalMetricsCollector(
collection_fn=metrics_fn,
schema=schemas,
)
#########################################
Create 3LC Tables¶
We create the Tables eagerly, outside the LightningModule. This sidesteps the DDP table-coordination problem that arises when tables are created inside train_dataloader (which Lightning replicates per process). With tables on disk before Trainer.fit(), every process simply opens the same Table.
We use tlc.TableWriter directly — iterating the torchvision dataset and pushing rows in batches. Transforms are attached at sample-time via Table.with_transform, so the Table itself preserves the original PIL images for visualization. We use the validation transform (no augmentation) when collecting metrics on the training table.
[ ]:
train_transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
val_transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
def train_fn(sample):
return train_transform(sample["image"]), sample["label"]
def val_fn(sample):
return val_transform(sample["image"]), sample["label"]
def write_cifar_table(dataset, dataset_name):
"""Stream a torchvision CIFAR-10 split into a 3LC Table via TableWriter."""
writer = tlc.TableWriter(
project_name=PROJECT_NAME,
dataset_name=dataset_name,
schema=schema,
if_exists="overwrite",
)
images, labels = [], []
def flush():
if not images:
return
writer.add_batch(
{
"image": images,
"label": labels,
"weight": [1.0] * len(images),
}
)
images.clear()
labels.clear()
for image, label in dataset:
images.append(image)
labels.append(label)
if len(images) >= 1000:
flush()
flush()
return writer.finalize()
raw_train_dataset = torchvision.datasets.CIFAR10(root=DOWNLOAD_PATH, train=True, download=True)
raw_val_dataset = torchvision.datasets.CIFAR10(root=DOWNLOAD_PATH, train=False, download=True)
train_table = write_cifar_table(raw_train_dataset, TRAIN_DATASET_NAME)
val_table = write_cifar_table(raw_val_dataset, VAL_DATASET_NAME)
Define our LightningModule¶
The 3LC integration is just a few standard Lightning hooks:
on_train_startinitializes a Run and records hyperparameters.on_train_endcollects per-sample metrics on the train and val tables, then marks the Run completed.
The dataloaders are built directly from the Tables, with a 3LC weighted sampler on the training side via tlc.integration.torch.samplers.create_sampler.
[ ]:
class MyModule(pl.LightningModule):
def __init__(self, train_table, val_table, batch_size=BATCH_SIZE, lr=1e-3):
super().__init__()
self.save_hyperparameters(ignore=["train_table", "val_table"])
self.train_table = train_table
self.val_table = val_table
self.model = create_model()
self.batch_size = batch_size
self.lr = lr
self.tlc_run: tlc.Run | None = None
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def train_dataloader(self):
return DataLoader(
self.train_table.with_transform(train_fn),
sampler=create_sampler(self.train_table, weighted=True, exclude_zero_weights=True),
batch_size=self.batch_size,
num_workers=NUM_WORKERS,
)
def val_dataloader(self):
return DataLoader(
self.val_table.with_transform(val_fn),
batch_size=self.batch_size,
num_workers=NUM_WORKERS,
)
def on_train_start(self):
super().on_train_start()
self.tlc_run = tlc.init(
project_name=PROJECT_NAME,
run_name=RUN_NAME,
description=RUN_DESCRIPTION,
parameters=dict(self.hparams_initial),
if_exists="rename",
)
self.tlc_run.set_status_running()
def on_train_end(self):
super().on_train_end()
# Mirrors the decorator default: metrics collected once, at end of training.
self._collect_3lc_metrics()
if self.tlc_run is not None:
self.tlc_run.set_status_completed()
def _collect_3lc_metrics(self):
predictor = tlc.metrics.Predictor(self)
# Use the val transform on both splits so metrics aren't computed on augmented images.
for split, table in [("train", self.train_table), ("val", self.val_table)]:
tlc.collect_metrics(
table=table.with_transform(val_fn),
metrics_collectors=[classification_metrics_collector],
predictor=predictor,
split=split,
constants={"epoch": self.current_epoch},
exclude_zero_weights=True,
)
Run training¶
[ ]:
# Create the LightningModule, passing in the Tables we created above.
module = MyModule(train_table=train_table, val_table=val_table)
# Train the model
trainer = pl.Trainer(max_epochs=EPOCHS)
trainer.fit(module)
After training has completed, the Run can be viewed in the 3LC Dashboard.