PyTorch 3LC CIFAR-10 Sample Notebook

This notebook demonstrates fine-tuning a pretrained ResNet-18 model on the CIFAR-10 dataset using PyTorch and 3LC. We run the fine-tuning process for 5 epochs. During training, both classification and embeddings metrics are collected.

image1

The notebook covers:

  • Creating a Table from a PyTorch Dataset.

  • Fine-tuning a pretrained ResNet-18 on CIFAR-10 using the Table.

  • Using FunctionalMetricsCollector and EmbeddingsMetricsCollector for metrics and embedding collection.

  • Reducing the dimensionality of embeddings using PaCMAP after training completes.

Project Setup

[ ]:
PROJECT_NAME = "3LC Tutorials - CIFAR-10"
RUN_NAME = "CIFAR-10 Demo Run"
DESCRIPTION = "Fine-tune ResNet18 on CIFAR-10"
TRAIN_DATASET_NAME = "cifar-10-train"
VAL_DATASET_NAME = "cifar-10-val"
NUM_CLASSES = 10
EMBEDDINGS_COLLECTION_FREQUENCY = 4
DOWNLOAD_PATH = "../../transient_data"
EPOCHS = 5
BATCH_SIZE = 32
INITIAL_LR = 0.01
LR_GAMMA = 0.9
NUM_WORKERS = 0
TIMM_MODEL_NAME = "resnet18"
PRETRAINED = True
DEVICE = None
DROP_RATE = 0.2
DROP_PATH_RATE = 0.2
INSTALL_DEPENDENCIES = True
[ ]:
if INSTALL_DEPENDENCIES:
    %pip --quiet install timm
    %pip --quiet install 3lc[pacmap]

Imports

[ ]:
import tlc
import torch
import torchvision
from tqdm.auto import tqdm
[ ]:
if DEVICE is None:
    if torch.cuda.is_available():
        device = "cuda:0"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
else:
    device = DEVICE

device = torch.device(device)
print(f"Using device: {device}")

Initialize a 3LC Run

First, we initialize a 3LC run. This will create a new empty run which will be visible in the 3LC dashboard.

[ ]:
run = tlc.init(
    project_name=PROJECT_NAME,
    run_name=RUN_NAME,
    description=DESCRIPTION,
    if_exists="overwrite",
)
[ ]:
config = {
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "initial_lr": INITIAL_LR,
    "lr_gamma": LR_GAMMA,
    "model_name": TIMM_MODEL_NAME,
    "pretrained": PRETRAINED,
    "drop_rate": DROP_RATE,
    "drop_path_rate": DROP_PATH_RATE,
}

# Persist the notebook configuration parameters to the run
run.set_parameters(config)

Setup Datasets

We will create a Table using the CIFAR-10 dataset from torchvision which will be used for visualization in the 3LC dashboard, and for associating metrics with the dataset. This will also allow the user to make virtual edits to the dataset, and run new experiments on the modified dataset.

Since the underlying CIFAR-10 dataset is not stored as individual image files, 3LC will copy the images to the configured sample root.

[ ]:
train_dataset = torchvision.datasets.CIFAR10(root=DOWNLOAD_PATH, train=True, download=True)
val_dataset = torchvision.datasets.CIFAR10(root=DOWNLOAD_PATH, train=False)
[ ]:
# The `schema` describes the layout of the samples in the dataset
# This helps 3lc to create a table with the correct columns and schemas
class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

schema = {
    "image": tlc.schemas.ImageSchema(),
    "label": tlc.schemas.CategoricalLabelSchema(classes=class_names),
}

train_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(224),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

val_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


def train_fn(sample):
    return train_transform(sample["image"]), sample["label"]


def val_fn(sample):
    return val_transform(sample["image"]), sample["label"]


