Training a finetuned SegFormer model with Pytorch Lightning#
This notebook is a modified version of the official Colab tutorial of “Roboflow How to Train SegFormer” which can be found here.
In this tutorial we will see how to fine-tune a pre-trained SegFormer model for semantic segmentation on a custom dataset. We will integrate with 3LC by creating a training run, registering 3LC datasets, and collecting per-sample predicted masks.
This notebook demonstrates:
Training a SegFormer model on a custom dataset with Pytorch Lightning.
Registering train/val/test sets into 3LC Tables
Collecting per-sample semantic segmentation, predicted masks through callbacks
Project Setup#
[2]:
PROJECT_NAME = "PyTorch Lightning Image Segmentation"
RUN_NAME = "Train Balloon SegFormer"
DESCRIPTION = "Train a SegFormer model using PyTorch Lightning"
TRAIN_DATASET_NAME = "balloons-train"
VAL_DATASET_NAME = "balloons-val"
TEST_DATASET_NAME = "balloons-test"
MODEL = "nvidia/mit-b0"
TEST_DATA_PATH = "./data"
EPOCHS = 100
BATCH_SIZE = 8
NUM_WORKERS = 0
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
DEVICE = None
INSTALL_DEPENDENCIES = False
Imports#
[7]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
from evaluate import load
from matplotlib import pyplot as plt
from PIL import Image
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import tlc
[8]:
if DEVICE is None:
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
else:
device = DEVICE
device = torch.device(device)
print(f"Using device: {device}")
Using device: cuda
Setup Datasets and Training helpers#
We will create a Table with the images and their associated masks.
Moreover, we will also define helpers to pre-process this dataset into a suitable form for training and collecting metrics.
To finish, we define a Pytorch LightningModule to define the steps for training, validation and test.
[9]:
class TLCSemanticSegmentationDataset(VisionDataset):
"""Image (semantic) segmentation dataset."""
def __init__(self, root_dir, transforms=None):
super().__init__(root_dir, transforms=transforms)
self.root_dir = root_dir
image_file_names = [f for f in os.listdir(self.root_dir) if ".jpg" in f]
mask_file_names = [f for f in os.listdir(self.root_dir) if ".png" in f]
self.images = sorted(image_file_names)
self.masks = sorted(mask_file_names)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.open(os.path.join(self.root_dir, self.images[idx]))
segmentation_map = Image.open(os.path.join(self.root_dir, self.masks[idx]))
if self.transforms is not None:
return self.transforms(image, segmentation_map, image.size, segmentation_map.size)
return image, segmentation_map, image.size, segmentation_map.size
[10]:
def get_id2label(root_dir):
classes_csv_file = os.path.join(root_dir, "_classes.csv")
with open(classes_csv_file) as fid:
data = [line.split(",") for idx, line in enumerate(fid) if idx != 0]
return {x[0]: x[1].strip() for x in data}
image_processor = SegformerImageProcessor.from_pretrained(MODEL)
image_processor.do_reduce_labels = False
image_processor.size = 128
dataset_location = tlc.Url(TEST_DATA_PATH + "/balloons-mask-segmentation").to_absolute()
id2label = get_id2label(f"{dataset_location}/train/") # Assuming the same classes for train, val, and test
model = SegformerForSemanticSegmentation.from_pretrained(
MODEL,
num_labels=len(id2label.keys()),
id2label=id2label,
label2id={v: k for k, v in id2label.items()},
ignore_mismatched_sizes=True,
)
/home/build/ado/w/3/pytorch-lightning-segformer_venv/lib/python3.9/site-packages/transformers/utils/deprecation.py:165: The following named arguments are not valid for `SegformerImageProcessor.__init__` and were ignored: 'feature_extractor_type'
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[11]:
structure = (
tlc.PILImage("image"),
tlc.SegmentationPILImage("mask", id2label),
tlc.HorizontalTuple("image size", [tlc.Int("width"), tlc.Int("height")]),
tlc.HorizontalTuple("mask size", [tlc.Int("width"), tlc.Int("height")]),
)
def mc_preprocess_fn(batch, predictor_output):
"""Transform a batch of inputs and model outputs to a format expected by the metrics collector."""
original_mask_size = batch["mask_size"].tolist()
outputs = predictor_output.forward
predicted_masks = image_processor.post_process_semantic_segmentation(
outputs=outputs,
target_sizes=original_mask_size,
)
return batch, predicted_masks
segmentation_metrics_collector = tlc.SegmentationMetricsCollector(
label_map=id2label,
preprocess_fn=mc_preprocess_fn,
)
[12]:
@tlc.lightning_module(
structure=structure,
dataset_prefix="balloons",
run_name=RUN_NAME,
run_description=DESCRIPTION,
if_run_exists="overwrite",
if_dataset_exists="overwrite",
project_name=PROJECT_NAME,
metrics_collectors=[segmentation_metrics_collector],
metrics_collection_interval=10,
)
class SegformerFinetuner(pl.LightningModule):
def __init__(
self,
model,
id2label,
train_dataloader=None,
val_dataloader=None,
test_dataloader=None,
metrics_interval=100,
):
super().__init__()
self.train_dl = train_dataloader
self.val_dl = val_dataloader
self.test_dl = test_dataloader
self.metrics_interval = metrics_interval
self.num_classes = len(id2label.keys())
self.model = model
self.train_mean_iou = load("mean_iou")
self.val_mean_iou = load("mean_iou")
self.test_mean_iou = load("mean_iou")
self.training_step_outputs = [] # >=2.0.0 fix
self.validation_step_outputs = [] # >=2.0.0 fix
self.test_step_outputs = [] # >=2.0.0 fix
def forward(self, images, masks=None):
outputs = self.model(images, masks)
return outputs
def training_step(self, batch, batch_idx):
images, masks = batch["pixel_values"], batch["labels"]
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.train_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
if batch_idx % self.metrics_interval == 0:
metrics = self.train_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
metrics = {
"loss": loss,
"mean_iou": metrics["mean_iou"],
"mean_accuracy": metrics["mean_accuracy"],
}
for k, v in metrics.items():
self.log(k, v, prog_bar=True)
tlc.log(
{
**{k: v.item() for k, v in metrics.items()},
"step": self.global_step,
}
)
else:
metrics = {"loss": loss}
self.training_step_outputs.append(metrics) # >=2.0.0 fix
return metrics
def validation_step(self, batch):
images, masks = batch["pixel_values"], batch["labels"]
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.val_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
self.validation_step_outputs.append(loss) # >=2.0.0 fix
return {"val_loss": loss}
def on_validation_epoch_end(self):
metrics = self.val_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
avg_val_loss = torch.stack(self.validation_step_outputs).mean() # >=2.0.0 fix
val_mean_iou = metrics["mean_iou"]
val_mean_accuracy = metrics["mean_accuracy"]
metrics = {
"val_loss": avg_val_loss,
"val_mean_iou": val_mean_iou,
"val_mean_accuracy": val_mean_accuracy,
}
for k, v in metrics.items():
self.log(k, v, prog_bar=True)
self.validation_step_outputs.clear() # >=2.0.0 fix
if not self.trainer.sanity_checking:
tlc.log(
{
**{k: v.item() for k, v in metrics.items()},
"step": self.global_step,
}
)
return metrics
def test_step(self, batch):
images, masks = batch["pixel_values"], batch["labels"]
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.test_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
self.test_step_outputs.append(loss) # >=2.0.0 fix
return {"test_loss": loss}
def on_test_epoch_end(self):
metrics = self.test_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
avg_test_loss = torch.stack(self.test_step_outputs).mean() # >=2.0.0 fix
test_mean_iou = metrics["mean_iou"]
test_mean_accuracy = metrics["mean_accuracy"]
metrics = {
"test_loss": avg_test_loss,
"test_mean_iou": test_mean_iou,
"test_mean_accuracy": test_mean_accuracy,
}
for k, v in metrics.items():
self.log(k, v)
self.test_step_outputs.clear() # >=2.0.0 fix
return metrics
def configure_optimizers(self):
return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
def train_dataloader(self):
return self.train_dl
def val_dataloader(self):
return self.val_dl
def test_dataloader(self):
return self.test_dl
[13]:
def transforms(image, mask, image_size, mask_size):
encoded_inputs = image_processor(image, mask, return_tensors="pt")
for k, _ in encoded_inputs.items():
encoded_inputs[k].squeeze_() # remove batch dimension
# Add the original mask size to the batch so that we can resize the mask back to its original size later
encoded_inputs.update({"mask_size": torch.tensor(mask_size)})
return encoded_inputs
train_dataset = TLCSemanticSegmentationDataset(f"{dataset_location}/train/", transforms)
val_dataset = TLCSemanticSegmentationDataset(f"{dataset_location}/valid/", transforms)
test_dataset = TLCSemanticSegmentationDataset(f"{dataset_location}/test/", transforms)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
segformer_finetuner = SegformerFinetuner(
model,
id2label,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader,
)
Training time#
[14]:
early_stop_callback = EarlyStopping(
monitor="val_loss",
patience=5,
verbose=True,
)
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss", save_last=True)
trainer = pl.Trainer(
accelerator="gpu",
callbacks=[early_stop_callback, checkpoint_callback],
max_epochs=EPOCHS,
val_check_interval=len(train_dataloader),
log_every_n_steps=7,
)
trainer.fit(segformer_finetuner)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/build/ado/w/3/pytorch-lightning-segformer_venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
------------------------------------------------------------------
0 | model | SegformerForSemanticSegmentation | 3.7 M | eval
------------------------------------------------------------------
3.7 M Trainable params
0 Non-trainable params
3.7 M Total params
14.859 Total estimated model params size (MB)
0 Modules in train mode
213 Modules in eval mode
/home/build/ado/w/3/pytorch-lightning-segformer_venv/lib/python3.9/site-packages/datasets/features/image.py:348: Downcasting array dtype int64 to int32 to be compatible with 'Pillow'
Metric val_loss improved. New best score: 0.653
Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 0.637
Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.626
Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.618
Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 0.599
Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.585
Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.582
Metric val_loss improved by 0.048 >= min_delta = 0.0. New best score: 0.534
Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.530
Metric val_loss improved by 0.042 >= min_delta = 0.0. New best score: 0.488
Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.483
Metric val_loss improved by 0.037 >= min_delta = 0.0. New best score: 0.447
Metric val_loss improved by 0.022 >= min_delta = 0.0. New best score: 0.424
Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 0.404
Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 0.375
Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.366
Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 0.348
Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.339
Metric val_loss improved by 0.025 >= min_delta = 0.0. New best score: 0.314
Metric val_loss improved by 0.031 >= min_delta = 0.0. New best score: 0.283
Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.277
Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.271
Metric val_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.262
Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.259
Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.256
Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.247
Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.239
Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.234
Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.231
Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.230
Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.228
Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.227
Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.218
Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.215
Monitored metric val_loss did not improve in the last 5 records. Best score: 0.215. Signaling Trainer to stop.
Checking results#
[15]:
res = trainer.test(ckpt_path="last")
Restoring states from the checkpoint path at /home/build/ado/w/3/s/tlc-monorepo/public-notebooks/notebook_tests/pytorch-lightning-segformer/lightning_logs/version_0/checkpoints/last.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/build/ado/w/3/s/tlc-monorepo/public-notebooks/notebook_tests/pytorch-lightning-segformer/lightning_logs/version_0/checkpoints/last.ckpt
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_loss │ 0.17538779973983765 │ │ test_mean_accuracy │ 0.9275992512702942 │ │ test_mean_iou │ 0.847816526889801 │ └───────────────────────────┴───────────────────────────┘
[16]:
%matplotlib inline
batch = next(iter(test_dataloader))
images, masks = batch["pixel_values"].to(device), batch["labels"].to(device)
segformer_finetuner.eval().to(device)
outputs = segformer_finetuner.model(images, masks, return_dict=True)
batch_prediction = image_processor.post_process_semantic_segmentation(outputs, batch["mask_size"].tolist())
n_rows = len(images)
n_cols = 3
fig_width = n_cols * 5
fig_height = n_rows * 5
fig, ax = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))
fig.suptitle("Test Batch Predictions", fontsize=16)
plt.tight_layout(pad=3.0, h_pad=-1.0, w_pad=1.0, rect=[0, 0, 1, 1])
for i in range(n_rows):
for j in range(3):
ax[i, j].axis("off")
ax[i, 0].imshow(masks[i, :, :].cpu().numpy(), cmap="gray")
ax[i, 0].set_title(f"Ground Truth (id={i})", fontsize=14)
ax[i, 1].imshow(batch_prediction[i].cpu().numpy(), cmap="gray")
ax[i, 1].set_title("Predicted mask (latest model)", fontsize=14)
im = tlc.active_run().metrics_tables[-1][i]["predicted_mask"]
ax[i, 2].imshow(np.array(im), cmap="gray")
ax[i, 2].set_title("Predicted mask (3LC)", fontsize=14)