Training a classifier using PyTorch Lightning

This notebook trains a classifier on CIFAR-10 using PyTorch Lightning.

image1

We integrate 3LC with a LightningModule by creating Tables up front (outside the module) and calling 3LC’s public API (tlc.init, tlc.collect_metrics, tlc.metrics.Predictor) directly from standard Lightning hooks (on_train_start, on_train_end).

Project setup

[ ]:
PROJECT_NAME = "3LC Tutorials - PyTorch Lightning Classification"
RUN_NAME = "Train classifier"
RUN_DESCRIPTION = "Train a resnet model on CIFAR-10"
TRAIN_DATASET_NAME = "cifar-10-train"
VAL_DATASET_NAME = "cifar-10-val"
DOWNLOAD_PATH = "../../transient_data"
EPOCHS = 5
BATCH_SIZE = 32
NUM_WORKERS = 0
INSTALL_DEPENDENCIES = True

Install dependencies

[ ]:
if INSTALL_DEPENDENCIES:
    %pip install -q 3lc[pacmap]
    %pip install -q pytorch-lightning
    %pip install -q git+https://github.com/3lc-ai/3lc-examples.git

Imports

[ ]:
import pytorch_lightning as pl
import tlc
import torch
import torch.nn.functional as F
import torchvision
from tlc.integration.torch.samplers import create_sampler
from torch.utils.data import DataLoader

from tlc_tools.common import infer_torch_device

Define model creation function

[ ]:
# Create model for cifar10 training
def create_model():
    return torchvision.models.resnet18(pretrained=False, num_classes=10)

Define the schema of our dataset

[ ]:
################## 3LC ##################

# Schema describes the columns 3LC will write into the Table. We include an explicit
# weight column (`SampleWeightSchema`) so the 3LC sampler has weights to sample from;
# `TableWriter` doesn't add one automatically the way `Table.from_torch_dataset` did.
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
schema = {
    "image": tlc.schemas.ImageSchema(),
    "label": tlc.schemas.CategoricalLabelSchema(classes=classes),
    "weight": tlc.schemas.SampleWeightSchema(),
}

#########################################

Describe the metrics we want to collect

[ ]:
################## 3LC ##################

# Define a function for the metrics we want to collect
def metrics_fn(batch, predictor_output: tlc.metrics.PredictorOutput):
    # tuple[torch.Tensor, torch.Tensor]
    labels = batch[1].to(infer_torch_device())
    predictions = predictor_output.forward
    num_classes = predictions.shape[1]
    one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()

    # Confidence & Predicted
    softmax_output = torch.nn.functional.softmax(predictions, dim=1)
    predicted_indices = torch.argmax(predictions, dim=1)
    confidence = torch.gather(softmax_output, 1, predicted_indices.unsqueeze(1)).squeeze(1)

    # Per-sample accuracy (1 if correct, 0 otherwise)
    accuracy = (predicted_indices == labels).float()

    # Unreduced Cross Entropy Loss
    cross_entropy_loss: torch.Tensor = torch.nn.CrossEntropyLoss(reduction="none")(predictions, labels)

    # RMSE
    mse: torch.Tensor = torch.nn.MSELoss(reduction="none")(softmax_output, one_hot_labels)
    mse = mse.mean(dim=1)
    rmse = torch.sqrt(mse)

    # MAE
    mae: torch.Tensor = torch.nn.L1Loss(reduction="none")(softmax_output, one_hot_labels)
    mae = mae.mean(dim=1)

    # These values will be the columns of the Run in the 3LC Dashboard
    return {
        "loss": cross_entropy_loss.cpu().numpy(),
        "predicted": predicted_indices.cpu().numpy(),
        "accuracy": accuracy.cpu().numpy(),
        "confidence": confidence.cpu().numpy(),
        "rmse": rmse.cpu().numpy(),
        "mae": mae.cpu().numpy(),
    }


# Schemas will be inferred automatically, but can be explicitly defined if customizations are needed,
# for example to set a description or a value map for an integer label.
schemas = {
    "loss": tlc.schemas.Float32Schema(description="Cross entropy loss"),
    "predicted": tlc.schemas.CategoricalLabelSchema(
        display_name="predicted label",
        classes=classes,
    ),
}

# Use the metrics function and schemas to create a metrics collector
classification_metrics_collector = tlc.metrics.FunctionalMetricsCollector(
    collection_fn=metrics_fn,
    schema=schemas,
)

