PyTorch 3LC MNIST Sample Notebook#

This notebook demonstrates the training of a Convolutional Neural Network (CNN) on the MNIST dataset using PyTorch and 3LC. The built-in MNIST dataset for training and validation is wrapped in a Table. The training runs for 5 epochs, and during this period, classification metrics and embeddings are collected.

The notebook demonstrates:

  • How to use a 3LC Table for integrating with built-in PyTorch datasets.

  • Metrics collection using a custom MetricsCollector subclass and a EmbeddingsMetricsCollector.

  • Reducing the dimensionality of embeddings using PaCMAP as a post-processing step.

Project Setup#

[2]:
# Parameters
PROJECT_NAME = "MNIST Digit Classification"
RUN_NAME = "Train MNIST Classifier"
DESCRIPTION = "Train a simple CNN to classify MNIST digits"
TRAIN_DATASET_NAME = "mnist-train"
VAL_DATASET_NAME = "mnist-val"
TRANSIENT_DATA_PATH = "../transient_data"
COLLECT_METRICS_BATCH_SIZE = 2048
TRAIN_BATCH_SIZE = 64
INITIAL_LR = 1.0
LR_GAMMA = 0.7
EPOCHS = 5
NUM_WORKERS = 0
DEVICE = None
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
INSTALL_DEPENDENCIES = False
[4]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install 3lc[pacmap]

Imports#

[7]:
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm.auto import tqdm

import tlc
[8]:
if DEVICE is None:
    if torch.cuda.is_available():
        device = "cuda:0"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
else:
    device = DEVICE

device = torch.device(device)
print(f"Using device: {device}")
Using device: cuda

Initialize a 3LC Run#

[9]:
config = {
    "train_batch_size": TRAIN_BATCH_SIZE,
    "initial_lr": INITIAL_LR,
    "lr_gamma": LR_GAMMA,
    "epochs": EPOCHS,
}

run = tlc.init(
    project_name=PROJECT_NAME,
    run_name=RUN_NAME,
    description=DESCRIPTION,
    parameters=config,
    if_exists="overwrite",
)

Setup Datasets#

[10]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]
)

train_dataset = torchvision.datasets.MNIST(root=TRANSIENT_DATA_PATH, train=True, download=True)
eval_dataset = torchvision.datasets.MNIST(root=TRANSIENT_DATA_PATH, train=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9.91M/9.91M [00:00<00:00, 168MB/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28.9k/28.9k [00:00<00:00, 45.3MB/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz

Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1.65M/1.65M [00:00<00:00, 264MB/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.27MB/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw

[11]:
class_names = [str(i) for i in range(10)]

structure = (tlc.PILImage("image"), tlc.CategoricalLabel("label", class_names))


def transforms(x):
    return transform(x[0]), torch.tensor(x[1])


# We pick up the latest version of the dataset, so that we can re-run this notebook as-is
# after adding new revisions to the dataset.
tlc_train_dataset = (
    tlc.Table.from_torch_dataset(
        dataset=train_dataset,
        dataset_name=TRAIN_DATASET_NAME,
        structure=structure,
        project_name=PROJECT_NAME,
        description="MNIST training dataset",
        table_name="train",
        if_exists="overwrite",
    )
    .map(transforms)
    .latest()
)

tlc_val_dataset = (
    tlc.Table.from_torch_dataset(
        dataset=eval_dataset,
        dataset_name=VAL_DATASET_NAME,
        structure=structure,
        project_name=PROJECT_NAME,
        description="MNIST validation dataset",
        table_name="val",
        if_exists="overwrite",
    )
    .map(transforms)
    .latest()
)

Setup Model#

[12]:
class Net(nn.Module):
    # From https://github.com/pytorch/examples/blob/main/mnist/main.py
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


model = Net().to(device)

Setup Training Loop#

[13]:
optimizer = torch.optim.Adadelta(model.parameters(), lr=INITIAL_LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=LR_GAMMA)
[14]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for data, target in tqdm(train_loader, desc=f"Training {epoch+1}/{EPOCHS}"):  # Epoch is 0-indexed
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        loss.backward()
        optimizer.step()

Setup Metrics Collectors#

[15]:
class MNISTMetricsCollector(tlc.MetricsCollector):
    def __init__(self, criterion):
        super().__init__()
        self._criterion = criterion

    def compute_metrics(self, batch, predictor_output):
        predictions = predictor_output.forward
        labels = batch[1].to(device)

        metrics = {
            "loss": self._criterion(predictions, labels).cpu().numpy(),
            "predicted": torch.argmax(predictions, dim=1).cpu().numpy(),
            "confidence": torch.exp(torch.max(predictions, dim=1).values).cpu().numpy(),
            "accuracy": (torch.argmax(predictions, dim=1) == labels).cpu().numpy(),
        }
        return metrics

    @property
    def column_schemas(self):
        # Explicitly override the schema of the predicted label, in order for it to be displayed as a
        # categorical label in the Dashboard.
        schemas = {
            "predicted": tlc.CategoricalLabelSchema(
                class_names,
                display_name="predicted label",
            )
        }
        return schemas


mnist_metrics_collector = MNISTMetricsCollector(nn.NLLLoss(reduction="none"))
embeddings_metrics_collector = tlc.EmbeddingsMetricsCollector(layers=[4])

Run Training#

We run training using a weighted sampler provided by the 3LC Table. The sampler uses the default weights column to sample the data. The weights can be updated in the Dashboard, and will be automatically picked up by the sampler.

[16]:
sampler = tlc_train_dataset.create_sampler()

train_loader = torch.utils.data.DataLoader(
    tlc_train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
)

metrics_collection_dataloader_args = {
    "num_workers": NUM_WORKERS,
    "batch_size": COLLECT_METRICS_BATCH_SIZE,
}

predictor = tlc.Predictor(model, layers=[4])

# Train the model
for epoch in range(EPOCHS):
    train(model, device, train_loader, optimizer, epoch)

    tlc.collect_metrics(
        tlc_train_dataset,
        metrics_collectors=[
            mnist_metrics_collector,
            embeddings_metrics_collector,
        ],
        predictor=predictor,
        split="train",
        constants={"epoch": epoch},
        dataloader_args=metrics_collection_dataloader_args,
    )
    tlc.collect_metrics(
        tlc_val_dataset,
        metrics_collectors=[
            mnist_metrics_collector,
            embeddings_metrics_collector,
        ],
        predictor=predictor,
        split="val",
        constants={"epoch": epoch},
        dataloader_args=metrics_collection_dataloader_args,
    )
[17]:
# Reduce embeddings using the final validation-set embeddings to fit a PaCMAP model
url_mapping = run.reduce_embeddings_by_foreign_table_url(
    tlc_val_dataset.url,
    method="pacmap",
    n_components=3,
)