Training a finetuned SegFormer model with Pytorch Lightning#

This notebook is a modified version of the official Colab tutorial of “Roboflow How to Train SegFormer” which can be found here.

In this tutorial we will see how to fine-tune a pre-trained SegFormer model for semantic segmentation on a custom dataset. We will integrate with 3LC by creating a training run, registering 3LC datasets, and collecting per-sample predicted masks.

This notebook demonstrates:

  • Training a SegFormer model on a custom dataset with Pytorch Lightning.

  • Registering train/val/test sets into 3LC Tables

  • Collecting per-sample semantic segmentation, predicted masks through callbacks

[2]:
# Parameters
PROJECT_NAME = "PyTorch Lightning Image Segmentation"
RUN_NAME = "Train Balloon SegFormer"
DESCRIPTION = "Train a SegFormer model using PyTorch Lightning"
TRAIN_DATASET_NAME = "balloons-train"
VAL_DATASET_NAME = "balloons-val"
TEST_DATASET_NAME = "balloons-test"
TRANSIENT_DATA_PATH = "../transient_data"
MODEL = "nvidia/mit-b5"
TEST_DATA_PATH = "../../tests/test_data/data"
EPOCHS = 100
BATCH_SIZE = 8
NUM_WORKERS = 0
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
INSTALL_DEPENDENCIES = False
[7]:
import os
from typing import Any, Callable

import numpy as np
import pytorch_lightning as pl
import tlc
import torch
from evaluate import load
from matplotlib import pyplot as plt
from PIL import Image
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor

Initialize a 3LC Run#

First, we initialize a 3LC run. This will create a new empty run which will be visible in the 3LC dashboard.

[9]:
run = tlc.init(
    project_name=PROJECT_NAME,
    run_name=RUN_NAME,
    description=DESCRIPTION,
    if_exists="overwrite",
)

Setup Datasets and Training helpers#

We will create a Table with the images and their associated masks.

Moreover, we will also define helpers to preprocess this dataset into a suitable form for training and collecting metrics.

To finish, we define a Pytorch LightningModule to define the steps for training, validation and test.

[10]:
class TLCSemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, image_processor):
        self.root_dir = root_dir
        self.image_processor = image_processor
        image_file_names = [f for f in os.listdir(self.root_dir) if ".jpg" in f]
        mask_file_names = [f for f in os.listdir(self.root_dir) if ".png" in f]
        self.images = sorted(image_file_names)
        self.masks = sorted(mask_file_names)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.images[idx]))
        segmentation_map = Image.open(os.path.join(self.root_dir, self.masks[idx]))
        return image, segmentation_map, image.size, segmentation_map.size

    def transform_to_seg_former_format(self, sample):
        image, segmentation_map, _, _ = sample
        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")
        for k, _ in encoded_inputs.items():
            encoded_inputs[k].squeeze_()  # remove batch dimension
        return encoded_inputs

    def transform_to_collect_metrics(self, sample):
        sample_preprocessed = self.transform_to_seg_former_format(sample)
        images, masks = sample_preprocessed["pixel_values"], sample_preprocessed["labels"]
        return (images, masks) + sample[-2:]


def get_id2label(root_dir):
    classes_csv_file = os.path.join(root_dir, "_classes.csv")
    with open(classes_csv_file) as fid:
        data = [line.split(",") for idx, line in enumerate(fid) if idx != 0]
    return {x[0]: x[1].strip() for x in data}
