Per Bounding Box Embeddings Example#

This notebook demonstrates how to extract embeddings for bounding boxes in tlc Tables using a pretrained EfficientNet model. The generated embeddings are then reduced in dimensionality and stored as extra columns in a new output Table.

Since the example uses a classification model, we can also extract class probabilities for each bounding box. The predicted labels are also stored as additional columns in the Table.

Info: This notebook demonstrates a technique for adding columns to an existing Table. While this is a useful technique, there are some drawbacks to this approach:

  • The new Table will not be part of the lineage of the input Table.

  • The new Table will contain a literal copy of all data in the input Table

In a future release of 3LC, adding columns to an existing Table will be supported natively.

[2]:
PROJECT_NAME = "Bounding Box Embeddings"
DATASET_NAME = "Balloons"
INSTALL_DEPENDENCIES = False
TRANSIENT_DATA_PATH = "../transient_data"
TEST_DATA_PATH = "../../tests/test_data/data"
TLC_PUBLIC_EXAMPLES_DEVELOPER_MODE = True
[4]:
%%capture
if INSTALL_DEPENDENCIES:
    %pip --quiet install torch --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install torchvision --index-url https://download.pytorch.org/whl/cu118
    %pip --quiet install timm
    %pip --quiet install tlc[umap]
[7]:
from __future__ import annotations

from io import BytesIO

from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import torch
from torchvision import transforms

import tlc

Set Up Input Table#

We will use a TableFromCoco to load the “Balloons” dataset from a annotations file and a folder of images.

[9]:
table_url = tlc.Url.create_table_url(
    project_name=PROJECT_NAME,
    dataset_name=DATASET_NAME,
    table_name="table_from_coco",
)

annotations_file = tlc.Url(TEST_DATA_PATH + "/balloons/train/train-annotations.json").to_absolute()
images_dir = tlc.Url(TEST_DATA_PATH + "/balloons/train").to_absolute()

input_table = tlc.Table.from_coco(
    table_url=table_url,
    annotations_file=annotations_file,
    image_folder=images_dir,
    if_exists="overwrite",
)
[10]:
# Get the schema of the bounding box column of the input table
import json

bb_schema = input_table.rows_schema.values["bbs"].values["bb_list"]
label_map = input_table.get_value_map("bbs.bb_list.label")
print(f"Input table uses {len(label_map)} unique label(s): {json.dumps(label_map, indent=2)}")
Input table uses 1 unique label(s): {
  "0.0": {
    "internal_name": "balloon",
    "display_name": "",
    "description": "",
    "display_color": "",
    "url": ""
  }
}

Initialize the Model#

Now we load the EfficientNet model. If a pretrained model is available locally, it will be loaded. Otherwise, we’ll download a pretrained version.

[11]:
import os

import timm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize a pretrained classifier model
bb_classifier_path = TRANSIENT_DATA_PATH + "/bb_classifier.pth"
if os.path.exists(bb_classifier_path):
    model = timm.create_model("efficientnet_b0", num_classes=2, checkpoint_path=bb_classifier_path).to(device)
    print("Loaded pretrained model")
else:
    print("Downloading pretrained model")
    model = timm.create_model("efficientnet_b0", num_classes=len(label_map), pretrained=True).to(device)

model = model.eval()
Using device: cuda:0
Downloading pretrained model
[12]:
# The hidden layer we will use as embeddings
hidden_layer = model.global_pool.flatten

Collecting Bounding Box Embeddings#

In this section, we’ll walk through the process of extracting embeddings for each bounding box present in our input images.

[13]:
# Image Preprocessing
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Initialize empty lists to store all embeddings and predicted labels
all_embeddings: list[np.ndarray] = []
all_labels: list[int] = []
all_hidden_outputs: list[np.ndarray] = []

# Register a hook to pick up the hidden layer output
output_list: list[torch.Tensor] = []

def hook_fn(module, input, output):
    """Store the output of the hooked layer."""
    output_list.append(output)

hook_handle = hidden_layer.register_forward_hook(hook_fn)

# Batched inference setup
batch_size = 4
mini_batch: list[torch.Tensor] = []
batch_to_image_map: list[int] = []

