PyTorch 3LC Fashion-MNIST Sample Notebook#

This notebook is fundamentally similar to the MNIST notebook, but it uses the slightly more interesting FashionMNIST dataset.

As the original authors of the dataset noted: “MNIST is too easy… MNIST is overused… MNIST cannot represent modern CV tasks.”

While this sentiment now applies to FashionMNIST as well, it is still a more interesting example due to the slightly more complex images and labels.

This notebook demonstrates training a Convolutional Neural Network (CNN) on the Fashion-MNIST dataset using PyTorch and 3LC. Training runs for 5 epochs, and during this period, classification metrics and embeddings are collected.

The notebook demonstrates:

  • How to use a 3LC Table for integrating with built-in PyTorch datasets.

  • Metrics collection using a custom MetricsCollector subclass and a EmbeddingsMetricsCollector.

  • Reducing the dimensionality of embeddings using PaCMAP as a post-processing step.

Project Setup#

[2]:
# Parameters
PROJECT_NAME = "Fashion-MNIST Classification"
RUN_NAME = "Train a Fashion-MNIST Classifier"
DESCRIPTION = "Train a simple CNN to classify Fashion-MNIST images"
TRAIN_DATASET_NAME = "fashion-mnist-train"
VAL_DATASET_NAME = "fashion-mnist-val"
TRANSIENT_DATA_PATH = "../transient_data"
COLLECT_METRICS_BATCH_SIZE = 2048
TRAIN_BATCH_SIZE = 64
INITIAL_LR = 1.0
LR_GAMMA = 0.7
EPOCHS = 5
NUM_WORKERS = 0
DEVICE = None
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
INSTALL_DEPENDENCIES = False
[4]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install 3lc[pacmap]

Imports#

[7]:
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm.auto import tqdm

import tlc
[8]:
if DEVICE is None:
    if torch.cuda.is_available():
        device = "cuda:0"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
else:
    device = DEVICE

device = torch.device(device)
print(f"Using device: {device}")
Using device: cuda

Initialize a 3LC Run#

[9]:
config = {
    "train_batch_size": TRAIN_BATCH_SIZE,
    "initial_lr": INITIAL_LR,
    "lr_gamma": LR_GAMMA,
    "epochs": EPOCHS,
}

run = tlc.init(
    project_name=PROJECT_NAME,
    run_name=RUN_NAME,
    description=DESCRIPTION,
    parameters=config,
    if_exists="overwrite",
)

Setup Datasets#

[10]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]
)

train_dataset = torchvision.datasets.FashionMNIST(root=TRANSIENT_DATA_PATH, train=True, download=True)
eval_dataset = torchvision.datasets.FashionMNIST(root=TRANSIENT_DATA_PATH, train=False)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 26421880/26421880 [00:01<00:00, 14500755.05it/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 29515/29515 [00:00<00:00, 329186.55it/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 4422102/4422102 [00:00<00:00, 6057359.26it/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 5148/5148 [00:00<00:00, 12116878.22it/s]
Extracting /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /home/build/ado/w/1/s/tlc-monorepo/public-notebooks/transient_data/FashionMNIST/raw


[11]:
class_names = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

structure = (tlc.PILImage("image"), tlc.CategoricalLabel("label", class_names))


def transforms(x):
    return transform(x[0]), torch.tensor(x[1])


# We pick up the latest version of the dataset, so that we can re-run this notebook as-is
# after adding new revisions to the dataset.
tlc_train_dataset = (
    tlc.Table.from_torch_dataset(
        dataset=train_dataset,
        structure=structure,
        dataset_name=TRAIN_DATASET_NAME,
        project_name=PROJECT_NAME,
        description="Fashion-MNIST training dataset",
        table_name="train",
        if_exists="overwrite",
    )
    .map(transforms)
    .latest()
)

tlc_val_dataset = (
    tlc.Table.from_torch_dataset(
        dataset=eval_dataset,
        dataset_name=VAL_DATASET_NAME,
        structure=structure,
        project_name=PROJECT_NAME,
        description="Fashion-MNIST validation dataset",
        table_name="val",
        if_exists="overwrite",
    )
    .map(transforms)
    .latest()
)

Setup Model#

[12]:
class Net(nn.Module):
    # From https://github.com/pytorch/examples/blob/main/mnist/main.py
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


model = Net().to(device)

Setup Training Loop#

[13]:
optimizer = torch.optim.Adadelta(model.parameters(), lr=INITIAL_LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=LR_GAMMA)
[14]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for data, target in tqdm(train_loader, desc=f"Training {epoch+1}/{EPOCHS}"):  # Epoch is 0-indexed
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        loss.backward()
        optimizer.step()

Setup Metrics Collectors#

[15]:
class FashionMNISTMetricsCollector(tlc.MetricsCollector):
    def __init__(self, criterion):
        super().__init__()
        self.criterion = criterion

    def compute_metrics(self, batch, predictor_output):
        predictions = predictor_output.forward
        labels = batch[1].to(device)

        metrics = {
            "loss": self.criterion(predictions, labels).cpu().numpy(),
            "predicted": torch.argmax(predictions, dim=1).cpu().numpy(),
            "confidence": torch.exp(torch.max(predictions, dim=1).values).cpu().numpy(),
            "accuracy": (torch.argmax(predictions, dim=1) == labels).cpu().numpy(),
        }
        return metrics

    @property
    def column_schemas(self):
        # Explicitly override the schema of the predicted label, in order for it to be displayed as a
        # categorical label in the Dashboard.
        schemas = {
            "predicted": tlc.CategoricalLabelSchema(
                class_names,
                display_name="predicted label",
            )
        }
        return schemas


mnist_metrics_collector = FashionMNISTMetricsCollector(nn.NLLLoss(reduction="none"))
embeddings_metrics_collector = tlc.EmbeddingsMetricsCollector(layers=[4])

Run Training#

We run training using a weighted sampler provided by the 3LC Table. The sampler uses the default weights column to sample the data. The weights can be updated in the Dashboard, and will be automatically picked up by the sampler.

[16]:
sampler = tlc_train_dataset.create_sampler()

train_loader = torch.utils.data.DataLoader(
    tlc_train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
)

metrics_collection_dataloader_args = {
    "num_workers": NUM_WORKERS,
    "batch_size": COLLECT_METRICS_BATCH_SIZE,
}

predictor = tlc.Predictor(model, layers=[4])

# Train the model
for epoch in range(EPOCHS):
    train(model, device, train_loader, optimizer, epoch)

    tlc.collect_metrics(
        tlc_train_dataset,
        metrics_collectors=[
            mnist_metrics_collector,
            embeddings_metrics_collector,
        ],
        predictor=predictor,
        split="train",
        constants={"epoch": epoch},
        dataloader_args=metrics_collection_dataloader_args,
    )
    tlc.collect_metrics(
        tlc_val_dataset,
        metrics_collectors=[
            mnist_metrics_collector,
            embeddings_metrics_collector,
        ],
        predictor=predictor,
        split="val",
        constants={"epoch": epoch},
        dataloader_args=metrics_collection_dataloader_args,
    )
[17]:
# Reduce embeddings using the final validation-set embeddings to fit a PaCMAP model
url_mapping = run.reduce_embeddings_by_foreign_table_url(
    tlc_val_dataset.url,
    method="pacmap",
    n_components=3,
)