[11]:
class SegformerFinetuner(pl.LightningModule):
    def __init__(
        self,
        id2label,
        image_processor,
        train_dataloader=None,
        val_dataloader=None,
        test_dataloader=None,
        metrics_interval=100,
    ):
        super().__init__()
        self.id2label = id2label
        self.image_processor = image_processor
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.test_dl = test_dataloader
        self.metrics_interval = metrics_interval

        self.num_classes = len(id2label.keys())
        self.label2id = {v: k for k, v in self.id2label.items()}

        self.model = SegformerForSemanticSegmentation.from_pretrained(
            MODEL,
            return_dict=False,
            num_labels=self.num_classes,
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True,
        )

        self.train_mean_iou = load("mean_iou")
        self.val_mean_iou = load("mean_iou")
        self.test_mean_iou = load("mean_iou")

        self.training_step_outputs = []  # >=2.0.0 fix
        self.validation_step_outputs = []  # >=2.0.0 fix
        self.test_step_outputs = []  # >=2.0.0 fix

    def forward(self, images, masks):
        outputs = self.model(pixel_values=images, labels=masks)
        return outputs

    def training_step(self, batch, batch_idx):
        images, masks = batch["pixel_values"], batch["labels"]
        outputs = self(images, masks)
        loss, logits = outputs[0], outputs[1]
        upsampled_logits = nn.functional.interpolate(
            logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
        )
        predicted = upsampled_logits.argmax(dim=1)
        self.train_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )
        if batch_idx % self.metrics_interval == 0:
            metrics = self.train_mean_iou.compute(
                num_labels=self.num_classes,
                ignore_index=255,
                reduce_labels=False,
            )
            metrics = {
                "loss": loss,
                "mean_iou": metrics["mean_iou"],
                "mean_accuracy": metrics["mean_accuracy"],
            }
            for k, v in metrics.items():
                self.log(k, v, prog_bar=True)

        else:
            metrics = {"loss": loss}

        self.training_step_outputs.append(metrics)  # >=2.0.0 fix
        return metrics

    def validation_step(self, batch):
        images, masks = batch["pixel_values"], batch["labels"]
        outputs = self(images, masks)
        loss, logits = outputs[0], outputs[1]
        upsampled_logits = nn.functional.interpolate(
            logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
        )
        predicted = upsampled_logits.argmax(dim=1)
        self.val_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )
        self.validation_step_outputs.append(loss)  # >=2.0.0 fix

        return {"val_loss": loss}

    def on_validation_epoch_end(self):
        metrics = self.val_mean_iou.compute(
            num_labels=self.num_classes,
            ignore_index=255,
            reduce_labels=False,
        )

        avg_val_loss = torch.stack(self.validation_step_outputs).mean()  # >=2.0.0 fix
        val_mean_iou = metrics["mean_iou"]
        val_mean_accuracy = metrics["mean_accuracy"]

        metrics = {
            "val_loss": avg_val_loss,
            "val_mean_iou": val_mean_iou,
            "val_mean_accuracy": val_mean_accuracy,
        }
        for k, v in metrics.items():
            self.log(k, v, prog_bar=True)

        self.validation_step_outputs.clear()  # >=2.0.0 fix

        return metrics

    def test_step(self, batch):
        images, masks = batch["pixel_values"], batch["labels"]
        outputs = self(images, masks)
        loss, logits = outputs[0], outputs[1]
        upsampled_logits = nn.functional.interpolate(
            logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
        )
        predicted = upsampled_logits.argmax(dim=1)
        self.test_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )
        self.test_step_outputs.append(loss)  # >=2.0.0 fix

        return {"test_loss": loss}

    def on_test_epoch_end(self):
        metrics = self.test_mean_iou.compute(
            num_labels=self.num_classes,
            ignore_index=255,
            reduce_labels=False,
        )

        avg_test_loss = torch.stack(self.test_step_outputs).mean()  # >=2.0.0 fix
        test_mean_iou = metrics["mean_iou"]
        test_mean_accuracy = metrics["mean_accuracy"]
        metrics = {
            "test_loss": avg_test_loss,
            "test_mean_iou": test_mean_iou,
            "test_mean_accuracy": test_mean_accuracy,
        }
        for k, v in metrics.items():
            self.log(k, v)
        self.test_step_outputs.clear()  # >=2.0.0 fix

        return metrics

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)

    def train_dataloader(self):
        return self.train_dl

    def val_dataloader(self):
        return self.val_dl

    def test_dataloader(self):
        return self.test_dl
[12]:
image_processor = SegformerImageProcessor.from_pretrained(
    MODEL,
    reduce_labels=False,
    size=128,
)
dataset_location = tlc.Url(TEST_DATA_PATH + "/balloons-mask-segmentation").to_absolute()

id2label = get_id2label(f"{dataset_location}/train/")  # Assuming the same classes for train, val, and test
structure = (
    tlc.PILImage("image"),
    tlc.SegmentationPILImage("mask", id2label),
    tlc.HorizontalTuple("image size", [tlc.Int("width"), tlc.Int("height")]),
    tlc.HorizontalTuple("mask size", [tlc.Int("width"), tlc.Int("height")]),
)

tlc_train_dataset = TLCSemanticSegmentationDataset(f"{dataset_location}/train/", image_processor)
tlc_train_dataset = (
    tlc.Table.from_torch_dataset(
        dataset=tlc_train_dataset,
        structure=structure,
        table_name="train_dataset",
        dataset_name=TRAIN_DATASET_NAME,
        project_name=PROJECT_NAME,
        if_exists="overwrite",
    )
    .map(tlc_train_dataset.transform_to_seg_former_format)
    .map_collect_metrics(tlc_train_dataset.transform_to_collect_metrics)
)

tlc_val_dataset = TLCSemanticSegmentationDataset(f"{dataset_location}/valid/", image_processor)
tlc_val_dataset = (
    tlc.Table.from_torch_dataset(
        dataset=tlc_val_dataset,
        structure=structure,
        table_name="val_dataset",
        dataset_name=VAL_DATASET_NAME,
        project_name=PROJECT_NAME,
        if_exists="overwrite",
    )
    .map(tlc_val_dataset.transform_to_seg_former_format)
    .map_collect_metrics(tlc_val_dataset.transform_to_collect_metrics)
)

tlc_test_dataset = TLCSemanticSegmentationDataset(f"{dataset_location}/test/", image_processor)
tlc_test_dataset = (
    tlc.Table.from_torch_dataset(
        dataset=tlc_test_dataset,
        structure=structure,
        table_name="test_dataset",
        dataset_name=TEST_DATASET_NAME,
        project_name=PROJECT_NAME,
    )
    .map(tlc_test_dataset.transform_to_seg_former_format)
    .map_collect_metrics(tlc_test_dataset.transform_to_collect_metrics)
)


