View source
Download
.ipynb
Collect metrics using a pre-trained model¶
This notebook demonstrates how to use a pre-trained model to collect metrics on a dataset.

Install dependencies¶
[ ]:
%pip install 3lc
%pip install timm
%pip install git+https://github.com/3lc-ai/3lc-examples.git
Imports¶
[ ]:
import timm
import tlc
import torchvision
from tlc_tools.common import infer_torch_device
Project setup¶
[ ]:
NUM_WORKERS = 4
Prepare Tables¶
We will reuse the tables from the notebook create-table-from-torch.ipynb, and use a pre-trained model from Hugging Face Hub.
[ ]:
device = infer_torch_device()
# Use a resnet18 model from timm, already trained on CIFAR-10
model = timm.create_model("hf_hub:FredMell/resnet18-cifar10", pretrained=True).to(device)
# Load the tables
train_table = tlc.Table.from_names("initial", "CIFAR-10-train", "3LC Tutorials - CIFAR-10")
val_table = tlc.Table.from_names("initial", "CIFAR-10-val", "3LC Tutorials - CIFAR-10")
[ ]:
image_transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def transform(sample):
image = sample[0]
label = sample[1]
return (image_transform(image), label)
[ ]:
# Apply the transforms to the tables to ensure model-compatibility (ensure any existing maps are cleared first)
train_table.clear_maps()
train_table = train_table.map(transform)
val_table.clear_maps()
val_table = val_table.map(transform)
Collect metrics¶
[ ]:
# Create a 3LC run and collect metrics
run = tlc.init(
project_name=train_table.project_name,
description="Only collect metrics with trained model on CIFAR-10",
)
dataloader_args = {
"batch_size": 128,
"num_workers": NUM_WORKERS,
"pin_memory": True,
}
classes = list(train_table.get_simple_value_map("Label").values())
tlc.collect_metrics(
table=train_table,
predictor=model,
metrics_collectors=tlc.ClassificationMetricsCollector(classes=classes),
dataloader_args=dataloader_args,
split="train",
)
tlc.collect_metrics(
table=val_table,
predictor=model,
metrics_collectors=tlc.ClassificationMetricsCollector(classes=classes),
dataloader_args=dataloader_args,
split="val",
)
run.set_status_completed()