View source Download .ipynb

Convert bounding boxes to instance segmentation masks using SAM¶

This notebook creates a derived Table with an added column containing instance segmentation masks generated by the SAM model using the Table’s existing bounding boxes as prompts.

img

Install dependencies¶

[ ]:
%pip install 3lc
%pip install git+https://github.com/3lc-ai/3lc-examples.git
%pip install git+https://github.com/facebookresearch/segment-anything
%pip install matplotlib

Imports¶

[ ]:
from pathlib import Path

import requests
import tlc
from tqdm import tqdm

from tlc_tools.sam_autosegment import bbs_to_segments

Project Setup¶

[ ]:
PROJECT_NAME = "3LC Tutorials - COCO128"
MODEL_TYPE = "vit_b"
DOWNLOAD_PATH = "../../transient_data"

Load the input table¶

Load the input table, created in create-table-from-coco.ipynb.

[ ]:
input_table = tlc.Table.from_names("initial", "COCO128", PROJECT_NAME)

Download the SAM model checkpoint¶

[ ]:
checkpoint_path = Path(DOWNLOAD_PATH) / "sam_vit_b_01ec64.pth"

if not checkpoint_path.exists():
    model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"

    print(f"Downloading SAM checkpoint to {checkpoint_path}...")
    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
    response = requests.get(model_url, stream=True)
    total_size = int(response.headers.get("content-length", 0))
    with (
        open(checkpoint_path, "wb") as f,
        tqdm(
            desc="Downloading",
            total=total_size,
            unit="B",
            unit_scale=True,
            unit_divisor=1024,
        ) as bar,
    ):
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                f.write(chunk)
                bar.update(len(chunk))
    print("Download completed.")
else:
    print(f"Checkpoint already exists at {checkpoint_path}.")

Run SAM model¶

[ ]:
out_table = bbs_to_segments(
    input_table,
    sam_model_type=MODEL_TYPE,
    checkpoint=checkpoint_path.as_posix(),
    description="Added segmentation column from bounding boxes",
)

Visualize an example mask¶

[ ]:
example_mask = out_table[3]["segments"]["masks"]
print(example_mask.shape)
[ ]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
plt.imshow(example_mask[:, :, 0], cmap="gray")
plt.axis("off")
plt.show()