Embeddings#

This article provides a deeper look at how embeddings are collected, processed, and visualized in 3LC.

Dimensionality reduced embeddings collected from the CIFAR-10 dataset. The embeddings were generated by a ResNet-18 model and reduced to 3D using PaCMAP. We can clearly see the separation of classes in the reduced space, indicating that the model has learned to distinguish between different classes. Selecting points from the regions between the cluster centers can help identify misclassified, under-represented, or difficult samples.

What are Neural Network Embeddings?#

In the context of 3LC, embeddings refer to the activations from specific layers in neural network models. These activations represent the model’s internal states and can be used to understand how the model processes input data.

Embeddings can be:

  • Activations of arbitrary layers in a neural network: intermediate representations of data as it passes through the network.

  • Output of purpose-built embedding models/layers: e.g Word2Vec or BERT for text data.

For a more general introduction to embeddings in Machine Learning, see for example the Google ML Foundational Course, which contains relevant background for understanding embeddings in the context of 3LC.

Importance of Embeddings in Machine Learning#

Embeddings are crucial because they:

  • Extract Features: Capture relevant features of input data.

  • Enable Data Visualization: Dimensionality reduction allows for visualization of high-dimensional data, aiding in understanding model behavior and data distribution.

  • Facilitate Clustering and Classification: Help identify clusters, outliers, and patterns in data, valuable for classification and anomaly detection tasks.

Collecting Embeddings in 3LC#

3LC provides tools to collect, manage, and visualize embeddings from neural networks.

In 3LC, embeddings are handled just as any other per-sample metric. Refer to the metrics collection user guide for more information on collecting metrics in general, or see the CIFAR-10 Notebook for a full example of training a network while collecting embeddings using a metrics collector.

Identifying Embedding Layers#

Identifying which layers of your model produce useful embeddings depends on the model architecture and the task. Some common examples include:

  • Image Classification: The next-to-last layer before the classification layer is typically a good choice.

  • Object Detection: The last layer of the backbone network or bounding-box predictor layers (e.g., heads predicting bounding box coordinates and class scores) are often suitable.

3LC uses the linear index into a module’s named_modules attribute to identify layers.

To see all modules with their corresponding indices, you can use the following code snippet:

import timm

model = timm.create_model("resnet18", pretrained=False, num_classes=2, in_chans=1)

for i, (name, module) in enumerate(model.named_modules()):
    print(f"{i}: {name}")

# Output:
# 0: 
# 1: conv1
# 2: bn1
# ...
# 90: global_pool
# 91: global_pool.pool
# 92: global_pool.flatten
# 93: fc

Discovering Embedding Sizes#

Once you have determined the indices of the layers you want to collect embeddings from, you need to understand the shapes of the embeddings produced by these layers. Some layers will produce embeddings of fixed size, while others may produce embeddings of varying sizes depending on the input data.

Using the torchinfo package, you can discover the shapes of various embeddings in a model, given an input size.

import torchinfo
import timm

model = timm.create_model("resnet18", pretrained=False, num_classes=2, in_chans=1)
torchinfo.summary(model, input_size=(1, 1, 64, 64))

# Output:
# ==========================================================================================
# Layer (type:depth-idx)                   Output Shape              Param #
# ==========================================================================================
# ResNet                                   [1, 2]                    --
# ├─Conv2d: 1-1                            [1, 64, 32, 32]           3,136
# ├─BatchNorm2d: 1-2                       [1, 64, 32, 32]           128
# ├─ReLU: 1-3                              [1, 64, 32, 32]           --
# ├─MaxPool2d: 1-4                         [1, 64, 16, 16]           --
# ├─Sequential: 1-5                        [1, 64, 16, 16]           --
# │    └─BasicBlock: 2-1                   [1, 64, 16, 16]           --
# ...
# ├─Sequential: 1-8                        [1, 512, 2, 2]            --
# │    └─BasicBlock: 2-7                   [1, 512, 2, 2]            --
# │    │    └─Conv2d: 3-51                 [1, 512, 2, 2]            1,179,648
# │    │    └─BatchNorm2d: 3-52            [1, 512, 2, 2]            1,024
# │    │    └─Identity: 3-53               [1, 512, 2, 2]            --
# │    │    └─ReLU: 3-54                   [1, 512, 2, 2]            --
# │    │    └─Identity: 3-55               [1, 512, 2, 2]            --
# │    │    └─Conv2d: 3-56                 [1, 512, 2, 2]            2,359,296
# │    │    └─BatchNorm2d: 3-57            [1, 512, 2, 2]            1,024
# │    │    └─Sequential: 3-58             [1, 512, 2, 2]            132,096
# │    │    └─ReLU: 3-59                   [1, 512, 2, 2]            --
# │    └─BasicBlock: 2-8                   [1, 512, 2, 2]            --
# │    │    └─Conv2d: 3-60                 [1, 512, 2, 2]            2,359,296
# │    │    └─BatchNorm2d: 3-61            [1, 512, 2, 2]            1,024
# │    │    └─Identity: 3-62               [1, 512, 2, 2]            --
# │    │    └─ReLU: 3-63                   [1, 512, 2, 2]            --
# │    │    └─Identity: 3-64   

