Per Bounding Box Embeddings Example#

This notebook demonstrates how to extract embeddings for bounding boxes in tlc Tables using a pretrained EfficientNet model. The generated embeddings are then reduced in dimensionality and stored as extra columns in a new output Table.

Since the example uses a classification model, we can also extract class probabilities for each bounding box. The predicted labels are also stored as additional columns in the Table.

Project Setup#

[2]:
PROJECT_NAME = "Bounding Box Embeddings"
DATASET_NAME = "Balloons"
INSTALL_DEPENDENCIES = False
TRANSIENT_DATA_PATH = "../transient_data"
TEST_DATA_PATH = "./data"
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
[4]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install timm
    %pip --quiet install 3lc[umap]

Imports#

[7]:
from __future__ import annotations

from io import BytesIO

import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm

import tlc

Set Up Input Table#

We will use a TableFromCoco to load the “Balloons” dataset from a annotations file and a folder of images.

[8]:
table_url = tlc.Url.create_table_url(
    project_name=PROJECT_NAME,
    dataset_name=DATASET_NAME,
    table_name="table_from_coco",
)

annotations_file = tlc.Url(TEST_DATA_PATH + "/balloons/train/train-annotations.json").to_absolute()
images_dir = tlc.Url(TEST_DATA_PATH + "/balloons/train").to_absolute()

input_table = tlc.Table.from_coco(
    table_url=table_url,
    annotations_file=annotations_file,
    image_folder=images_dir,
    description="Balloons train split from COCO annotations",
    if_exists="overwrite",
)
[9]:
# Get the schema of the bounding box column of the input table
import json

bb_schema = input_table.rows_schema.values["bbs"].values["bb_list"]
label_map = input_table.get_value_map("bbs.bb_list.label")
print(f"Input table uses {len(label_map)} unique label(s): {json.dumps(label_map, indent=2)}")
Input table uses 1 unique label(s): {
  "0.0": {
    "internal_name": "balloon",
    "display_name": "",
    "description": "",
    "display_color": "",
    "url": ""
  }
}

Initialize the Model#

Now we load the EfficientNet model. If a pretrained model is available locally, it will be loaded. Otherwise, we’ll download a pretrained version.

[10]:
if torch.cuda.is_available():
    device = "cuda:0"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

device = torch.device(device)
print(f"Using device: {device}")
Using device: cuda:0
[11]:
import os

import timm

# Initialize a pretrained classifier model
bb_classifier_path = TRANSIENT_DATA_PATH + "/bb_classifier.pth"
if os.path.exists(bb_classifier_path):
    model = timm.create_model("efficientnet_b0", num_classes=2, checkpoint_path=bb_classifier_path).to(device)
    print("Loaded pretrained model")
else:
    print("Downloading pretrained model")
    model = timm.create_model("efficientnet_b0", num_classes=len(label_map), pretrained=True).to(device)

model = model.eval()
Downloading pretrained model
[12]:
# The hidden layer we will use as embeddings
hidden_layer = model.global_pool.flatten

Collecting Bounding Box Embeddings#

In this section, we’ll walk through the process of extracting embeddings for each bounding box present in our input images.

[13]:
# Image Preprocessing
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Initialize empty lists to store all embeddings and predicted labels
all_embeddings: list[np.ndarray] = []
all_labels: list[int] = []
all_hidden_outputs: list[np.ndarray] = []

# Register a hook to pick up the hidden layer output
output_list: list[torch.Tensor] = []


def hook_fn(module, layer_input, layer_output):
    """Store the output of the hooked layer."""
    output_list.append(layer_output)


hook_handle = hidden_layer.register_forward_hook(hook_fn)

# Batched inference setup
batch_size = 4
mini_batch: list[torch.Tensor] = []
batch_to_image_map: list[int] = []


def run_inference_on_batch(mini_batch: list[torch.Tensor]) -> None:
    mini_batch_tensor = torch.stack(mini_batch).to(device)
    with torch.no_grad():
        mini_batch_embeddings = model(mini_batch_tensor)

    # Collect and clear the hook outputs
    mini_batch_hidden = output_list.pop().cpu().numpy()
    all_hidden_outputs.extend(mini_batch_hidden)

    all_embeddings.extend(mini_batch_embeddings.cpu().numpy())
    mini_batch_labels = torch.argmax(mini_batch_embeddings, dim=1)
    all_labels.extend(mini_batch_labels.cpu().numpy())


