View source Download .ipynb

Train a YOLO model for pose estimation on a custom keypoints dataset¶

This notebook trains a YOLO model for pose estimation on the AnimalPose dataset.

The input Table required for running this notebook is created in create-custom-keypoints-table.ipynb.

image1

Project setup¶

[ ]:
PROJECT_NAME = "3LC Tutorials - 2D Keypoints"
DATASET_NAME = "AnimalPose"
TABLE_NAME = "initial"
NUM_WORKERS = 0
DOWNLOAD_PATH = "../../transient_data"
EPOCHS = 10

Install dependencies¶

[ ]:
%pip install 3lc-ultralytics
%pip install git+https://github.com/3lc-ai/3lc-examples.git

Imports¶

[ ]:
import tlc
from tlc_ultralytics import YOLO, Settings

from tlc_tools.split import split_table

Load and split table¶

[ ]:
initial_table = tlc.Table.from_names(TABLE_NAME, DATASET_NAME, PROJECT_NAME)


def split_by(table_row):
    """Callable to get the label of the first keypoint instance

    This allows us to do a stratified split by label, just like in the original SuperGradients notebook.
    """
    return table_row["keypoints_2d"]["instances_additional_data"]["label"][0]


train_val_test = split_table(
    initial_table,
    splits={"train": 0.8, "val_test": 0.2},
    split_strategy="stratified",
    split_by=split_by,
    random_seed=42,
    shuffle=False,
)

test_val = split_table(
    train_val_test["val_test"],
    splits={"val": 0.5, "test": 0.5},
    split_strategy="stratified",
    split_by=split_by,
    shuffle=False,
    random_seed=42,
)
[ ]:
train_table = train_val_test["train"]
val_table = test_val["val"]
test_table = test_val["test"]
[ ]:
print(initial_table)
print(train_table)
print(val_table)
print(test_table)

Train model¶

[ ]:
model = YOLO("yolo11n-pose.pt")

settings = Settings(
    project_name=PROJECT_NAME,
    run_name="train-yolon-animalpose",
    run_description="Training a YOLO model for pose estimation on the AnimalPose dataset",
    collect_loss=True,
    image_embeddings_dim=2,
)

model.train(
    tables={"train": train_table, "val": val_table},
    epochs=EPOCHS,
    workers=NUM_WORKERS,
)