View source Download .ipynb

Train a YOLO-NAS model for pose estimation with SuperGradients¶

This tutorial trains a SuperGradients YOLO-NAS model for pose estimation the AnimalPose dataset.

The input Table required for running this notebook is created in create-custom-keypoints-table.ipynb.

image1

This notebook is a modified version of the SuperGradients YoloNAS Pose Fine Tuning Notebook.

Install dependencies¶

[ ]:
%pip install 3lc
%pip install super-gradients
%pip install termcolor==3.1.0
%pip install git+https://github.com/3lc-ai/3lc-examples.git

Project setup¶

[ ]:
PROJECT_NAME = "3LC Tutorials - 2D Keypoints"
DATASET_NAME = "AnimalPose"
TABLE_NAME = "initial"
MODEL_NAME = "yolo_nas_pose_n"
RUN_NAME = "fine-tune-yolo-nas-pose-n-animalpose"
BATCH_SIZE = 16
NUM_WORKERS = 0
DOWNLOAD_PATH = "../../transient_data"
MAX_EPOCHS = 10
IMAGE_SIZE = 640

Imports¶

[ ]:
import requests
from super_gradients.training import Trainer, models
from super_gradients.training.datasets.pose_estimation_datasets import YoloNASPoseCollateFN
from super_gradients.training.metrics import PoseEstimationMetrics
from super_gradients.training.models.pose_estimation_models.yolo_nas_pose import YoloNASPosePostPredictionCallback
from super_gradients.training.transforms.keypoints import (
    KeypointsBrightnessContrast,
    KeypointsHSV,
    KeypointsImageStandardize,
    KeypointsLongestMaxSize,
    KeypointsPadIfNeeded,
    KeypointsRandomAffineTransform,
    KeypointsRandomHorizontalFlip,
    KeypointsRemoveSmallObjects,
)
from super_gradients.training.utils.callbacks import Callback
from tlc.core import KeypointHelper, Table
from tlc.integration.super_gradients import PoseEstimationDataset, PoseEstimationMetricsCollectionCallback
from torch.utils.data import DataLoader
from torchmetrics.metric import Metric

from tlc_tools.split import split_table

Download pretrained model¶

[ ]:
from pathlib import Path

MODEL_PATH = Path(DOWNLOAD_PATH) / "yolo_nas_pose_n_coco_pose.pth"

if not MODEL_PATH.exists():
    MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
    response = requests.get("https://sg-hub-nv.s3.amazonaws.com/models/yolo_nas_pose_n_coco_pose.pth")
    MODEL_PATH.write_bytes(response.content)

Load and split input tables¶

[ ]:
initial_table = Table.from_names(TABLE_NAME, DATASET_NAME, PROJECT_NAME)


def split_by(table_row):
    """Callable to get the label of the first keypoint instance

    This allows us to do a stratified split by label, just like in the original SuperGradients notebook.
    """
    return table_row["keypoints_2d"]["instances_additional_data"]["label"][0]


train_val_test = split_table(
    initial_table,
    splits={"train": 0.8, "val_test": 0.2},
    split_strategy="stratified",
    split_by=split_by,
    random_seed=42,
    shuffle=False,
)

test_val = split_table(
    train_val_test["val_test"],
    splits={"val": 0.5, "test": 0.5},
    split_strategy="stratified",
    split_by=split_by,
    shuffle=False,
    random_seed=42,
)
[ ]:
train_table = train_val_test["train"]
val_table = test_val["val"]
test_table = test_val["test"]
[ ]:
print(initial_table)
print(train_table)
print(val_table)
print(test_table)

Prepare for training¶

[ ]:
def create_transforms(image_size: int, flip_indices: list[int]):
    keypoints_random_horizontal_flip = KeypointsRandomHorizontalFlip(flip_index=flip_indices, prob=0.5)
    keypoints_hsv = KeypointsHSV(prob=0.5, hgain=20, sgain=20, vgain=20)
    keypoints_brightness_contrast = KeypointsBrightnessContrast(
        prob=0.5, brightness_range=[0.8, 1.2], contrast_range=[0.8, 1.2]
    )
    keypoints_random_affine_transform = KeypointsRandomAffineTransform(
        max_rotation=0,
        min_scale=0.5,
        max_scale=1.5,
        max_translate=0.1,
        image_pad_value=127,
        mask_pad_value=1,
        prob=0.75,
        interpolation_mode=[0, 1, 2, 3, 4],
    )
    keypoints_longest_max_size = KeypointsLongestMaxSize(max_height=image_size, max_width=image_size)
    keypoints_pad_if_needed = KeypointsPadIfNeeded(
        min_height=image_size,
        min_width=image_size,
        image_pad_value=[127, 127, 127],
        mask_pad_value=1,
        padding_mode="bottom_right",
    )
    keypoints_image_standardize = KeypointsImageStandardize(max_value=255)
    keypoints_remove_small_objects = KeypointsRemoveSmallObjects(min_instance_area=1, min_visible_keypoints=1)

    train_transforms = [
        keypoints_random_horizontal_flip,
        keypoints_hsv,
        keypoints_brightness_contrast,
        keypoints_random_affine_transform,
        keypoints_longest_max_size,
        keypoints_pad_if_needed,
        keypoints_image_standardize,
        keypoints_remove_small_objects,
    ]

    val_transforms = [
        keypoints_longest_max_size,
        keypoints_pad_if_needed,
        keypoints_image_standardize,
    ]

    return train_transforms, val_transforms


def create_training_params(max_epochs: int, callbacks: list[Callback], metrics: list[Metric], oks_sigmas: list[float]):
    return {
        "seed": 42,
        "warmup_mode": "LinearBatchLRWarmup",
        "warmup_initial_lr": 1e-8,
        "lr_warmup_epochs": 2,
        "initial_lr": 5e-4,
        "lr_mode": "cosine",
        "cosine_final_lr_ratio": 0.05,
        "max_epochs": max_epochs,
        "zero_weight_decay_on_bias_and_bn": True,
        "batch_accumulate": 1,
        "average_best_models": False,
        "save_ckpt_epoch_list": [],
        "loss": "yolo_nas_pose_loss",
        "criterion_params": {
            "oks_sigmas": oks_sigmas,
            "classification_loss_weight": 1.0,
            "classification_loss_type": "focal",
            "regression_iou_loss_type": "ciou",
            "iou_loss_weight": 2.5,
            "dfl_loss_weight": 0.01,
            "pose_cls_loss_weight": 1.0,
            "pose_reg_loss_weight": 34.0,
            "pose_classification_loss_type": "focal",
            "rescale_pose_loss_with_assigned_score": True,
            "assigner_multiply_by_pose_oks": True,
        },
        "optimizer": "AdamW",
        "optimizer_params": {"weight_decay": 0.000001},
        "ema": True,
        "ema_params": {"decay": 0.997, "decay_type": "threshold"},
        "mixed_precision": True,
        "sync_bn": False,
        "valid_metrics_list": metrics,
        "phase_callbacks": callbacks,
        "pre_prediction_callback": None,
        "metric_to_watch": "AP",
        "greater_metric_to_watch_is_better": True,
    }
[ ]:
flip_indices = KeypointHelper.get_flip_indices_from_table(initial_table)
oks_sigmas = KeypointHelper.get_oks_sigmas_from_table(initial_table)

train_transforms, val_transforms = create_transforms(image_size=IMAGE_SIZE, flip_indices=flip_indices)
[ ]:
train_dataset = PoseEstimationDataset(train_table, transforms=train_transforms)
val_dataset = PoseEstimationDataset(val_table, transforms=val_transforms)
[ ]:
post_prediction_callback = YoloNASPosePostPredictionCallback(
    pose_confidence_threshold=0.01,
    nms_iou_threshold=0.7,
    pre_nms_max_predictions=100,
    post_nms_max_predictions=15,
)

pose_estimation_metrics = PoseEstimationMetrics(
    num_joints=train_dataset.num_joints,
    oks_sigmas=oks_sigmas,
    max_objects_per_image=15,
    post_prediction_callback=post_prediction_callback,
)

tlc_callback = PoseEstimationMetricsCollectionCallback(project_name=PROJECT_NAME, run_name=RUN_NAME)
[ ]:
training_params = create_training_params(
    max_epochs=MAX_EPOCHS,
    callbacks=[tlc_callback],
    metrics=[pose_estimation_metrics],
    oks_sigmas=oks_sigmas,
)
[ ]:
train_dataloader_params = {
    "shuffle": True,
    "batch_size": BATCH_SIZE,
    "drop_last": True,
    "pin_memory": False,
    "collate_fn": YoloNASPoseCollateFN(),
    "num_workers": NUM_WORKERS,
    "persistent_workers": NUM_WORKERS > 0,
}
val_dataloader_params = {
    "shuffle": False,
    "batch_size": BATCH_SIZE,
    "drop_last": True,
    "pin_memory": False,
    "collate_fn": YoloNASPoseCollateFN(),
    "num_workers": NUM_WORKERS,
    "persistent_workers": NUM_WORKERS > 0,
}

train_dataloader = DataLoader(train_dataset, **train_dataloader_params)
val_dataloader = DataLoader(val_dataset, **val_dataloader_params)

Train model¶

[ ]:
yolo_nas_pose = models.get(
    MODEL_NAME,
    num_classes=20,
    checkpoint_path=MODEL_PATH.as_posix(),
    checkpoint_num_classes=17,
).cuda()

trainer = Trainer(experiment_name=RUN_NAME, ckpt_root_dir=DOWNLOAD_PATH + "/sg-checkpoints")
[ ]:
trainer.train(
    model=yolo_nas_pose, training_params=training_params, train_loader=train_dataloader, valid_loader=val_dataloader
)