for row_idx, row in tqdm(enumerate(input_table), total=len(input_table), desc="Running inference on table"):
    image_bbs = row["bbs"]["bb_list"]
    if len(image_bbs) == 0:
        continue
    image_filename = row["image"]
    image_bytes = tlc.Url(image_filename).read()
    image = Image.open(BytesIO(image_bytes))
    w, h = image.size

    for bb in image_bbs:
        bb_crop = tlc.BBCropInterface.crop(image, bb, bb_schema, h, w)
        bb_crop_tensor = preprocess(bb_crop)

        # Check if adding this bb_crop_tensor will overfill the mini_batch
        if len(mini_batch) >= batch_size:
            run_inference_on_batch(mini_batch)
            mini_batch.clear()

        mini_batch.append(bb_crop_tensor)
        batch_to_image_map.append(row_idx)

# Run inference on remaining items in mini_batch if it's not empty
if len(mini_batch) > 0:
    run_inference_on_batch(mini_batch)

# Remove the hook
hook_handle.remove()

Dimensionality Reduction#

Once the embeddings are collected, the next step is to reduce their dimensionality for easier analysis.

[14]:
import umap

all_embeddings_np = np.vstack(all_hidden_outputs)
print(f"UMAP input shape: {all_embeddings_np.shape}")

# Fit UMAP
reducer = umap.UMAP(n_components=3)
embedding_3d = reducer.fit_transform(all_embeddings_np)
UMAP input shape: (255, 1280)

Create a new Table containing the embeddings as an extra column#

Finally, we combine the reduced embeddings and predicted labels with the input Table to write a new Table.

[15]:
# Repack embeddings and labels into groups per image
grouped_embeddings: list[list[np.ndarray]] = [[] for _ in range(len(input_table))]
grouped_labels: list[list[int]] = [[] for _ in range(len(input_table))]

for img_idx, embed, label in zip(batch_to_image_map, embedding_3d, all_labels):
    grouped_labels[img_idx].append(label)
    grouped_embeddings[img_idx].append(embed)

Setup the Schema of the output Table#

[16]:
# Create a schema for the embeddings
per_bb_embedding_schema = tlc.Schema(
    value=tlc.Float32Value(number_role=tlc.NUMBER_ROLE_XYZ_COMPONENT),
    size0=tlc.DimensionNumericValue(value_min=3, value_max=3),  # 3D embedding
    size1=tlc.DimensionNumericValue(value_min=0, value_max=1000),  # Max 1000 bbs per image
    sample_type="hidden",  # Hide this column when iterating over the "sample view" of the table
    writable=False,
)

# Create a schema with a label map for the labels
label_value_map = {
    **label_map,
    len(label_map): tlc.MapElement("background"),
}

label_schema = tlc.Schema(
    value=tlc.Int32Value(value_map=label_value_map),
    size0=tlc.DimensionNumericValue(value_min=0, value_max=1000),
    sample_type="hidden",  # Hide this column when iterating over the "sample view" of the table
    writable=False,
)

schemas = {
    "per_bb_embeddings": per_bb_embedding_schema,
    "per_bb_labels": label_schema,
}
schemas.update(input_table.row_schema.values)  # Copy over the schemas from the input table

Write the output Table#

We will use a TableWriter to write the output table as a TableFromParquet.

[17]:
from collections import defaultdict

table_writer = tlc.TableWriter(
    project_name=PROJECT_NAME,
    dataset_name=DATASET_NAME,
    description="Bounding box embeddings and labels",
    table_name="added_embeddings_and_labels",
    column_schemas=schemas,
    if_exists="overwrite",
    input_tables=[input_table.url],
)

# TableWriter accepts data as a dictionary of column names to lists
data = defaultdict(list)

# Copy over all rows from the input table
for row in input_table.table_rows:
    for column_name, column_value in row.items():
        data[column_name].append(column_value)

# Add the embeddings and labels
data["per_bb_embeddings"] = grouped_embeddings
data["per_bb_labels"] = grouped_labels

table_writer.add_batch(data)
new_table = table_writer.finalize()

Inspect the properties of the output Table#

[18]:
print(len(new_table))
print(new_table.columns)
print(new_table.url.to_relative(input_table.url))
61
['image_id', 'image', 'width', 'height', 'bbs', 'weight', 'per_bb_embeddings', 'per_bb_labels']
../added_embeddings_and_labels