# Create the 3LC Tables. The Table itself stays untransformed so it can capture the original
# images; we apply per-use transforms via non-mutating views (`Table.with_transform`).
#
# We build a separate metrics-collection view (`val_fn`) for the training data because we
# don't want metrics to be computed on augmented images.

tlc_train_table = tlc.Table.from_torch_dataset(
    dataset=train_dataset,
    dataset_name=TRAIN_DATASET_NAME,
    table_name="train",
    description="CIFAR-10 training dataset",
    schema=schema,
    if_exists="overwrite",
)

tlc_val_table = tlc.Table.from_torch_dataset(
    dataset=val_dataset,
    dataset_name=VAL_DATASET_NAME,
    table_name="val",
    description="CIFAR-10 validation dataset",
    schema=schema,
    if_exists="overwrite",
)

# Automatically pick up the latest version of the tables to include edits committed in the dashboard.
initial_train_url = tlc_train_table.url
initial_val_url = tlc_val_table.url

tlc_train_table = tlc_train_table.latest()
tlc_val_table = tlc_val_table.latest()

if tlc_train_table.url != initial_train_url:
    print(f"Using latest training table {tlc_train_table.url}")
else:
    print(f"Using source training table {initial_train_url}")

if tlc_val_table.url != initial_val_url:
    print(f"Using latest validation table {tlc_val_table.url}")
else:
    print(f"Using source validation table {initial_val_url}")

# Views applied on read
train_view = tlc_train_table.with_transform(train_fn)
train_collect_view = tlc_train_table.with_transform(val_fn)  # used for metrics collection
val_view = tlc_val_table.with_transform(val_fn)

Setup Model

We use a ResNet-18 model from the timm model repository.

[ ]:
import timm

torch.backends.cudnn.benchmark = True

model = timm.create_model(
    TIMM_MODEL_NAME, pretrained=PRETRAINED, num_classes=NUM_CLASSES, drop_rate=DROP_RATE, drop_path_rate=DROP_PATH_RATE
).to(device)

Setup Training Loop

[ ]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=INITIAL_LR,
    momentum=0.9,
    weight_decay=1e-4,
    nesterov=True,
)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_GAMMA)
scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")
[ ]:
def train(model, loader, criterion, optimizer, scaler):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            outputs = model(images)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return train_loss / total, 100 * correct / total
[ ]:
def validate(model, loader, criterion):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return val_loss / total, 100 * correct / total

Setup Metrics Collectors

[ ]:
# Print the model layers. This is useful for finding the indices of named modules in a model.
# The index will be used for creating the embeddings metrics collector.
indices_and_modules = list(enumerate(model.named_modules()))
for idx, (name, _module) in indices_and_modules:
    print(idx, name)
    pass

# The final fully connected layer will be used for collecting the embeddings.
final_flatten_layer_index = indices_and_modules[-1][0]
final_flatten_layer_name = indices_and_modules[-1][1][0]

print(f"Using layer {final_flatten_layer_index} ({final_flatten_layer_name}) for embeddings collection")
[ ]:
import torch.nn.functional as F


## Define a function for the metrics we want to collect, will be passed to a FunctionalMetricsCollector
def metrics_fn(batch, predictor_output: tlc.metrics.PredictorOutput):
    # tuple[torch.Tensor, torch.Tensor]
    labels = batch[1].to(device)
    predictions = predictor_output.forward
    num_classes = predictions.shape[1]
    one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()

    # Confidence & Predicted
    softmax_output = torch.nn.functional.softmax(predictions, dim=1)
    predicted_indices = torch.argmax(predictions, dim=1)
    confidence = torch.gather(softmax_output, 1, predicted_indices.unsqueeze(1)).squeeze(1)

    # Per-sample accuracy (1 if correct, 0 otherwise)
    accuracy = (predicted_indices == labels).float()

    # Unreduced Cross Entropy Loss
    cross_entropy_loss: torch.Tensor = torch.nn.CrossEntropyLoss(reduction="none")(predictions, labels)

    # RMSE
    mse: torch.Tensor = torch.nn.MSELoss(reduction="none")(softmax_output, one_hot_labels)
    mse = mse.mean(dim=1)
    rmse = torch.sqrt(mse)

    # MAE
    mae: torch.Tensor = torch.nn.L1Loss(reduction="none")(softmax_output, one_hot_labels)
    mae = mae.mean(dim=1)

    return {
        "loss": cross_entropy_loss.cpu().numpy(),
        "predicted": predicted_indices.cpu().numpy(),
        "accuracy": accuracy.cpu().numpy(),
        "confidence": confidence.cpu().numpy(),
        "rmse": rmse.cpu().numpy(),
        "mae": mae.cpu().numpy(),
    }