sampler = tlc_train_dataset.create_sampler()

train_dataloader = DataLoader(tlc_train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, sampler=sampler)
val_dataloader = DataLoader(tlc_val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(tlc_test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

segformer_finetuner = SegformerFinetuner(
    id2label,
    image_processor=image_processor,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    metrics_interval=10,
)
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b5 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Setup Callback to register predicted mask to 3LC#

[13]:
from tlc.client.torch.metrics.metrics_collectors.segmentation_metrics_collector import SegmentationMetricsCollector

metrics_collection_dataloader_args = {
    "num_workers": NUM_WORKERS,
    "batch_size": BATCH_SIZE,
}


class MetricsCollectionCallBack(pl.Callback):
    def __init__(self, dataset, post_process_function: Callable[[Any], object]) -> None:
        super().__init__()
        self.dataset = dataset
        self.post_process_function = post_process_function

    def on_train_epoch_end(
        self, trainer, pl_module
    ):  # You could define this inside  on_train_end if you just want to run on the last epoch.
        segmentation_metrics_collector = SegmentationMetricsCollector(
            segmentation_model=pl_module.model,
            id2label=id2label,
            post_process_function=self.post_process_function,
            current_epoch=pl_module.current_epoch,
        )
        pl_module.eval()
        tlc.collect_metrics(
            table=self.dataset,
            metrics_collectors=[segmentation_metrics_collector],
            constants={"epoch": pl_module.current_epoch},
            dataloader_args=metrics_collection_dataloader_args,
        )
        pl_module.train()

Training time#

[14]:
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=False,
)

checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss", save_last=True)

collect_metrics_on_train = MetricsCollectionCallBack(
    dataset=tlc_train_dataset, post_process_function=image_processor.post_process_semantic_segmentation
)  # Setting up the callback on the train dataset to collect metrics

collect_metrics_on_test = MetricsCollectionCallBack(
    dataset=tlc_test_dataset, post_process_function=image_processor.post_process_semantic_segmentation
)

trainer = pl.Trainer(
    accelerator="gpu",
    callbacks=[early_stop_callback, checkpoint_callback, collect_metrics_on_train, collect_metrics_on_test],
    max_epochs=EPOCHS,
    val_check_interval=len(train_dataloader),
    log_every_n_steps=7,
)

trainer.fit(segformer_finetuner)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/build/ado/w/2/pytorch-lightning-segformer_venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Missing logger folder: /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/notebook_tests/pytorch-lightning-segformer/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                             | Params
-----------------------------------------------------------
0 | model | SegformerForSemanticSegmentation | 84.6 M
-----------------------------------------------------------
84.6 M    Trainable params
0         Non-trainable params
84.6 M    Total params
338.380   Total estimated model params size (MB)
/home/build/ado/w/2/pytorch-lightning-segformer_venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/build/ado/w/2/pytorch-lightning-segformer_venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.

Checking results#

[15]:
res = trainer.test(ckpt_path="last")
Restoring states from the checkpoint path at /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/notebook_tests/pytorch-lightning-segformer/lightning_logs/version_0/checkpoints/last.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/build/ado/w/2/s/tlc-monorepo/public-notebooks/notebook_tests/pytorch-lightning-segformer/lightning_logs/version_0/checkpoints/last.ckpt
/home/build/ado/w/2/pytorch-lightning-segformer_venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss             0.08859418332576752    │
│    test_mean_accuracy         0.9040089845657349     │
│       test_mean_iou           0.8491796255111694     │
└───────────────────────────┴───────────────────────────┘
[16]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch = next(iter(test_dataloader))
images, masks = batch["pixel_values"].to(DEVICE), batch["labels"].to(DEVICE)
segformer_finetuner.eval().to(DEVICE)
outputs = segformer_finetuner.model(images, masks, return_dict=True)
batch_prediction = image_processor.post_process_semantic_segmentation(outputs, [(640, 640)] * 8)
loss, logits = outputs[0], outputs[1]

n_plots = len(images)
fig, ax = plt.subplots(n_plots, 3)
fig.set_figheight(30)
fig.set_figwidth(30)
fig.subplots_adjust(wspace=-0.80)
for i in range(n_plots):
    ax[i, 0].imshow(masks[i, :, :].cpu().numpy(), cmap="gray")
    ax[i, 0].set_title("Mask id=" + str(i))

    ax[i, 1].imshow(batch_prediction[i].cpu().numpy(), cmap="gray")
    ax[i, 1].set_title("Predicted mask from model id=" + str(i))

    im = Image.open(run.metrics_tables[-1].table_rows[i]["predicted_mask"])
    ax[i, 2].imshow(np.array(im), cmap="gray")
    ax[i, 2].set_title("Predicted mask from 3lc id=" + str(i))

../_images/public-notebooks_pytorch-lightning-segformer_21_0.png