Reshape Strategies#

After embeddings have been collected, they must be converted to a fixed size before further processing. This is achieved by reshaping the embeddings, converting them from multi-dimensional tensors to one-dimensional vectors. Note that this is not the same as dimensionality reduction, rather a required step to ensure the embedding of each sample is a 1D vector of fixed size.

We provide several strategies for reshaping embeddings, which are passed as an argument to the EmbeddingsMetricsCollector class:

  • mean: Take the mean across all non-first dimensions (excluding the batch dimension).

  • flatten: Flatten all dimensions after the batch dimension.

  • avg_pool_1_1, avg_pool_2_2, avg_pool_3_3: Use average pooling with a given output size to ensure consistent shapes.

  • A custom function that takes the embedding tensor as input and returns a reshaped tensor.

Explicit reshaping might not be necessary if the embeddings are already of fixed size, or if all inputs to the model have the same size. In this case, the flatten strategy, will ensure the output is one dimensional.

Embedding sizes

Be aware that collecting very large embedding layers can lead to high memory usage and slow performance. It is recommended to collect embeddings from layers with a manageable size, typically in the 100s to low 1000s of elements. Using mean, avg_pool, or custom reshape strategies can help reduce the size of the embeddings before storage. The embeddings will be stored on disk, but can optionally be deleted during dimensionality reduction to save space.

Embeddings collected using the EmbeddingsMetricsCollector are automatically assigned the number role “nn_embedding”. This currently has two side-effects: these columns are not sent to the 3LC Dashboard to save bandwidth, and they are automatically selected as input to dimensionality reduction methods.

Collection Examples#

Below is an example of collecting embeddings from a model and reshaping them using the EmbeddingsMetricsCollector:

"""Collect embeddings from layers 10 and 11 of a model"""
import torch
import tlc

table = ... # a tlc.Table representing the input data
model = ... # a torch.nn.Module
collector = tlc.EmbeddingsMetricsCollector(
    layers=[10, 11],
    reshape_strategy={
        10: "flatten",
        11: "avg_pool_2_2",
    },
)

# Wrap the model in a Predictor object in order to collect intermediate activations
predictor = tlc.Predictor(model, layers=[10, 11])

# Collect embeddings. The embeddings will be stored in a metrics table associated with the current active Run.
tlc.collect_metrics(table, collector, predictor)

For more details on tlc.Predictor, and how to combine data from tables with models, see the Predictor user guide.

Next we show an example of collecting embeddings from a pretrained ViT model, using a custom function to reshape the embeddings.

"""Collect embeddings from the `layernorm` layer of ViT"""
from transformers import ViTModel
import torch
import tlc

table: tlc.Table = ... # A table returning suitable input data (image-tensors of shape (3, 224, 224))
dataloader = torch.utils.data.DataLoader(table, batch_size=4)
model = ViTModel.from_pretrained("google/vit-base-patch16-224")
predictor = tlc.Predictor(model, layers=[223])  # Index of the `layernorm` layer

def reshape_function(layer_output):
    """Extract embeddings corresponding to the [CLS] token from the `layernorm` layer of a Vision Transformer model.

    This function is designed to specifically extract the embeddings corresponding to the [CLS] token. In ViT models, the
    [CLS] token is typically used as a summary representation of the entire image. By slicing with `layer_output[:, 0, :]`,
    we select the embeddings of the [CLS] token from the layer's output, which has dimensions (batch_size, sequence_length,
    hidden_size). This operation produces a tensor of shape (batch_size, hidden_size), making it suitable for further
    analysis or downstream tasks. For this model, hidden size is 768."""
    return layer_output[:, 0, :]

collector = tlc.EmbeddingsMetricsCollector(
    layers=[223],
    reshape_strategy={
        223: reshape_function,
    },
)

# Collect embeddings from the first batch only:
for batch in dataloader:
    predictor_output = predictor(batch)
    collected_embeddings = collector(batch, predictor_output)
    break

print(collected_embeddings)

# Output:
# {'embeddings_223': tensor([[ 0.6929, -0.1778,  1.4126,  ..., -0.2675, -0.2489,  0.6796],
#         [-0.0734, -0.2110,  0.3797,  ...,  0.5803,  0.7378,  0.1924],
#         [-1.5251, -1.0982,  0.9778,  ...,  1.3487, -0.2712, -0.5566],
#         [-0.3959, -0.5170,  0.3551,  ..., -0.4861, -0.4202, -0.7321]],
#        device='cuda:0')}

Dimensionality Reduction in 3LC#

