View source
Download
.ipynb
Train autoencoder for embedding extraction¶
This notebook showcases more ways of working with metrics and embeddings in 3LC. It is mostly meant as a demonstration of how to collect embeddings and image metrics from a manually trained model.

The auto-encoder architecture is mainly used as a simple example to demonstrate the process, and the model should only be considered as an example of an embedding extractor, which also produces images as a side effect.
Install dependencies¶
[ ]:
%pip install 3lc[pacmap]
%pip install git+https://github.com/3lc-ai/3lc-examples.git
%pip install timm
Imports¶
[ ]:
import tlc
import torch
import torch.nn as nn
from timm import create_model
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
from tlc_tools.common import infer_torch_device
Project setup¶
[ ]:
DOWNLOAD_PATH = "../../transient_data"
PROJECT_NAME = "3LC Tutorials - CIFAR-10"
RUN_NAME = "Train autoencoder"
RUN_DESCRIPTION = "Train an autoencoder and collect embeddings and reconstructions"
BACKBONE = "resnet50"
EMBEDDING_DIM = 512 # Desired embedding dimension
EPOCHS = 10
FREEZE_ENCODER = False
IMAGE_WIDTH = 32
IMAGE_HEIGHT = 32
NUM_CHANNELS = 3
BATCH_SIZE = 64
METHOD = "pacmap"
NUM_COMPONENTS = 2
NUM_WORKERS = 0
[ ]:
CHECKPOINT_PATH = DOWNLOAD_PATH + "/autoencoder_model.pth"
Load input Table¶
[ ]:
train_table = tlc.Table.from_names("initial", "CIFAR-10-train", "3LC Tutorials - CIFAR-10")
val_table = tlc.Table.from_names("initial", "CIFAR-10-val", "3LC Tutorials - CIFAR-10")
[ ]:
# Prepare Data
transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
def map_fn(sample):
"""Map samples from the table to be compatible with the model."""
image = sample[0]
image = transform(image)
return image
train_table.clear_maps()
train_table.map(map_fn)
val_table.clear_maps()
val_table.map(map_fn)
Train autoencoder¶
[ ]:
class Autoencoder(nn.Module):
def __init__(self, backbone_name="resnet50", embedding_dim=512, freeze_encoder=FREEZE_ENCODER):
super().__init__()
# Load the backbone as an encoder
self.encoder = create_model(backbone_name, pretrained=True, num_classes=0)
encoder_output_dim = self.encoder.feature_info[-1]["num_chs"]
# Freeze encoder parameters if specified
if freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
# Add a projection layer to reduce to embedding_dim
self.projector = nn.Linear(encoder_output_dim, embedding_dim)
# Define the decoder
self.decoder = nn.Sequential(
nn.Linear(embedding_dim, encoder_output_dim),
nn.ReLU(),
nn.Linear(encoder_output_dim, IMAGE_HEIGHT * IMAGE_WIDTH * NUM_CHANNELS),
nn.Sigmoid(),
)
def forward(self, x):
# Encoder
features = self.encoder(x)
embeddings = self.projector(features)
# Decoder
reconstructions = self.decoder(embeddings)
reconstructions = reconstructions.view(x.size(0), NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)
return embeddings, reconstructions
[ ]:
# Initialize the model
model = Autoencoder(backbone_name=BACKBONE, embedding_dim=EMBEDDING_DIM)
# Training Components
criterion = nn.MSELoss() # Reconstruction loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Create data loaders
train_loader = DataLoader(train_table, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_table, batch_size=BATCH_SIZE, shuffle=False)
device = infer_torch_device()
model.to(device)
[ ]:
# Training loop
for epoch in range(EPOCHS):
model.train()
epoch_train_loss = 0.0
epoch_val_loss = 0.0
for images in tqdm(train_loader, desc="Training", total=len(train_loader)):
images = images.to(device)
# Forward pass
embeddings, reconstructions = model(images)
# Compute loss
loss = criterion(reconstructions, images)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_train_loss += loss.item()
# Validation pass
model.eval()
with torch.no_grad():
for images in tqdm(val_loader, desc="Validation", total=len(val_loader)):
images = images.to(device)
# Forward pass
embeddings, reconstructions = model(images)
# Compute loss
loss = criterion(reconstructions, images)
epoch_val_loss += loss.item()
epoch_train_loss /= len(train_loader)
epoch_val_loss /= len(val_loader)
print(f"Epoch [{epoch + 1}/{EPOCHS}], Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
[ ]:
# Save the model
torch.save(model.state_dict(), CHECKPOINT_PATH)
print(f"Model saved to {CHECKPOINT_PATH}")
Collect metrics from the trained model¶
[ ]:
unreduced_loss = nn.MSELoss(reduction="none") # Reconstruction loss
def metrics_fn(batch, predictor_output):
embeddings, reconstructions = predictor_output.forward
reconstructed_images = [transforms.ToPILImage()(image.cpu()) for image in reconstructions]
reconstruction_loss = unreduced_loss(reconstructions.to(device), batch.to(device)).mean(dim=(1, 2, 3))
return {
"embeddings": embeddings.cpu().detach().numpy(),
"reconstructions": reconstructed_images,
"reconstruction_loss": reconstruction_loss.cpu().detach().numpy(),
}
[ ]:
run = tlc.init(project_name=PROJECT_NAME, run_name=RUN_NAME, description=RUN_DESCRIPTION)
tlc.collect_metrics(
train_table,
metrics_fn,
model,
collect_aggregates=False,
dataloader_args={"batch_size": BATCH_SIZE, "num_workers": NUM_WORKERS},
)
tlc.collect_metrics(
val_table,
metrics_fn,
model,
collect_aggregates=False,
dataloader_args={"batch_size": BATCH_SIZE, "num_workers": NUM_WORKERS},
)
Reduce embeddings to 2D¶
[ ]:
run.reduce_embeddings_by_foreign_table_url(
train_table.url,
source_embedding_column="embeddings",
method=METHOD,
n_components=NUM_COMPONENTS,
)