View source Download .ipynb

Collect metrics using a pre-trained model¶

This notebook demonstrates how to use a pre-trained model to collect metrics on a dataset.

image1

Install dependencies¶

[ ]:
%pip install 3lc
%pip install timm
%pip install git+https://github.com/3lc-ai/3lc-examples.git

Imports¶

[ ]:
import timm
import tlc
import torchvision

from tlc_tools.common import infer_torch_device

Project setup¶

[ ]:
NUM_WORKERS = 4

Prepare Tables¶

We will reuse the tables from the notebook create-table-from-torch.ipynb, and use a pre-trained model from Hugging Face Hub.

[ ]:
device = infer_torch_device()

# Use a resnet18 model from timm, already trained on CIFAR-10
model = timm.create_model("hf_hub:FredMell/resnet18-cifar10", pretrained=True).to(device)

# Load the tables
train_table = tlc.Table.from_names("initial", "CIFAR-10-train", "3LC Tutorials - CIFAR-10")
val_table = tlc.Table.from_names("initial", "CIFAR-10-val", "3LC Tutorials - CIFAR-10")
[ ]:
image_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


def transform(sample):
    image = sample[0]
    label = sample[1]
    return (image_transform(image), label)
[ ]:
# Apply the transforms to the tables to ensure model-compatibility (ensure any existing maps are cleared first)

train_table.clear_maps()
train_table = train_table.map(transform)

val_table.clear_maps()
val_table = val_table.map(transform)

Collect metrics¶

[ ]:
# Create a 3LC run and collect metrics
run = tlc.init(
    project_name=train_table.project_name,
    description="Only collect metrics with trained model on CIFAR-10",
)

dataloader_args = {
    "batch_size": 128,
    "num_workers": NUM_WORKERS,
    "pin_memory": True,
}

classes = list(train_table.get_simple_value_map("Label").values())

tlc.collect_metrics(
    table=train_table,
    predictor=model,
    metrics_collectors=tlc.ClassificationMetricsCollector(classes=classes),
    dataloader_args=dataloader_args,
    split="train",
)

tlc.collect_metrics(
    table=val_table,
    predictor=model,
    metrics_collectors=tlc.ClassificationMetricsCollector(classes=classes),
    dataloader_args=dataloader_args,
    split="val",
)

run.set_status_completed()