Balloons Toy Dataset + Detectron2 + 3LC Tutorial#

This notebook is a modified version of the official colab tutorial of detectron which can be found here.

In this tutorial we will see how to fine-tune a pre-trained detectron model for object detection on a custom dataset in the COCO format. We will integrate with 3LC by creating a training run, registering 3LC datasets, and collecting per-sample bounding box metrics.

This notebook demonstrates:

  • Training a detectron2 model on a custom dataset.

  • Integrating a COCO dataset with 3LC using register_coco_instances().

  • Collecting per-sample bounding box metrics using BoundingBoxMetricsCollector.

  • Registering a custom per-sample metrics collection callback.

Setup Project#

[2]:
PROJECT_NAME = "Balloons"
RUN_NAME = "Train Balloon Detector"
DESCRIPTION = "Train a balloon detector using detectron2"
TRAIN_DATASET_NAME = "balloons-train"
VAL_DATASET_NAME = "balloons-val"
TRANSIENT_DATA_PATH = "../transient_data"
TEST_DATA_PATH = "./data"
MODEL_CONFIG = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
MAX_ITERS = 200
BATCH_SIZE = 2
MAX_DETECTIONS_PER_IMAGE = 30
SCORE_THRESH_TEST = 0.5
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
EPOCHS = 1
INSTALL_DEPENDENCIES = False
[4]:
%%capture
if INSTALL_DEPENDENCIES:
    # NOTE: There is no single version of detectron2 that is appropriate for all users and all systems.
    #       This notebook uses a particular prebuilt version of detectron2 that is only available for
    #       Linux and for specific versions of torch, torchvision, and CUDA. It may not be appropriate
    #       for your system. See https://detectron2.readthedocs.io/en/latest/tutorials/install.html for
    #       instructions on how to install or build a version of detectron2 for your system.
    %pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
    %pip install detectron2 -f "https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.10/index.html"
    %pip install 3lc
    %pip install opencv-python
    %pip install matplotlib

Imports#

[9]:
import detectron2
import tlc
import torch

TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("tlc: ", tlc.__version__)
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)
tlc:  99.0.0
torch:  1.10 ; cuda:  cu111
detectron2: 0.6
[10]:
# Some basic setup:
from __future__ import annotations

# Setup detectron2 logger
from detectron2.utils.logger import setup_logger

logger = setup_logger()
logger.setLevel("ERROR")

# import some common libraries
import os
import cv2
import random
import matplotlib.pyplot as plt

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

Prepare the dataset#

In this section, we show how to train an existing detectron2 model on a custom dataset in the COCO format.

We use the balloon segmentation dataset which only has one class: balloon.

You can find a modified COCO version of this dataset inside the “data” directory provided while cloning our repository.

We’ll train a balloon segmentation model from an existing model pre-trained on COCO dataset, available in detectron2’s model zoo.

Note that COCO dataset does not have the “balloon” category. We’ll be able to recognize this new class in a few minutes.

Register the dataset with 3LC#

Now that we have the dataset in the COCO format, we can register it with 3LC.

[12]:
from tlc.integration.detectron2 import register_coco_instances

register_coco_instances(
    TRAIN_DATASET_NAME,
    {},
    train_json_path.to_str(),
    train_image_folder.to_str(),
    project_name=PROJECT_NAME,
)

register_coco_instances(
    VAL_DATASET_NAME,
    {},
    val_json_path.to_str(),
    val_image_folder.to_str(),
    project_name=PROJECT_NAME,
)
[13]:
# The detectron2 dataset dicts and dataset metadata can be read from the DatasetCatalog and
# MetadataCatalog, respectively.
dataset_metadata = MetadataCatalog.get(TRAIN_DATASET_NAME)
dataset_dicts = DatasetCatalog.get(TRAIN_DATASET_NAME)

To verify the dataset is in correct format, let’s visualize the annotations of randomly selected samples in the training set:

[14]:
import numpy as np
from detectron2.utils.file_io import PathManager

for d in random.sample(dataset_dicts, 3):
    filename = tlc.Url(d["file_name"]).to_absolute().to_str()
    if "s3://" in filename:
        with PathManager.open(filename, "rb") as f:
            img = np.asarray(bytearray(f.read()), dtype="uint8")
            img = cv2.imdecode(img, cv2.IMREAD_COLOR)
    else:
        img = cv2.imread(filename)
    visualizer = Visualizer(img[:, :, ::-1], metadata=dataset_metadata, scale=0.5)
    out = visualizer.draw_dataset_dict(d)
    out_rgb = cv2.cvtColor(out.get_image(), cv2.COLOR_BGR2RGB)
    plt.imshow(out_rgb[:, :, ::-1])
    plt.title(filename.split("/")[-1])
    plt.show()