def run_inference_on_batch(mini_batch: list[torch.Tensor]) -> None:
    mini_batch_tensor = torch.stack(mini_batch).to(device)
    with torch.no_grad():
        mini_batch_embeddings = model(mini_batch_tensor)

    # Collect and clear the hook outputs
    mini_batch_hidden = output_list.pop().cpu().numpy()
    all_hidden_outputs.extend(mini_batch_hidden)

    all_embeddings.extend(mini_batch_embeddings.cpu().numpy())
    mini_batch_labels = torch.argmax(mini_batch_embeddings, dim=1)
    all_labels.extend(mini_batch_labels.cpu().numpy())

for row_idx, row in tqdm(enumerate(input_table.table_rows), total=len(input_table), desc="Running inference on table"):
    image_bbs = row["bbs"]["bb_list"]
    if len(image_bbs) == 0:
        continue
    image_filename = row["image"]
    image_bytes = tlc.Url(image_filename).read()
    image = Image.open(BytesIO(image_bytes))
    w, h = image.size

    for bb in image_bbs:
        bb_crop = tlc.BBCropInterface.crop(image, bb, bb_schema, h, w)
        bb_crop_tensor = preprocess(bb_crop)

        # Check if adding this bb_crop_tensor will overfill the mini_batch
        if len(mini_batch) >= batch_size:
            run_inference_on_batch(mini_batch)
            mini_batch.clear()

        mini_batch.append(bb_crop_tensor)
        batch_to_image_map.append(row_idx)

# Run inference on remaining items in mini_batch if it's not empty
if len(mini_batch) > 0:
    run_inference_on_batch(mini_batch)

# Remove the hook
hook_handle.remove()

Dimensionality Reduction#

Once the embeddings are collected, the next step is to reduce their dimensionality for easier analysis.

[14]:
import umap

all_embeddings_np = np.vstack(all_hidden_outputs)
print(f"UMAP input shape: {all_embeddings_np.shape}")

# Fit UMAP
reducer = umap.UMAP(n_components=3)
embedding_3d = reducer.fit_transform(all_embeddings_np)
UMAP input shape: (255, 1280)

Create a new Table containing the embeddings as an extra column#

Finally, we combine the reduced embeddings and predicted labels with the input Table to write a new Table.

[15]:
# Repack embeddings and labels into groups per image
grouped_embeddings: list[list[np.ndarray]] = [[] for _ in range(len(input_table))]
grouped_labels: list[list[int]] = [[] for _ in range(len(input_table))]

for img_idx, embed, label in zip(batch_to_image_map, embedding_3d, all_labels):
    grouped_labels[img_idx].append(label)
    grouped_embeddings[img_idx].append(embed)

Setup the Schema of the output Table#

[16]:

# Create a schema for the embeddings per_bb_embedding_schema = tlc.Schema( value=tlc.Float32Value(), size0=tlc.DimensionNumericValue(value_min=3, value_max=3), # 3D embedding size1=tlc.DimensionNumericValue(value_min=0, value_max=1000), # Max 1000 bbs per image sample_type="hidden", # Hide this column when iterating over the "sample view" of the table writable=False, ) # Create a schema with a label map for the labels label_value_map = { **label_map, **{len(label_map): tlc.MapElement("background")}, } label_schema = tlc.Schema( value=tlc.Int32Value(value_map=label_value_map), size0 = tlc.DimensionNumericValue(value_min=0, value_max=1000), sample_type="hidden", # Hide this column when iterating over the "sample view" of the table writable=False, ) schemas = { "per_bb_embeddings": per_bb_embedding_schema, "per_bb_labels": label_schema, } schemas.update(input_table.row_schema.values) # Copy over the schemas from the input table

Write the output Table#

We will use a TableWriter to write the output table as a TableFromParquet.

[17]:
from collections import defaultdict

table_writer = tlc.TableWriter(
    project_name=PROJECT_NAME,
    dataset_name=DATASET_NAME,
    table_name="added_embeddings_and_labels",
    column_schemas=schemas,
    if_exists="overwrite",
)

# TableWriter accepts data as a dictionary of column names to lists
data = defaultdict(list)

# Copy over all rows from the input table
for row in input_table.table_rows:
    for column_name, column_value in row.items():
        data[column_name].append(column_value)

# Add the embeddings and labels
data["per_bb_embeddings"] = grouped_embeddings
data["per_bb_labels"] = grouped_labels

table_writer.add_batch(data)
new_table = table_writer.finalize()

Inspect the properties of the output Table#

[18]:
print(len(new_table))
print(new_table.columns)
print(new_table.url.to_relative(input_table.url))
61
['image_id', 'image', 'width', 'height', 'bbs', 'weight', 'per_bb_embeddings', 'per_bb_labels']
../added_embeddings_and_labels