PyTorch 3LC MNIST Sample Notebook#
This notebook demonstrates the training of a Convolutional Neural Network (CNN) on the MNIST dataset using PyTorch and 3LC. The built-in MNIST dataset for training and validation is wrapped in a Table. The training runs for 5 epochs, and during this period, classification metrics and embeddings are collected.
The notebook demonstrates:
How to use a 3LC Table for integrating with built-in PyTorch datasets.
Metrics collection using a custom
MetricsCollector
subclass and aEmbeddingsMetricsCollector
.Reducing the dimensionality of embeddings using PaCMAP as a post-processing step.
Project Setup#
[2]:
# Parameters
PROJECT_NAME = "MNIST Digit Classification"
RUN_NAME = "Train MNIST Classifier"
DESCRIPTION = "Train a simple CNN to classify MNIST digits"
TRAIN_DATASET_NAME = "mnist-train"
VAL_DATASET_NAME = "mnist-val"
TRANSIENT_DATA_PATH = "../transient_data"
COLLECT_METRICS_BATCH_SIZE = 2048
TRAIN_BATCH_SIZE = 64
INITIAL_LR = 1.0
LR_GAMMA = 0.7
EPOCHS = 5
NUM_WORKERS = 0
DEVICE = "cuda:0"
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
INSTALL_DEPENDENCIES = False
[4]:
%%capture
if INSTALL_DEPENDENCIES:
%pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
%pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
%pip --quiet install tlc[pacmap]
Imports#
[7]:
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm.auto import tqdm
import tlc
Initialize a 3LC Run#
Setup Datasets#
[9]:
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]
)
train_dataset = torchvision.datasets.MNIST(root=TRANSIENT_DATA_PATH, train=True, download=True)
eval_dataset = torchvision.datasets.MNIST(root=TRANSIENT_DATA_PATH, train=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 136968110.76it/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 61459002.45it/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/train-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 206771650.57it/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 9110726.34it/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/MNIST/raw
[10]:
class_names = [str(i) for i in range(10)]
structure = (tlc.PILImage("image"), tlc.CategoricalLabel("label", class_names))
def transforms(x):
return transform(x[0]), torch.tensor(x[1])
# We pick up the latest version of the dataset, so that we can re-run this notebook as-is
# after adding new revisions to the dataset.
tlc_train_dataset = (
tlc.Table.from_torch_dataset(
dataset=train_dataset,
dataset_name=TRAIN_DATASET_NAME,
structure=structure,
project_name=PROJECT_NAME,
description="MNIST training dataset",
table_name="train",
if_exists="overwrite",
)
.map(transforms)
.latest()
)
tlc_val_dataset = (
tlc.Table.from_torch_dataset(
dataset=eval_dataset,
dataset_name=VAL_DATASET_NAME,
structure=structure,
project_name=PROJECT_NAME,
description="MNIST validation dataset",
table_name="val",
if_exists="overwrite",
)
.map(transforms)
.latest()
)
Setup Model#
[11]:
device = torch.device(DEVICE)
print(f"Using device: {device}")
Using device: cuda
[12]:
class Net(nn.Module):
# From https://github.com/pytorch/examples/blob/main/mnist/main.py
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
model = Net().to(device)
Setup Training Loop#
[13]:
optimizer = torch.optim.Adadelta(model.parameters(), lr=INITIAL_LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=LR_GAMMA)
[14]:
def train(model, device, train_loader, optimizer, epoch):
model.train()
for data, target in tqdm(train_loader, desc=f"Training {epoch+1}/{EPOCHS}"): # Epoch is 0-indexed
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
Setup Metrics Collectors#
[15]:
class MNISTMetricsCollector(tlc.MetricsCollector):
def __init__(self, criterion):
super().__init__()
self._criterion = criterion
def compute_metrics(self, batch, predictor_output):
predictions = predictor_output.forward
labels = batch[1].to(device)
metrics = {
"loss": self._criterion(predictions, labels).cpu().numpy(),
"predicted": torch.argmax(predictions, dim=1).cpu().numpy(),
"confidence": torch.exp(torch.max(predictions, dim=1).values).cpu().numpy(),
"accuracy": (torch.argmax(predictions, dim=1) == labels).cpu().numpy(),
}
return metrics
@property
def column_schemas(self):
# Explicitly override the schema of the predicted label, in order for it to be displayed as a
# categorical label in the Dashboard.
schemas = {
"predicted": tlc.CategoricalLabelSchema(
class_names,
display_name="predicted label",
)
}
return schemas
mnist_metrics_collector = MNISTMetricsCollector(nn.NLLLoss(reduction="none"))
embeddings_metrics_collector = tlc.EmbeddingsMetricsCollector(layers=[4])
Run Training#
We run training using a weighted sampler provided by the 3LC Table. The sampler uses the default weights
column to sample the data. The weights can be updated in the Dashboard, and will be automatically picked up by the sampler.
[16]:
sampler = tlc_train_dataset.create_sampler()
train_loader = torch.utils.data.DataLoader(
tlc_train_dataset,
batch_size=TRAIN_BATCH_SIZE,
sampler=sampler,
num_workers=NUM_WORKERS,
)
metrics_collection_dataloader_args = {
"num_workers": NUM_WORKERS,
"batch_size": COLLECT_METRICS_BATCH_SIZE,
}
predictor = tlc.Predictor(model, layers=[4])
# Train the model
for epoch in range(EPOCHS):
train(model, device, train_loader, optimizer, epoch)
tlc.collect_metrics(
tlc_train_dataset,
metrics_collectors=[
mnist_metrics_collector,
embeddings_metrics_collector,
],
predictor=predictor,
split="train",
constants={"epoch": epoch},
dataloader_args=metrics_collection_dataloader_args,
)
tlc.collect_metrics(
tlc_val_dataset,
metrics_collectors=[
mnist_metrics_collector,
embeddings_metrics_collector,
],
predictor=predictor,
split="val",
constants={"epoch": epoch},
dataloader_args=metrics_collection_dataloader_args,
)
[17]:
# Reduce embeddings using the final validation-set embeddings to fit a PaCMAP model
url_mapping = run.reduce_embeddings_by_foreign_table_url(
tlc_val_dataset.url,
method="pacmap",
n_components=3,
)