View source Download .ipynb

Collect SAM image embeddings¶

In this example, we will show how to create a Run containing embeddings extracted from SAM for a set of images.

image1

Setup project¶

[ ]:
PROJECT_NAME = "3LC Tutorials - COCO128"
MODEL_TYPE = "vit_b"
MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
DOWNLOAD_PATH = "../../transient_data"
EMBEDDING_DIM = 3
REDUCTION_METHOD = "umap"
BATCH_SIZE = 4

Install dependencies¶

[ ]:
%pip install 3lc[umap]
%pip install git+https://github.com/facebookresearch/segment-anything
%pip install git+https://github.com/3lc-ai/3lc-examples

Imports¶

[ ]:
from pathlib import Path

import cv2
import tlc
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide

from tlc_tools.common import infer_torch_device

Download model weights¶

[ ]:
CHECKPOINT = DOWNLOAD_PATH + "/sam_vit_b_01ec64.pth"

if not Path(CHECKPOINT).exists():
    torch.hub.download_url_to_file(MODEL_URL, CHECKPOINT)

Set up model and preprocessing¶

[ ]:
device = infer_torch_device()
[ ]:
def create_model():
    sam_model = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT)
    sam_model.to(device)
    sam_model.eval()
    return sam_model
[ ]:
sam_model = create_model()
RESIZE_TRANSFORM = ResizeLongestSide(sam_model.image_encoder.img_size)
PREPROCESS_TRANSFORM = sam_model.preprocess


def transform_to_sam_format(sample):
    image = cv2.cvtColor(cv2.imread(sample["image"]), cv2.COLOR_BGR2RGB)
    image = RESIZE_TRANSFORM.apply_image(image)
    image = torch.as_tensor(image, device=device).permute(2, 0, 1).contiguous()
    image = PREPROCESS_TRANSFORM(image)

    return {"image": image}

Create 3LC Table and Run¶

[ ]:
# Reuse the COCO128 table from ../1-create-tables/create-table-from-coco and apply the transformation defined above
table = tlc.Table.from_names("initial", "COCO128", PROJECT_NAME).map(transform_to_sam_format)

# Initialize a 3LC Run
run = tlc.init(
    project_name=PROJECT_NAME,
    run_name="Collect SAM embeddings",
    description="Collect embeddings for the COCO128 dataset using the SAM model",
)

Collect embeddings using SAM¶

[ ]:
embeddings_metrics_collector = tlc.EmbeddingsMetricsCollector(layers=[0])

predictor = tlc.Predictor(
    sam_model.image_encoder,
    layers=[0],
    unpack_dicts=True,
    device=device,
)

tlc.collect_metrics(
    table,
    embeddings_metrics_collector,
    predictor,
)

Reduce dimensionality of embeddings¶

[ ]:
run.reduce_embeddings_by_foreign_table_url(
    table.url,
    method=REDUCTION_METHOD,
    n_components=EMBEDDING_DIM,
)