# Schemas will be inferred automatically, but can be explicitly defined if customizations are needed,
# for example to set the description, display name, display_importance, class_names, etc..

schemas = {
    "loss": tlc.schemas.Float32Schema(description="Cross entropy loss"),
    "predicted": tlc.schemas.CategoricalLabelSchema(
        display_name="predicted label",
        classes=class_names,
    ),
}

## Define metrics collectors

classification_metrics_collector = tlc.metrics.FunctionalMetricsCollector(
    collection_fn=metrics_fn,
    schema=schemas,
)

embeddings_metrics_collector = tlc.metrics.EmbeddingsMetricsCollector(layers=[final_flatten_layer_index])

Run training

We have now defined our training and validation datasets, defined our model, and configured our metrics collectors. We are ready to run training.

[ ]:
from tlc.integration.torch.samplers import create_sampler
from torch.utils.data import DataLoader

# Create a weighted sampler to determine the sampling probability for each sample
sampler = create_sampler(tlc_train_table)

predictor = tlc.metrics.Predictor(model, layers=[final_flatten_layer_index])

# Create training and validation dataloaders using our 3LC views
train_loader = DataLoader(
    train_view,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    sampler=sampler,
)

val_loader = DataLoader(
    val_view,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

# We can use a larger batch size for the metrics collection, as we don't need to backpropagate
metrics_collection_dataloader_args = {"num_workers": NUM_WORKERS, "batch_size": 512}

# We will collect the learning rate as a constant value per metrics-collection run,
# but we want it to be hidden by default in the Dashboard.
learning_rate_schema = tlc.schemas.Float32Schema(
    display_name="LR",
    description="Learning rate",
    default_visible=False,
)

# Train the model
for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, scaler)
    val_loss, val_acc = validate(model, val_loader, criterion)

    # Collect classification metrics every epoch
    tlc.collect_metrics(
        val_view,
        metrics_collectors=[classification_metrics_collector, embeddings_metrics_collector],
        predictor=predictor,
        split="val",
        constants={"epoch": epoch, "learning_rate": lr_scheduler.get_last_lr()[0]},
        constants_schemas={"learning_rate": learning_rate_schema},
        dataloader_args=metrics_collection_dataloader_args,
    )
    tlc.collect_metrics(
        train_collect_view,
        metrics_collectors=[classification_metrics_collector, embeddings_metrics_collector],
        predictor=predictor,
        split="train",
        constants={"epoch": epoch, "learning_rate": lr_scheduler.get_last_lr()[0]},
        constants_schemas={"learning_rate": learning_rate_schema},
        dataloader_args=metrics_collection_dataloader_args,
    )

    print(
        f"Epoch {epoch + 1}/{EPOCHS}:, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
        f"lr: {lr_scheduler.get_last_lr()[0]:.6f}"
    )
    print(f"Epoch {epoch + 1}/{EPOCHS}:, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    lr_scheduler.step()
[ ]:
# Reduce embeddings using the final validation-set embeddings to fit a PaCMAP model
url_mapping = run.reduce_embeddings_by_foreign_table_url(
    tlc_train_table.url,
    method="pacmap",
    n_components=3,
)