View source Download .ipynb

Fine-tune Hugging Face SegFormer on a custom dataset

This tutorial covers metrics collection on a custom semantic segmentation dataset using 3lc and training using 🤗 transformers.

It is based on the original notebook found here.

image1

A small subset of the ADE20K dataset is used for this tutorial. The subset consists of 5 training images and 5 validation images, with semantic masks containing 150 labels.

During training, per-sample loss, embeddings, and predictions are collected.

Project setup

[ ]:
PROJECT_NAME = "3LC Tutorials - Semantic Segmentation ADE20k"
DATASET_NAME = "ADE20k_toy_dataset"
DOWNLOAD_PATH = "../../transient_data"
EPOCHS = 200
NUM_WORKERS = 0
BATCH_SIZE = 2

Install dependencies

[ ]:
%pip install 3lc[huggingface] "transformers<=4.56.0"
%pip install git+https://github.com/3lc-ai/3lc-examples.git

Imports

[ ]:
import json
import os
from pathlib import Path

import tlc
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from transformers import SegformerImageProcessor

from tlc_tools.common import download_and_extract_zipfile, infer_torch_device

Download the dataset

[ ]:
DATASET_ROOT = (Path(DOWNLOAD_PATH) / "ADE20k_toy_dataset").resolve()

if not DATASET_ROOT.exists():
    print("Downloading data...")
    download_and_extract_zipfile(
        url="https://www.dropbox.com/s/l1e45oht447053f/ADE20k_toy_dataset.zip?dl=1",
        location=DOWNLOAD_PATH,
    )

Fetch the label map from the Hugging Face Hub

