PyTorch 3LC CIFAR-10 Sample Notebook#

This notebook demonstrates fine-tuning a pretrained ResNet-18 model on the CIFAR-10 dataset using PyTorch and 3LC. We run the fine-tuning process for 5 epochs. During training, both classification and embeddings metrics are collected.

The notebook covers:

  • Creating a Table from a PyTorch Dataset.

  • Fine-tuning a pretrained ResNet-18 on CIFAR-10 using the Table.

  • Using FunctionalMetricsCollector and EmbeddingsMetricsCollector for metrics and embedding collection.

  • Reducing the dimensionality of embeddings using UMAP after training completes.

[2]:
# Parameters
PROJECT_NAME = "CIFAR-10 Image Classification"
RUN_NAME = "CIFAR-10 Demo Run"
DESCRIPTION = "Fine-tune ResNet18 on CIFAR-10"
TRAIN_DATASET_NAME = "cifar-10-train"
VAL_DATASET_NAME = "cifar-10-val"
NUM_CLASSES = 10
EMBEDDINGS_COLLECTION_FREQUENCY = 4
TRANSIENT_DATA_PATH = "../transient_data"
EPOCHS = 5
BATCH_SIZE = 32
INITIAL_LR = 0.01
LR_GAMMA = 0.9
NUM_WORKERS = 0
MODEL_NAME = "resnet18"
PRETRAINED = True
DEVICE = "cuda:0"
DROP_RATE = 0.2
DROP_PATH_RATE = 0.2
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
INSTALL_DEPENDENCIES = False
[4]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install timm
    %pip --quiet install tlc[umap]
[7]:
from __future__ import annotations

import torch
import torchvision
from tqdm.auto import tqdm

import tlc

Initialize a 3LC Run#

First, we initialize a 3LC run. This will create a new empty run which will be visible in the 3LC dashboard.

[9]:
run = tlc.init(
    project_name=PROJECT_NAME,
    run_name=RUN_NAME,
    description=DESCRIPTION,
    if_exists="overwrite",
)
[10]:
config = {
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "initial_lr": INITIAL_LR,
    "lr_gamma": LR_GAMMA,
    "model_name": MODEL_NAME,
    "pretrained": PRETRAINED,
    "drop_rate": DROP_RATE,
    "drop_path_rate": DROP_PATH_RATE,
}

# Persist the notebook configuration parameters to the run
run.set_parameters(config)

Setup Datasets#

We will create a Table using the CIFAR-10 dataset from torchvision which will be used for visualization in the 3LC dashboard, and for associating metrics with the dataset. This will also allow the user to make virtual edits to the dataset, and run new experiments on the modified dataset.

Since the underlying CIFAR-10 dataset is not stored as individual image files, 3LC will copy the images to the configured sample root.

[11]:
train_dataset = torchvision.datasets.CIFAR10(root=TRANSIENT_DATA_PATH, train=True, download=True)
val_dataset = torchvision.datasets.CIFAR10(root=TRANSIENT_DATA_PATH, train=False)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/cifar-10-python.tar.gz
100%|██████████| 170498071/170498071 [00:01<00:00, 97695838.71it/s]
Extracting /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data/cifar-10-python.tar.gz to /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/transient_data
[12]:
# The `structure` describes the layout of the samples in the dataset
# This helps 3lc to create a table with the correct columns and schemas
class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

structure = (
    tlc.PILImage("image"),
    tlc.CategoricalLabel("label", classes=class_names),
)

train_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(224),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

val_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def train_fn(sample):
    return train_transform(sample[0]), sample[1]


def val_fn(sample):
    return val_transform(sample[0]), sample[1]


# Create the 3LC Tables

# Notice that instead of assigning the transforms to the torch dataset
# we created, we assign them to the 3LC Table using `map`. This is because
# we want the Table to be able to capture the untransformed images.

# Notice that we also call `map_collect_metrics` on the training table
# to specify the transforms which should be used to collect metrics.
# Since we don't want 3LC to collect metrics on augmented images, we
# use the validation transforms for metrics collection. If
# `map_collect_metrics` is not called, the transforms given to `map`
# will be used for metrics collection.

