Hugging Face CIFAR-100 Embeddings Example#
In this notebook we will see how to use a pre-trained Vision Transformers (ViT) model to collect embeddings on the CIFAR-100 dataset.
This notebook demonstrates:
Registering the
CIFAR-100
dataset from Hugging Face.Computing image embeddings with
transformers
and reducing them to 2D with UMAP.Adding the computed embeddings as metrics to a 3LC
Run
.
Project Setup#
[2]:
PROJECT_NAME = "CIFAR-100"
RUN_NAME = "Collect Image Embeddings"
DESCRIPTION = "Collect image embeddings from ViT model on CIFAR-100"
DEVICE = None
TRAIN_DATASET_NAME = "hf-cifar-100-train"
TEST_DATASET_NAME = "hf-cifar-100-test"
MODEL = "google/vit-base-patch16-224"
BATCH_SIZE = 32
TRANSIENT_DATA_PATH = "../transient_data"
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
INSTALL_DEPENDENCIES = False
[4]:
%%capture
if INSTALL_DEPENDENCIES:
%pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
%pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
%pip --quiet install 3lc[umap,huggingface]
Imports#
[7]:
import logging
import datasets
import tlc
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) # Reduce model loading logs
datasets.utils.logging.disable_progress_bar()
Prepare the data#
To read the data into 3LC, we use tlc.Table.from_hugging_face()
available under the Hugging Face integration. This returns a Table
that works similarly to a Hugging Face datasets.Dataset
.
[8]:
cifar100_train = tlc.Table.from_hugging_face(
"cifar100",
split="train",
table_name="train",
project_name=PROJECT_NAME,
dataset_name=TRAIN_DATASET_NAME,
description="CIFAR-100 training dataset",
if_exists="overwrite",
)
cifar100_test = tlc.Table.from_hugging_face(
"cifar100",
split="test",
table_name="test",
project_name=PROJECT_NAME,
dataset_name=TEST_DATASET_NAME,
description="CIFAR-100 test dataset",
if_exists="overwrite",
)
[9]:
cifar100_train[0]["img"]
[9]:
Compute the data#
We then use the transformers
library to compute embeddings and umap-learn
to reduce the embeddings to two dimensions.
[10]:
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import ViTImageProcessor, ViTModel
if DEVICE is None:
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
else:
device = DEVICE
device = torch.device(device)
print(f"Using device: {device}")
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
Using device: cuda
[11]:
feature_extractor = ViTImageProcessor.from_pretrained(MODEL)
model = ViTModel.from_pretrained(MODEL).to(device)
[12]:
def extract_feature(sample):
return feature_extractor(images=sample["img"], return_tensors="pt")
[13]:
def infer_on_dataset(dataset):
activations = []
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
for inputs in tqdm(dataloader, total=len(dataloader)):
inputs["pixel_values"] = inputs["pixel_values"].squeeze()
inputs = inputs.to(DEVICE)
outputs = model(**inputs)
activations.append(outputs.last_hidden_state[:, 0, :].detach().cpu())
return activations
[14]:
activations = []
model.eval()
for dataset in (cifar100_train, cifar100_test):
dataset = dataset.map(extract_feature)
activations.extend(infer_on_dataset(dataset))
[15]:
activations = torch.cat(activations).numpy()
activations.shape
[15]:
(60000, 768)
[16]:
import umap
reducer = umap.UMAP(n_components=2)
embeddings_2d = reducer.fit_transform(activations)
/home/build/ado/w/1/huggingface-cifar100_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
warnings.warn(
Collect the embeddings as 3LC metrics#
In this example the metrics are contained in a numpy.ndarray
object. We can specify the schema of this data and provide it directly to 3LC using Run.add_metrics()
.
[17]:
[18]:
[19]:
[21]:
for dataset, embeddings in ((cifar100_train, embeddings_2d_train), (cifar100_test, embeddings_2d_test)):
run.add_metrics(
{"embeddings": list(embeddings)},
column_schemas={"embeddings": tlc.FloatVector2Schema()},
foreign_table_url=dataset.url,
)
[22]: