View source
Download
.ipynb
Train a YOLO-NAS model for pose estimation with SuperGradients¶
This tutorial trains a SuperGradients YOLO-NAS model for pose estimation the AnimalPose dataset.
The input Table required for running this notebook is created in create-custom-keypoints-table.ipynb.

This notebook is a modified version of the SuperGradients YoloNAS Pose Fine Tuning Notebook.
Install dependencies¶
[ ]:
%pip install 3lc
%pip install super-gradients
%pip install termcolor==3.1.0
%pip install git+https://github.com/3lc-ai/3lc-examples.git
Project setup¶
[ ]:
PROJECT_NAME = "3LC Tutorials - 2D Keypoints"
DATASET_NAME = "AnimalPose"
TABLE_NAME = "initial"
MODEL_NAME = "yolo_nas_pose_n"
RUN_NAME = "fine-tune-yolo-nas-pose-n-animalpose"
BATCH_SIZE = 16
NUM_WORKERS = 0
DOWNLOAD_PATH = "../../transient_data"
MAX_EPOCHS = 10
IMAGE_SIZE = 640
Imports¶
[ ]:
import requests
from super_gradients.training import Trainer, models
from super_gradients.training.datasets.pose_estimation_datasets import YoloNASPoseCollateFN
from super_gradients.training.metrics import PoseEstimationMetrics
from super_gradients.training.models.pose_estimation_models.yolo_nas_pose import YoloNASPosePostPredictionCallback
from super_gradients.training.transforms.keypoints import (
KeypointsBrightnessContrast,
KeypointsHSV,
KeypointsImageStandardize,
KeypointsLongestMaxSize,
KeypointsPadIfNeeded,
KeypointsRandomAffineTransform,
KeypointsRandomHorizontalFlip,
KeypointsRemoveSmallObjects,
)
from super_gradients.training.utils.callbacks import Callback
from tlc.core import KeypointHelper, Table
from tlc.integration.super_gradients import PoseEstimationDataset, PoseEstimationMetricsCollectionCallback
from torch.utils.data import DataLoader
from torchmetrics.metric import Metric
from tlc_tools.split import split_table
Download pretrained model¶
[ ]:
from pathlib import Path
MODEL_PATH = Path(DOWNLOAD_PATH) / "yolo_nas_pose_n_coco_pose.pth"
if not MODEL_PATH.exists():
MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
response = requests.get("https://sg-hub-nv.s3.amazonaws.com/models/yolo_nas_pose_n_coco_pose.pth")
MODEL_PATH.write_bytes(response.content)
Load and split input tables¶
[ ]:
initial_table = Table.from_names(TABLE_NAME, DATASET_NAME, PROJECT_NAME)
def split_by(table_row):
"""Callable to get the label of the first keypoint instance
This allows us to do a stratified split by label, just like in the original SuperGradients notebook.
"""
return table_row["keypoints_2d"]["instances_additional_data"]["label"][0]
train_val_test = split_table(
initial_table,
splits={"train": 0.8, "val_test": 0.2},
split_strategy="stratified",
split_by=split_by,
random_seed=42,
shuffle=False,
)
test_val = split_table(
train_val_test["val_test"],
splits={"val": 0.5, "test": 0.5},
split_strategy="stratified",
split_by=split_by,
shuffle=False,
random_seed=42,
)
[ ]:
train_table = train_val_test["train"]
val_table = test_val["val"]
test_table = test_val["test"]
[ ]:
print(initial_table)
print(train_table)
print(val_table)
print(test_table)
Prepare for training¶
[ ]:
def create_transforms(image_size: int, flip_indices: list[int]):
keypoints_random_horizontal_flip = KeypointsRandomHorizontalFlip(flip_index=flip_indices, prob=0.5)
keypoints_hsv = KeypointsHSV(prob=0.5, hgain=20, sgain=20, vgain=20)
keypoints_brightness_contrast = KeypointsBrightnessContrast(
prob=0.5, brightness_range=[0.8, 1.2], contrast_range=[0.8, 1.2]
)
keypoints_random_affine_transform = KeypointsRandomAffineTransform(
max_rotation=0,
min_scale=0.5,
max_scale=1.5,
max_translate=0.1,
image_pad_value=127,
mask_pad_value=1,
prob=0.75,
interpolation_mode=[0, 1, 2, 3, 4],
)
keypoints_longest_max_size = KeypointsLongestMaxSize(max_height=image_size, max_width=image_size)
keypoints_pad_if_needed = KeypointsPadIfNeeded(
min_height=image_size,
min_width=image_size,
image_pad_value=[127, 127, 127],
mask_pad_value=1,
padding_mode="bottom_right",
)
keypoints_image_standardize = KeypointsImageStandardize(max_value=255)
keypoints_remove_small_objects = KeypointsRemoveSmallObjects(min_instance_area=1, min_visible_keypoints=1)
train_transforms = [
keypoints_random_horizontal_flip,
keypoints_hsv,
keypoints_brightness_contrast,
keypoints_random_affine_transform,
keypoints_longest_max_size,
keypoints_pad_if_needed,
keypoints_image_standardize,
keypoints_remove_small_objects,
]
val_transforms = [
keypoints_longest_max_size,
keypoints_pad_if_needed,
keypoints_image_standardize,
]
return train_transforms, val_transforms
def create_training_params(max_epochs: int, callbacks: list[Callback], metrics: list[Metric], oks_sigmas: list[float]):
return {
"seed": 42,
"warmup_mode": "LinearBatchLRWarmup",
"warmup_initial_lr": 1e-8,
"lr_warmup_epochs": 2,
"initial_lr": 5e-4,
"lr_mode": "cosine",
"cosine_final_lr_ratio": 0.05,
"max_epochs": max_epochs,
"zero_weight_decay_on_bias_and_bn": True,
"batch_accumulate": 1,
"average_best_models": False,
"save_ckpt_epoch_list": [],
"loss": "yolo_nas_pose_loss",
"criterion_params": {
"oks_sigmas": oks_sigmas,
"classification_loss_weight": 1.0,
"classification_loss_type": "focal",
"regression_iou_loss_type": "ciou",
"iou_loss_weight": 2.5,
"dfl_loss_weight": 0.01,
"pose_cls_loss_weight": 1.0,
"pose_reg_loss_weight": 34.0,
"pose_classification_loss_type": "focal",
"rescale_pose_loss_with_assigned_score": True,
"assigner_multiply_by_pose_oks": True,
},
"optimizer": "AdamW",
"optimizer_params": {"weight_decay": 0.000001},
"ema": True,
"ema_params": {"decay": 0.997, "decay_type": "threshold"},
"mixed_precision": True,
"sync_bn": False,
"valid_metrics_list": metrics,
"phase_callbacks": callbacks,
"pre_prediction_callback": None,
"metric_to_watch": "AP",
"greater_metric_to_watch_is_better": True,
}
[ ]:
flip_indices = KeypointHelper.get_flip_indices_from_table(initial_table)
oks_sigmas = KeypointHelper.get_oks_sigmas_from_table(initial_table)
train_transforms, val_transforms = create_transforms(image_size=IMAGE_SIZE, flip_indices=flip_indices)
[ ]:
train_dataset = PoseEstimationDataset(train_table, transforms=train_transforms)
val_dataset = PoseEstimationDataset(val_table, transforms=val_transforms)
[ ]:
post_prediction_callback = YoloNASPosePostPredictionCallback(
pose_confidence_threshold=0.01,
nms_iou_threshold=0.7,
pre_nms_max_predictions=100,
post_nms_max_predictions=15,
)
pose_estimation_metrics = PoseEstimationMetrics(
num_joints=train_dataset.num_joints,
oks_sigmas=oks_sigmas,
max_objects_per_image=15,
post_prediction_callback=post_prediction_callback,
)
tlc_callback = PoseEstimationMetricsCollectionCallback(project_name=PROJECT_NAME, run_name=RUN_NAME)
[ ]:
training_params = create_training_params(
max_epochs=MAX_EPOCHS,
callbacks=[tlc_callback],
metrics=[pose_estimation_metrics],
oks_sigmas=oks_sigmas,
)
[ ]:
train_dataloader_params = {
"shuffle": True,
"batch_size": BATCH_SIZE,
"drop_last": True,
"pin_memory": False,
"collate_fn": YoloNASPoseCollateFN(),
"num_workers": NUM_WORKERS,
"persistent_workers": NUM_WORKERS > 0,
}
val_dataloader_params = {
"shuffle": False,
"batch_size": BATCH_SIZE,
"drop_last": True,
"pin_memory": False,
"collate_fn": YoloNASPoseCollateFN(),
"num_workers": NUM_WORKERS,
"persistent_workers": NUM_WORKERS > 0,
}
train_dataloader = DataLoader(train_dataset, **train_dataloader_params)
val_dataloader = DataLoader(val_dataset, **val_dataloader_params)
Train model¶
[ ]:
yolo_nas_pose = models.get(
MODEL_NAME,
num_classes=20,
checkpoint_path=MODEL_PATH.as_posix(),
checkpoint_num_classes=17,
).cuda()
trainer = Trainer(experiment_name=RUN_NAME, ckpt_root_dir=DOWNLOAD_PATH + "/sg-checkpoints")
[ ]:
trainer.train(
model=yolo_nas_pose, training_params=training_params, train_loader=train_dataloader, valid_loader=val_dataloader
)