PyTorch 3LC MNIST Sample Notebook#

This notebook demonstrates the training of a Convolutional Neural Network (CNN) on the MNIST dataset using PyTorch and 3LC. The built-in MNIST dataset for training and validation is wrapped in a Table. The training runs for 5 epochs, and during this period, classification metrics and embeddings are collected using the ClassificationMetricsCollector class.

The notebook demonstrates:

  • How to use a 3LC Table for integrating with built-in PyTorch datasets.

  • Metrics collection using ClassificationMetricsCollector and EmbeddingsMetricsCollector.

  • Reducing the dimensionality of embeddings using UMAP as a post-processing step.

[2]:
# Parameters
PROJECT_NAME = "MNIST Digit Classification"
RUN_NAME = "Train MNIST Classifier"
DESCRIPTION = "Train a simple CNN to classify MNIST digits"
TRAIN_DATASET_NAME = "mnist-train"
VAL_DATASET_NAME = "mnist-val"
TRANSIENT_DATA_PATH = "../transient_data"
COLLECT_METRICS_BATCH_SIZE = 2048
TRAIN_BATCH_SIZE = 64
INITIAL_LR = 1.0
LR_GAMMA = 0.7
EPOCHS = 5
NUM_WORKERS = 0
DEVICE = "cuda:0"
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
INSTALL_DEPENDENCIES = False
[4]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install tlc[umap]
[7]:
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm.auto import tqdm

import tlc

Initialize a 3LC Run#

[9]:
config = {
    "train_batch_size": TRAIN_BATCH_SIZE,
    "initial_lr": INITIAL_LR,
    "lr_gamma": LR_GAMMA,
    "epochs": EPOCHS,
}

run = tlc.init(
    project_name=PROJECT_NAME,
    run_name=RUN_NAME,
    description=DESCRIPTION,
    parameters=config,
    if_exists="overwrite",
)

Setup Datasets#

[10]:
transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root=TRANSIENT_DATA_PATH, train=True, download=True)
eval_dataset = torchvision.datasets.MNIST(root=TRANSIENT_DATA_PATH, train=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 165802122.57it/s]
Extracting /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-images-idx3-ubyte.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 63388641.46it/s]
Extracting /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-labels-idx1-ubyte.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 168528191.55it/s]
Extracting /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 21648328.15it/s]
Extracting /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw

[11]:
class_names = [str(i) for i in range(10)]

structure = (tlc.PILImage("image"), tlc.CategoricalLabel("label", class_names))

def transforms(x):
    return transform(x[0]), torch.tensor(x[1])

# We pick up the latest version of the dataset, so that we can re-run this notebook as-is
# after adding new revisions to the dataset.
tlc_train_dataset = tlc.Table.from_torch_dataset(
    dataset=train_dataset,
    dataset_name=TRAIN_DATASET_NAME,
    structure=structure,
    project_name=PROJECT_NAME,
    table_name="train",
).map(transforms).latest()

tlc_val_dataset = tlc.Table.from_torch_dataset(
    dataset=eval_dataset,
    dataset_name=VAL_DATASET_NAME,
    structure=structure,
    project_name=PROJECT_NAME,
    table_name="val"
).map(transforms).latest()

Setup Model#

[12]:
device = torch.device(DEVICE)
print(f"Using device: {device}")
Using device: cuda
[13]:
class Net(nn.Module):
    # From https://github.com/pytorch/examples/blob/main/mnist/main.py
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

model = Net().to(device)

Setup Training Loop#

[14]:
optimizer = torch.optim.Adadelta(model.parameters(), lr=INITIAL_LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=LR_GAMMA)
[15]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for data, target in tqdm(train_loader, desc=f"Training {epoch}/{EPOCHS}"):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        loss.backward()
        optimizer.step()

Setup Metrics Collectors#

[16]:
class MNISTMetricsCollector(tlc.MetricsCollector):
    def __init__(self, model, criterion):
        super().__init__()
        self._model = model
        self._criterion = criterion

    def compute_metrics(self, batch, predictions, _):
        labels = batch[1].to(device)

        metrics = {
            "loss": self._criterion(predictions, labels).cpu().numpy(),
            "predicted": torch.argmax(predictions, dim=1).cpu().numpy(),
            "confidence": torch.exp(torch.max(predictions, dim=1).values).cpu().numpy(),
            "accuracy": (torch.argmax(predictions, dim=1) == labels).cpu().numpy(),
        }
        return metrics

    @property
    def column_schemas(self):
        # Explicitly override the schema of the predicted label, in order for it to be displayed as a
        # categorical label in the Dashboard.
        schemas = {
            "predicted": tlc.CategoricalLabelSchema(
                class_names,
                display_name="predicted label",
            )
        }
        return schemas


mnist_metrics_collector = MNISTMetricsCollector(model, nn.NLLLoss(reduction="none"))
[17]:

Run Training#

We run training using a weighted sampler provided by the 3LC Table. The sampler uses the default weights column to sample the data. The weights can be updated in the Dashboard, and will be automatically picked up by the sampler.

[18]:
sampler = tlc_train_dataset.create_sampler()

train_loader = torch.utils.data.DataLoader(
    tlc_train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
)

metrics_collection_dataloader_args = {
    "num_workers": NUM_WORKERS,
    "batch_size": COLLECT_METRICS_BATCH_SIZE,
}

# Train the model
for epoch in range(EPOCHS):
    train(model, device, train_loader, optimizer, epoch)

    tlc.collect_metrics(
        tlc_train_dataset,
        metrics_collectors=[
            mnist_metrics_collector,
            embeddings_metrics_collector,
        ],
        split="train",
        constants={"epoch": epoch},
        dataloader_args=metrics_collection_dataloader_args,
    )
    tlc.collect_metrics(
        tlc_val_dataset,
        metrics_collectors=[
            mnist_metrics_collector,
            embeddings_metrics_collector,
        ],
        split="val",
        constants={"epoch": epoch},
        dataloader_args=metrics_collection_dataloader_args,
    )
[19]:
# Reduce embeddings using the final validation-set embeddings to fit a UMAP model
url_mapping = run.reduce_embeddings_by_foreign_table_url(
    tlc_val_dataset.url,
    n_components=3,
)