View source Download .ipynb

Training a classifier using PyTorch Lightning¶

This notebooks trains a classifier on CIFAR-10 using Pytorch Lightning.

image1

When using a LightningModule which defines the train_dataloader, val_dataloader and/or test_dataloader methods, we can decorate our LightningModule with the tlc.module_decorator to automatically generate Tables for our datasets and collect any desired metrics into a Run.

Project setup¶

[ ]:
PROJECT_NAME = "3LC Tutorials - PyTorch Lightning Classification"
DOWNLOAD_PATH = "../../transient_data"
EPOCHS = 5
BATCH_SIZE = 32
NUM_WORKERS = 0

Install dependencies¶

[ ]:
%pip install 3lc[pacmap]
%pip install pytorch-lightning
%pip install git+https://github.com/3lc-ai/3lc-examples.git

Imports¶

[ ]:
import pytorch_lightning as pl
import tlc
import torch
import torch.nn.functional as F
import torchvision

from tlc_tools.common import infer_torch_device

Define model creation function¶

[ ]:
# Create model for cifar10 training
def create_model():
    return torchvision.models.resnet18(pretrained=False, num_classes=10)

Define the structure of our dataset¶

[ ]:
################## 3LC ##################

# Define the structure of a sample in the dataset(s)
# Here, the structure is a tuple, where the first element is a PIL image which we will call "Image",
# and the second element is an integer label, which maps to the given classes.
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
structure = (tlc.PILImage("Image"), tlc.CategoricalLabel("Label", classes=classes))

#########################################

Describe the metrics we want to collect¶

[ ]:
################## 3LC ##################

# Define a function for the metrics we want to collect
def metrics_fn(batch, predictor_output: tlc.PredictorOutput):
    # tuple[torch.Tensor, torch.Tensor]
    labels = batch[1].to(infer_torch_device())
    predictions = predictor_output.forward
    num_classes = predictions.shape[1]
    one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()

    # Confidence & Predicted
    softmax_output = torch.nn.functional.softmax(predictions, dim=1)
    predicted_indices = torch.argmax(predictions, dim=1)
    confidence = torch.gather(softmax_output, 1, predicted_indices.unsqueeze(1)).squeeze(1)

    # Per-sample accuracy (1 if correct, 0 otherwise)
    accuracy = (predicted_indices == labels).float()

    # Unreduced Cross Entropy Loss
    cross_entropy_loss: torch.Tensor = torch.nn.CrossEntropyLoss(reduction="none")(predictions, labels)

    # RMSE
    mse: torch.Tensor = torch.nn.MSELoss(reduction="none")(softmax_output, one_hot_labels)
    mse = mse.mean(dim=1)
    rmse = torch.sqrt(mse)

    # MAE
    mae: torch.Tensor = torch.nn.L1Loss(reduction="none")(softmax_output, one_hot_labels)
    mae = mae.mean(dim=1)

    # These values will be the columns of the Run in the 3LC Dashboard
    return {
        "loss": cross_entropy_loss.cpu().numpy(),
        "predicted": predicted_indices.cpu().numpy(),
        "accuracy": accuracy.cpu().numpy(),
        "confidence": confidence.cpu().numpy(),
        "rmse": rmse.cpu().numpy(),
        "mae": mae.cpu().numpy(),
    }


# Schemas will be inferred automatically, but can be explicitly defined if customizations are needed,
# for example to set a description or a value map for an integer label.
schemas = {
    "loss": tlc.Schema(
        description="Cross entropy loss",
        value=tlc.Float32Value(),
    ),
    "predicted": tlc.CategoricalLabelSchema(
        display_name="predicted label",
        class_names=classes,
    ),
}

# Use the metrics function and schemas to create a metrics collector
classification_metrics_collector = tlc.FunctionalMetricsCollector(
    collection_fn=metrics_fn,
    column_schemas=schemas,
)

#########################################

Define our LightningModule (With 3LC decorator)¶

[ ]:
################## 3LC ##################
@tlc.lightning_module(
    structure=structure,
    project_name=PROJECT_NAME,
    metrics_collectors=classification_metrics_collector,
)
#########################################
class MyModule(pl.LightningModule):
    def __init__(self, batch_size=BATCH_SIZE, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = create_model()
        self.batch_size = batch_size
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        # Define transformations for the training dataset
        train_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        # Create the training dataset, including the transformations
        train_dataset = torchvision.datasets.CIFAR10(
            root=DOWNLOAD_PATH,
            train=True,
            download=True,
            transform=train_transform,
        )

        # Create a DataLoader for the training dataset
        return torch.utils.data.DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=NUM_WORKERS
        )

    def val_dataloader(self):
        # Define transformations for the validation dataset
        val_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        # Create the validation dataset, including the transformations
        val_dataset = torchvision.datasets.CIFAR10(
            root=DOWNLOAD_PATH,
            train=False,
            download=True,
            transform=val_transform,
        )

        # Create a DataLoader for the validation dataset
        return torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS)

Run training¶

[ ]:
# Create the LightningModule
module = MyModule()

# Train the model
trainer = pl.Trainer(max_epochs=EPOCHS)
trainer.fit(module)

After training has completed, the Run can be viewed in the 3LC Dashboard.