[ ]:
# load id2label mapping from a JSON on the hub
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
with open(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")) as f:
    id2label = json.load(f)
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
unreduced_label_map = {0.0: "background", **{k + 1: v for k, v in id2label.items()}}
[ ]:
id2label

Initialize a Run

[ ]:
DEVICE = infer_torch_device()

run = tlc.init(
    PROJECT_NAME,
    description="Train a SegFormer model on ADE20k toy dataset",
    parameters={
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "device": str(DEVICE),
    },
)

Setup Torch Datasets and 3LC Tables

[ ]:
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir: str, train: bool = True):
        """
        :param root_dir: Root directory of the dataset containing the images + annotations.
        :param train: Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.train = train

        sub_path = "training" if self.train else "validation"
        self.img_dir = os.path.join(self.root_dir, "images", sub_path)
        self.ann_dir = os.path.join(self.root_dir, "annotations", sub_path)

        # read images
        image_file_names = []
        for _, _, files in os.walk(self.img_dir):
            image_file_names.extend(files)
        self.images = sorted(image_file_names)

        # read annotations
        annotation_file_names = []
        for _, _, files in os.walk(self.ann_dir):
            annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.img_dir, self.images[idx]))
        segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx]))

        # We need to include the original segmentation map size, in order to post-process the model output
        return image, segmentation_map, (segmentation_map.size[1], segmentation_map.size[0])
[ ]:
train_dataset = SemanticSegmentationDataset(root_dir=DATASET_ROOT, train=True)
val_dataset = SemanticSegmentationDataset(root_dir=DATASET_ROOT, train=False)
[ ]:
train_dataset[0][1]

Create the Tables

[ ]:
structure = (
    tlc.PILImage("image"),
    tlc.SegmentationPILImage("segmentation_map", classes=unreduced_label_map),
    tlc.HorizontalTuple("mask size", [tlc.Int("width"), tlc.Int("height")]),
)

train_table = tlc.Table.from_torch_dataset(
    train_dataset,
    structure,
    project_name=PROJECT_NAME,
    dataset_name=DATASET_NAME,
    table_name="train",
    if_exists="overwrite",
)

val_table = tlc.Table.from_torch_dataset(
    val_dataset,
    structure,
    project_name=PROJECT_NAME,
    dataset_name=DATASET_NAME,
    table_name="val",
    if_exists="overwrite",
)
[ ]:
class MapFn:
    def __init__(self, image_processor: SegformerImageProcessor):
        self.image_processor = image_processor

    def __call__(self, sample):
        image, segmentation_map, mask_size = sample
        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")

        for k, _ in encoded_inputs.items():
            encoded_inputs[k].squeeze_()  # remove batch dimension

        encoded_inputs.update({"mask_size": torch.tensor(mask_size)})

        return encoded_inputs


image_processor = SegformerImageProcessor(reduce_labels=True)

# Apply the image processor to the datasets
train_table.map(MapFn(image_processor))
val_table.map(MapFn(image_processor))
[ ]:
train_table[0].keys()
[ ]:
train_table.url

Define the model

[ ]:
from transformers import SegformerForSemanticSegmentation

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0",
    num_labels=150,
    id2label=id2label,
    label2id=label2id,
).to(DEVICE)
[ ]:
# Predict on single sample
model(train_table[0]["pixel_values"].unsqueeze(0).to(DEVICE))

Setup metrics collection

[ ]:
# 1. EmbeddingsMetricsCollector to collect hidden layer activations
for ind, layer in enumerate(model.named_modules()):
    print(ind, "=>", layer[0])

# Interesting layers for embedding collection:
#   - segformer.encoder.layer_norm.3 (Index: 197)
#   - decode_head.linear_c.2.proj (Index: 204)
#   - decode_head.linear_c.3.proj (Index: 207)

layers = [197, 204, 207]

embedding_collector = tlc.EmbeddingsMetricsCollector(layers=layers)
[ ]:
# 2. A metrics collection callable to collect per-sample loss


def metrics_fn(batch, predictor_output):
    labels = batch["labels"].to(DEVICE)
    logits = predictor_output.forward.logits
    upsampled_logits = torch.nn.functional.interpolate(
        logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
    )
    loss = torch.nn.functional.cross_entropy(upsampled_logits, labels, reduction="none", ignore_index=255)
    loss = loss.mean(dim=(1, 2))
    return {"loss": loss.detach().cpu().numpy()}
[ ]:
# 3. A SegmentationMetricsCollector to write out the predictions


def preprocess_fn(batch, predictor_output: tlc.PredictorOutput):
    """Convert logits to masks with the same size as the input, un-reduce the labels"""
    processed_masks = image_processor.post_process_semantic_segmentation(
        predictor_output.forward,
        batch["mask_size"].tolist(),
    )
    for i in range(len(processed_masks)):
        mask = processed_masks[i]
        mask[mask == 255] = 0
        mask = mask + 1
        processed_masks[i] = mask

    return batch, processed_masks


segmentation_collector = tlc.SegmentationMetricsCollector(label_map=unreduced_label_map, preprocess_fn=preprocess_fn)
[ ]:
# Define a single function to collect all metrics

# A Predictor object wraps the model and enables embedding-collection
predictor = tlc.Predictor(model, device=DEVICE, layers=layers)

# Control the arguments used for the dataloader used during metrics collection
mc_dataloader_args = {"batch_size": BATCH_SIZE}


def collect_metrics(epoch):
    tlc.collect_metrics(
        train_table,
        [segmentation_collector, metrics_fn, embedding_collector],
        predictor,
        constants={"epoch": epoch},
        dataloader_args=mc_dataloader_args,
        split="train",
    )
    tlc.collect_metrics(
        val_table,
        [segmentation_collector, metrics_fn, embedding_collector],
        predictor,
        constants={"epoch": epoch},
        dataloader_args=mc_dataloader_args,
        split="val",
    )


# Collect metrics before training (-1 means before training)
collect_metrics(-1)

Train!

[ ]:
# Uses the "weights" column of the Table to sample the data
sampler = train_table.create_sampler()

train_dataloader = DataLoader(train_table, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS)
valid_dataloader = DataLoader(val_table, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
[ ]:
def loss_fn(logits, labels):
    upsampled_logits = torch.nn.functional.interpolate(
        logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
    )
    if model.config.num_labels > 1:
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=model.config.semantic_loss_ignore_index)
        loss = loss_fct(upsampled_logits, labels)
    elif model.config.num_labels == 1:
        valid_mask = ((labels >= 0) & (labels != model.config.semantic_loss_ignore_index)).float()
        loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")
        loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
        loss = (loss * valid_mask).mean()

    return loss
[ ]:
# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)

# move model to GPU
model.to(DEVICE)

model.train()
for epoch in range(EPOCHS):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    agg_loss = 0.0
    seen_samples = 0
    for _idx, batch in enumerate(tqdm(train_dataloader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        _, logits = outputs.loss, outputs.logits
        loss = loss_fn(outputs.logits, labels)

        agg_loss += loss.item() * pixel_values.shape[0]
        seen_samples += pixel_values.shape[0]

        loss.backward()
        optimizer.step()

        # evaluate
        with torch.no_grad():
            upsampled_logits = torch.nn.functional.interpolate(
                logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
            predicted = upsampled_logits.argmax(dim=1)

    # Log aggregated metrics directly to the active Run
    tlc.log(
        {
            "epoch": epoch,
            "running_train_loss": loss.item() / seen_samples,
        }
    )

    if epoch % 50 == 0 and epoch != 0:
        collect_metrics(epoch)

Collect metrics after training

[ ]:
collect_metrics(epoch)

Dimensionality reduce collected metrics