Prerequisites

This notebook reuses tables created by other example notebooks. Run them first:

Collect metrics using a pre-trained model¶

This notebook demonstrates how to use a pre-trained model to collect metrics on a dataset.

image1

Project setup¶

[ ]:
PROJECT_NAME = "3LC Tutorials - CIFAR-10"
RUN_NAME = "collect-metrics-cifar-10"
NUM_WORKERS = 4
INSTALL_DEPENDENCIES = True

Install dependencies¶

[ ]:
if INSTALL_DEPENDENCIES:
    %pip install -q 3lc
    %pip install -q timm
    %pip install -q git+https://github.com/3lc-ai/3lc-examples.git

Imports¶

[ ]:
import timm
import tlc
import torch
import torch.nn as nn
import torchvision

from tlc_tools.common import infer_torch_device

Prepare Tables¶

We will reuse the tables from the notebook create-table-from-torch.ipynb, and use a pre-trained model from Hugging Face Hub.

[ ]:
device = infer_torch_device()

# Use a resnet18 model from timm, already trained on CIFAR-10
model = timm.create_model("hf_hub:FredMell/resnet18-cifar10", pretrained=True).to(device)

# Load the tables
train_table = tlc.Table.from_names(
    table_name="initial", dataset_name="CIFAR-10-train", project_name="3LC Tutorials - CIFAR-10"
)
val_table = tlc.Table.from_names(
    table_name="initial", dataset_name="CIFAR-10-val", project_name="3LC Tutorials - CIFAR-10"
)
[ ]:
image_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


def transform(sample):
    image = sample["Image"]
    label = sample["Label"]
    return (image_transform(image), label)
[ ]:
# Build non-mutating views that apply the model-compatible transforms on read

train_view = train_table.with_transform(transform)
val_view = val_table.with_transform(transform)

Setup Metrics Collector¶

We use a tlc.metrics.FunctionalMetricsCollector to compute per-sample loss, predicted label, confidence, and accuracy from the model’s raw logits.

[ ]:
classes = list(train_table.get_simple_value_map("Label").values())
criterion = nn.CrossEntropyLoss(reduction="none")


def compute_metrics(batch, predictor_output):
    predictions = predictor_output.forward
    labels = batch[1].to(predictions.device)

    softmax = torch.softmax(predictions, dim=1)
    predicted = torch.argmax(predictions, dim=1)
    confidence = torch.gather(softmax, 1, predicted.unsqueeze(1)).squeeze(1)

    return {
        "loss": criterion(predictions, labels).cpu().numpy(),
        "predicted": predicted.cpu().numpy(),
        "confidence": confidence.cpu().numpy(),
        "accuracy": predicted.eq(labels).float().cpu().numpy(),
    }


metrics_collector = tlc.metrics.FunctionalMetricsCollector(
    collection_fn=compute_metrics,
    schema={
        "predicted": tlc.schemas.CategoricalLabelSchema(
            classes=classes,
            display_name="predicted label",
        ),
    },
)

Collect metrics¶

[ ]:
# Create a 3LC run and collect metrics
run = tlc.init(
    project_name=PROJECT_NAME,
    run_name=RUN_NAME,
    description="Only collect metrics with trained model on CIFAR-10",
    if_exists="overwrite",
)

dataloader_args = {
    "batch_size": 128,
    "num_workers": NUM_WORKERS,
    "pin_memory": True,
}

tlc.collect_metrics(
    table=train_view,
    predictor=model,
    metrics_collectors=metrics_collector,
    dataloader_args=dataloader_args,
    split="train",
)

tlc.collect_metrics(
    table=val_view,
    predictor=model,
    metrics_collectors=metrics_collector,
    dataloader_args=dataloader_args,
    split="val",
)

run.set_status_completed()