../_images/public-notebooks_detectron2-balloons_21_0.png
../_images/public-notebooks_detectron2-balloons_21_1.png
../_images/public-notebooks_detectron2-balloons_21_2.png

Create a custom metrics collection function#

We will use a BoundingBoxMetricsCollection to collect per-sample bounding box metrics. This allows users to supply a custom function to collect the metrics.

[15]:
def custom_bbox_metrics_collector(
    gts: list[tlc.COCOGroundTruth], preds: list[tlc.COCOPrediction], metrics: dict[str, list] | None = None
) -> None:
    """Example function that computes custom metrics for bounding box detection."""

    # Lets just return the number of ground truth boxes and predictions
    num_gts = [len(gt["annotations"]) for gt in gts]
    num_preds = [len(pred["annotations"]) for pred in preds]

    metrics["num_gts"] = num_gts
    metrics["num_preds"] = num_preds

Train!#

Now, let’s fine-tune a COCO-pretrained R50-FPN Mask R-CNN model on the balloon dataset. It takes ~2 minutes to train 300 iterations on a P100 GPU.

[16]:
import tlc

run = tlc.init(
    project_name=PROJECT_NAME,
    run_name=RUN_NAME,
    description=DESCRIPTION,
    if_exists="overwrite",
)
[17]:
# For a full list of config values: https://github.com/facebookresearch/detectron2/blob/main/detectron2/config/defaults.py
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(MODEL_CONFIG))
cfg.DATASETS.TRAIN = (TRAIN_DATASET_NAME,)
cfg.DATASETS.TEST = (VAL_DATASET_NAME,)
cfg.DATALOADER.NUM_WORKERS = 0
cfg.OUTPUT_DIR = TRANSIENT_DATA_PATH
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(MODEL_CONFIG)  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = BATCH_SIZE  # This is the real "batch size" commonly known to deep learning people
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = (
    MAX_ITERS  # Seems good enough for this toy dataset; you will need to train longer for a practical dataset
)
cfg.SOLVER.STEPS = []  # Do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (
    128  # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # Only has one class (balloon).

cfg.TEST.DETECTIONS_PER_IMAGE = MAX_DETECTIONS_PER_IMAGE
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = SCORE_THRESH_TEST
cfg.MODEL.DEVICE = "cuda"
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

config = {
    "model_config": MODEL_CONFIG,
    "solver.ims_per_batch": BATCH_SIZE,
    "test.detections_per_image": MAX_DETECTIONS_PER_IMAGE,
    "model.roi_heads.score_thresh_test": SCORE_THRESH_TEST,
}

run.set_parameters(config)
[18]:
from detectron2.engine import DefaultTrainer
from tlc.integration.detectron2 import DetectronMetricsCollectionHook, MetricsCollectionHook

trainer = DefaultTrainer(cfg)

metrics_collector = tlc.BoundingBoxMetricsCollector(
    classes=dataset_metadata.thing_classes,
    label_mapping=dataset_metadata.thing_dataset_id_to_contiguous_id,
    iou_threshold=0.5,
    compute_derived_metrics=True,
    extra_metrics_fn=custom_bbox_metrics_collector,
)

# Add schemas for the custom metrics defined above
metrics_collector.add_schema(
    "num_gts",
    tlc.Schema(value=tlc.Int32Value(value_min=0), description="The number of ground truth boxes"),
)
metrics_collector.add_schema(
    "num_preds",
    tlc.Schema(value=tlc.Int32Value(value_min=0), description="The number of predicted boxes"),
)

# Register the metrics collector with the trainer;
# + Collect metrics on the training set every 50 iterations starting at iteration 0
# + Collect metrics on the validation set after training
# + Collect default detectron2 metrics every 5 iterations
trainer.register_hooks(
    [
        MetricsCollectionHook(
            dataset_name=TRAIN_DATASET_NAME,
            metrics_collectors=[metrics_collector],
            collection_frequency=50,
            collection_start_iteration=0,
            collect_metrics_after_train=True,
        ),
        MetricsCollectionHook(
            dataset_name=VAL_DATASET_NAME,
            metrics_collectors=[metrics_collector],
            collect_metrics_after_train=True,
        ),
        DetectronMetricsCollectionHook(
            run_url=run.url,
            collection_frequency=5,
        ),
    ]
)
trainer.resume_or_load(resume=False)
trainer.train()
model_final_f10217.pkl: 178MB [00:00, 213MB/s]