#########################################

Create 3LC Tables

We create the Tables eagerly, outside the LightningModule. This sidesteps the DDP table-coordination problem that arises when tables are created inside train_dataloader (which Lightning replicates per process). With tables on disk before Trainer.fit(), every process simply opens the same Table.

We use tlc.TableWriter directly — iterating the torchvision dataset and pushing rows in batches. Transforms are attached at sample-time via Table.with_transform, so the Table itself preserves the original PIL images for visualization. We use the validation transform (no augmentation) when collecting metrics on the training table.

[ ]:
train_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

val_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)


def train_fn(sample):
    return train_transform(sample["image"]), sample["label"]


def val_fn(sample):
    return val_transform(sample["image"]), sample["label"]


def write_cifar_table(dataset, dataset_name):
    """Stream a torchvision CIFAR-10 split into a 3LC Table via TableWriter."""
    writer = tlc.TableWriter(
        project_name=PROJECT_NAME,
        dataset_name=dataset_name,
        schema=schema,
        if_exists="overwrite",
    )
    images, labels = [], []

    def flush():
        if not images:
            return
        writer.add_batch(
            {
                "image": images,
                "label": labels,
                "weight": [1.0] * len(images),
            }
        )
        images.clear()
        labels.clear()

    for image, label in dataset:
        images.append(image)
        labels.append(label)
        if len(images) >= 1000:
            flush()
    flush()
    return writer.finalize()


raw_train_dataset = torchvision.datasets.CIFAR10(root=DOWNLOAD_PATH, train=True, download=True)
raw_val_dataset = torchvision.datasets.CIFAR10(root=DOWNLOAD_PATH, train=False, download=True)

train_table = write_cifar_table(raw_train_dataset, TRAIN_DATASET_NAME)
val_table = write_cifar_table(raw_val_dataset, VAL_DATASET_NAME)

Define our LightningModule

The 3LC integration is just a few standard Lightning hooks:

  • on_train_start initializes a Run and records hyperparameters.

  • on_train_end collects per-sample metrics on the train and val tables, then marks the Run completed.

The dataloaders are built directly from the Tables, with a 3LC weighted sampler on the training side via tlc.integration.torch.samplers.create_sampler.

[ ]:
class MyModule(pl.LightningModule):
    def __init__(self, train_table, val_table, batch_size=BATCH_SIZE, lr=1e-3):
        super().__init__()
        self.save_hyperparameters(ignore=["train_table", "val_table"])
        self.train_table = train_table
        self.val_table = val_table
        self.model = create_model()
        self.batch_size = batch_size
        self.lr = lr
        self.tlc_run: tlc.Run | None = None

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        return DataLoader(
            self.train_table.with_transform(train_fn),
            sampler=create_sampler(self.train_table, weighted=True, exclude_zero_weights=True),
            batch_size=self.batch_size,
            num_workers=NUM_WORKERS,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_table.with_transform(val_fn),
            batch_size=self.batch_size,
            num_workers=NUM_WORKERS,
        )

    def on_train_start(self):
        super().on_train_start()
        self.tlc_run = tlc.init(
            project_name=PROJECT_NAME,
            run_name=RUN_NAME,
            description=RUN_DESCRIPTION,
            parameters=dict(self.hparams_initial),
            if_exists="rename",
        )
        self.tlc_run.set_status_running()

    def on_train_end(self):
        super().on_train_end()
        # Mirrors the decorator default: metrics collected once, at end of training.
        self._collect_3lc_metrics()
        if self.tlc_run is not None:
            self.tlc_run.set_status_completed()

    def _collect_3lc_metrics(self):
        predictor = tlc.metrics.Predictor(self)
        # Use the val transform on both splits so metrics aren't computed on augmented images.
        for split, table in [("train", self.train_table), ("val", self.val_table)]:
            tlc.collect_metrics(
                table=table.with_transform(val_fn),
                metrics_collectors=[classification_metrics_collector],
                predictor=predictor,
                split=split,
                constants={"epoch": self.current_epoch},
                exclude_zero_weights=True,
            )

Run training

[ ]:
# Create the LightningModule, passing in the Tables we created above.
module = MyModule(train_table=train_table, val_table=val_table)

# Train the model
trainer = pl.Trainer(max_epochs=EPOCHS)
trainer.fit(module)

After training has completed, the Run can be viewed in the 3LC Dashboard.