An example of dimensionality reduction applied to a 3D point cloud. In the top left is the original 3D point cloud. The two panels on the right show two different parameterizations of the PaCMAP algorithm applied to the point cloud. The bottom left panel shows the UMAP algorithm applied to the point cloud. This example illustrates the importance of not only which algorithm to use, but also the choice of hyperparameters, which can have a significant impact on the usefulness of the reduced representation.

Dimensionality reduction techniques help transform high-dimensional embeddings into lower-dimensional spaces for easier visualization and analysis.

Importance of Dimensionality Reduction#

Dimensionality reduction is important because it:

  • Improves Visualization: High-dimensional data is challenging to visualize; reducing dimensions helps in plotting and interpreting data.

  • Enhances Clustering and Classification: Reduced dimensions can reveal inherent clusters and relationships within the data.

  • Simplifies Models: Lower-dimensional representations can lead to simpler models with faster training times and reduced risk of overfitting.

Reducing Embeddings#

3LC currently supports two popular methods for dimensionality reduction:

  • UMAP: Uniform Manifold Approximation and Projection reduces the dimensionality of data by constructing a high-dimensional graph and optimizing a low-dimensional graph to preserve the structure.

  • PACMAP: Pairwise Controlled Manifold Approximation optimizes a graph-based objective function to preserve both local and global data structures by balancing different pairwise distances during dimensionality reduction.

All dimensionality reduction algorithms in 3LC are implemented as Table types, operating on an input table and producing a new table containing the reduced data. These tables provide a common set of parameters for controlling the reduction process, as well as parameters specific to each algorithm.

Common parameters to all dimensionality reduction tables:

Parameter

Description

source_embedding_column

The name of the column containing the data to be reduced (must be 1-dimensional vectors)

target_embedding_column

The name of the column in the output table containing the dimensionality reduced data

retain_source_embedding_column

Whether the input column should be present in the output table

While the low-level implementation of dimensionality reduction is implemented in the UMAPTable and PaCMAPTable table types, in practice you will use one of the higher level helper functions in the reduce module, or through Run object’s Run.reduce_embeddings_by_foreign_table_url() and Run.reduce_embeddings_per_dataset().

When embeddings are collected during training, using a metrics collector, the result is a sequence of metric tables containing the embeddings from the model across iterations. The Run-methods mentioned above are designed to reuse the same reduction model across multiple metric tables in specific ways. This allows e.g. the reducer to be fit on the final epoch embeddings and then applied to all intermediate embeddings. This is useful for visualizing how the embeddings evolve over time.

For a overview of method-specific arguments, see classes PaCMAPTableArgs, and UMapTableArgs.

Reduction Examples#

We demonstrate dimensionality reduction with a few examples.

"""Reduce a column containing embeddings using UMAP"""
import numpy as np
import tlc

# Create a table with random embeddings (100 embeddings of size 100)
table = tlc.Table.from_dict(
    data={
        "embedding": np.random.rand(100, 100).tolist(),
    },
    table_name="embeddings",
    description="A table with random embeddings",
    add_weight_column=False,
)

print(f"Created embedding table with columns: {table.columns}")

# Create a reducer to reduce the embeddings to 2D using UMAP
umap_args = {
    "n_components": 2,
    "n_neighbors": 5,
    "retain_source_embedding_column": True,
}
reducer = tlc.create_reducer("umap", umap_args)

# Reduce the table
reduced_table_url = reducer.fit_reduction_method(table, "embedding")
reduced_table = tlc.Table.from_url(reduced_table_url)

print(f"Created reduced table with columns: {reduced_table.columns}")

# Output:
# Created embedding table with columns: ['embedding']
# Created reduced table with columns: ['embedding', 'embedding_umap']
"""Reduce all embedding-columns in a Run's metrics tables using PaCMAP"""
import tlc

train_table: tlc.Table = ... # A table containing the training data
val_table: tlc.Table = ... # A table containing the validation data
run: tlc.Run = ... # A Run object containing metrics tables with embeddings

# Let's assume embeddings were collected across multiple epochs on both splits
# It is often desirable to see how points move through embedding-space over time, and in order
# to accomplish this, we need to re-use the same reduction model across several metrics tables.

# We can choose between two high-level APIs for reducing all metrics tables in a Run:

# 1. Reduce embeddings using the most recent embeddings produced by a specific table revision
#    This method identifies a single metrics table to fit the reduction model. This model is applied
#    on all other metrics tables in the Run. We set delete_source_tables=True to delete the source
#    tables after reduction, in order to save disk-space.

url_mapping = run.reduce_embeddings_by_foreign_table_url(train_table.url, delete_source_tables=True)

# 2. Reduce embeddings using the most recent embeddings originating from the same table
#    This method identifies the most recent metrics table for each dataset in the Run,
#    and fits reduction models on these tables. These models are then applied to all other
#    metrics tables corresponding to the same table.

url_mapping = run.reduce_embeddings_per_dataset(delete_source_tables=True)

# The url_mapping dictionary contains the mapping between the original metrics table URLs
# and the URLs of the reduced tables.