tlc_train_dataset = (
    tlc.Table.from_torch_dataset(
        dataset=train_dataset,
        dataset_name=TRAIN_DATASET_NAME,
        table_name="train",
        structure=structure,
        if_exists="overwrite",
    )
    .map(train_fn)
    .map_collect_metrics(val_fn)
)

tlc_val_dataset = tlc.Table.from_torch_dataset(
    dataset=val_dataset,
    dataset_name=VAL_DATASET_NAME,
    table_name="val",
    structure=structure,
    if_exists="overwrite",
).map(val_fn)


# Automatically pick up the latest version of the tables to include edits committed in the dashboard.
initial_train_url = tlc_train_dataset.url
initial_val_url = tlc_val_dataset.url

tlc_train_dataset = tlc_train_dataset.latest()
tlc_val_dataset = tlc_val_dataset.latest()

if tlc_train_dataset.url != initial_train_url:
    print(f"Using latest training table {tlc_train_dataset.url}")
else:
    print(f"Using source training table {initial_train_url}")

if tlc_val_dataset.url != initial_val_url:
    print(f"Using latest validation table {tlc_val_dataset.url}")
else:
    print(f"Using source validation table {initial_val_url}")
Using source training table /home/build/.local/share/3LC/projects/CIFAR-10 Image Classification/datasets/cifar-10-train/tables/train
Using source validation table /home/build/.local/share/3LC/projects/CIFAR-10 Image Classification/datasets/cifar-10-val/tables/val

Setup Model#

We use a ResNet-18 model from the timm model repository.

[13]:
import timm

torch.backends.cudnn.benchmark = True

device = torch.device(DEVICE)
print(f"Training will use device {device}")

model = timm.create_model(
    MODEL_NAME,
    pretrained=PRETRAINED,
    num_classes=NUM_CLASSES,
    drop_rate=DROP_RATE,
    drop_path_rate=DROP_PATH_RATE
).to(device)
Training will use device cuda

Setup Training Loop#

[14]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=INITIAL_LR,
    momentum=0.9,
    weight_decay=1e-4,
    nesterov=True,
)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_GAMMA)
scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")
[15]:
def train(model, loader, criterion, optimizer, scaler):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            outputs = model(images)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return train_loss / total, 100 * correct / total
[16]:
def validate(model, loader, criterion):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return val_loss / total, 100 * correct / total

Setup Metrics Collectors#

[17]:
# Print the model layers. This is useful for finding the indices of named modules in a model.
# The index will be used for creating the embeddings metrics collector.
indices_and_modules = list(enumerate(model.named_modules()))
for idx, (name, module) in indices_and_modules:
    # print(idx, name)
    pass

# The final fully connected layer will be used for collecting the embeddings.
final_flatten_layer_index = indices_and_modules[-1][0]
final_flatten_layer_name = indices_and_modules[-1][1][0]

print(f"Using layer {final_flatten_layer_index} ({final_flatten_layer_name}) for embeddings collection")
Using layer 100 (fc) for embeddings collection
[18]:
import torch.nn.functional as F


## Define a function for the metrics we want to collect, will be passed to a FunctionalMetricsCollector
def metrics_fn(samples: tuple[torch.Tensor, torch.Tensor], predictions: torch.Tensor, _):
    labels = samples[1]
    num_classes = predictions.shape[1]
    one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()

    # Confidence & Predicted
    softmax_output = torch.nn.functional.softmax(predictions, dim=1)
    predicted_indices = torch.argmax(predictions, dim=1)
    confidence = torch.gather(softmax_output, 1, predicted_indices.unsqueeze(1)).squeeze(1)

    # Per-sample accuracy (1 if correct, 0 otherwise)
    accuracy = (predicted_indices == labels).float()

    # Unreduced Cross Entropy Loss
    cross_entropy_loss: torch.Tensor = torch.nn.CrossEntropyLoss(reduction="none")(predictions, labels)

    # RMSE
    mse: torch.Tensor = torch.nn.MSELoss(reduction="none")(softmax_output, one_hot_labels)
    mse = mse.mean(dim=1)
    rmse = torch.sqrt(mse)

    # MAE
    mae: torch.Tensor = torch.nn.L1Loss(reduction="none")(softmax_output, one_hot_labels)
    mae = mae.mean(dim=1)

    return {
        "loss": cross_entropy_loss.cpu().numpy(),
        "predicted": predicted_indices.cpu().numpy(),
        "accuracy": accuracy.cpu().numpy(),
        "confidence": confidence.cpu().numpy(),
        "rmse": rmse.cpu().numpy(),
        "mae": mae.cpu().numpy(),
    }


# Schemas will be inferred automatically, but can be explicitly defined if customizations are needed,
# for example to set the description, display name, display_importance, class_names, etc..

schemas = {
    "loss": tlc.Schema(
        description="Cross entropy loss",
        value=tlc.Float32Value(),
    ),
    "predicted": tlc.CategoricalLabelSchema(
        display_name="predicted label",
        class_names=class_names,
    ),
}

## Define metrics collectors

classification_metrics_collector = tlc.FunctionalMetricsCollector(
    model=model,
    collection_fn=metrics_fn,
    column_schemas=schemas,
)

embeddings_metrics_collector = tlc.EmbeddingsMetricsCollector(
    model=model,
    layers=[final_flatten_layer_index],
)

Run training#

We have now defined our training and validation datasets, defined our model, and configured our metrics collectors. We are ready to run training.

[19]:
from torch.utils.data import DataLoader

# Create a weighted sampler to determine the sampling probability for each sample
sampler = tlc_train_dataset.create_sampler()

# Create training and validation dataloaders using our 3LC Tables
train_loader = DataLoader(
    tlc_train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    sampler=sampler,
)

val_loader = DataLoader(
    tlc_val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

# We can use a larger batch size for the metrics collection, as we don't need to backpropagate
metrics_collection_dataloader_args = {"num_workers": NUM_WORKERS, "batch_size": 512}

# We will collect the learning rate as a constant value per metrics-collection run,
# but we want it to be hidden by default in the Dashboard.
learning_rate_schema = tlc.Schema(
    display_name="LR",
    description="Learning rate",
    value=tlc.Float32Value(),
    default_visible=False,
)

# Train the model
for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, scaler)
    val_loss, val_acc = validate(model, val_loader, criterion)

    # Collect classification metrics every epoch
    tlc.collect_metrics(
        tlc_val_dataset,
        metrics_collectors=[classification_metrics_collector, embeddings_metrics_collector],
        split="val",
        constants={"epoch": epoch, "learning_rate": lr_scheduler.get_last_lr()[0]},
        constants_schemas={"learning_rate": learning_rate_schema},
        dataloader_args=metrics_collection_dataloader_args,
    )
    tlc.collect_metrics(
        tlc_train_dataset,
        metrics_collectors=[classification_metrics_collector, embeddings_metrics_collector],
        split="train",
        constants={"epoch": epoch, "learning_rate": lr_scheduler.get_last_lr()[0]},
        constants_schemas={"learning_rate": learning_rate_schema},
        dataloader_args=metrics_collection_dataloader_args,
    )

    print("Epoch {}/{}:, Train Loss: {:.4f}, Train Acc: {:.2f}%, lr: {:.6f}"
        .format(epoch + 1, EPOCHS, train_loss, train_acc, lr_scheduler.get_last_lr()[0])
    )
    print(f"Epoch {epoch + 1}/{EPOCHS}:, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    lr_scheduler.step()
Epoch 1/5:, Train Loss: 0.8221, Train Acc: 71.55%, lr: 0.010000
Epoch 1/5:, Val Loss: 0.3013, Val Acc: 89.76%
Epoch 2/5:, Train Loss: 0.4983, Train Acc: 83.06%, lr: 0.009000
Epoch 2/5:, Val Loss: 0.2615, Val Acc: 91.26%
Epoch 3/5:, Train Loss: 0.4269, Train Acc: 85.51%, lr: 0.008100
Epoch 3/5:, Val Loss: 0.2228, Val Acc: 92.31%
Epoch 4/5:, Train Loss: 0.3800, Train Acc: 87.16%, lr: 0.007290
Epoch 4/5:, Val Loss: 0.1957, Val Acc: 93.27%
Epoch 5/5:, Train Loss: 0.3508, Train Acc: 87.94%, lr: 0.006561
Epoch 5/5:, Val Loss: 0.1948, Val Acc: 93.25%
[20]:
# Reduce embeddings using the final validation-set embeddings to fit a UMAP model
url_mapping = run.reduce_embeddings_by_foreign_table_url(tlc_train_dataset.url, n_neighbors=10)