Embeddings#
This article provides a deeper look at how embeddings are collected, processed, and visualized in 3LC.
What are Neural Network Embeddings?#
In the context of 3LC, embeddings refer to the activations from specific layers in neural network models. These activations represent the model’s internal states and can be used to understand how the model processes input data.
Embeddings can be:
Activations of arbitrary layers in a neural network: intermediate representations of data as it passes through the network.
Output of purpose-built embedding models/layers: e.g Word2Vec or BERT for text data.
For a more general introduction to embeddings in Machine Learning, see for example the Google ML Foundational Course, which contains relevant background for understanding embeddings in the context of 3LC.
Importance of Embeddings in Machine Learning#
Embeddings are crucial because they:
Extract Features: Capture relevant features of input data.
Enable Data Visualization: Dimensionality reduction allows for visualization of high-dimensional data, aiding in understanding model behavior and data distribution.
Facilitate Clustering and Classification: Help identify clusters, outliers, and patterns in data, valuable for classification and anomaly detection tasks.
Collecting Embeddings in 3LC#
3LC provides tools to collect, manage, and visualize embeddings from neural networks.
In 3LC, embeddings are handled just as any other per-sample metric. Refer to the metrics collection user guide for more information on collecting metrics in general, or see the CIFAR-10 Notebook for a full example of training a network while collecting embeddings using a metrics collector.
Identifying Embedding Layers#
Identifying which layers of your model produce useful embeddings depends on the model architecture and the task. Some common examples include:
Image Classification: The next-to-last layer before the classification layer is typically a good choice.
Object Detection: The last layer of the backbone network or bounding-box predictor layers (e.g., heads predicting bounding box coordinates and class scores) are often suitable.
3LC uses the linear index into a module’s named_modules attribute to identify layers.
To see all modules with their corresponding indices, you can use the following code snippet:
Discovering Embedding Sizes#
Once you have determined the indices of the layers you want to collect embeddings from, you need to understand the shapes of the embeddings produced by these layers. Some layers will produce embeddings of fixed size, while others may produce embeddings of varying sizes depending on the input data.
Using the torchinfo
package, you can discover the shapes of various embeddings in a model, given an input size.
import torchinfo
import timm
model = timm.create_model("resnet18", pretrained=False, num_classes=2, in_chans=1)
torchinfo.summary(model, input_size=(1, 1, 64, 64))
# Output:
# ==========================================================================================
# Layer (type:depth-idx) Output Shape Param #
# ==========================================================================================
# ResNet [1, 2] --
# ├─Conv2d: 1-1 [1, 64, 32, 32] 3,136
# ├─BatchNorm2d: 1-2 [1, 64, 32, 32] 128
# ├─ReLU: 1-3 [1, 64, 32, 32] --
# ├─MaxPool2d: 1-4 [1, 64, 16, 16] --
# ├─Sequential: 1-5 [1, 64, 16, 16] --
# │ └─BasicBlock: 2-1 [1, 64, 16, 16] --
# ...
# ├─Sequential: 1-8 [1, 512, 2, 2] --
# │ └─BasicBlock: 2-7 [1, 512, 2, 2] --
# │ │ └─Conv2d: 3-51 [1, 512, 2, 2] 1,179,648
# │ │ └─BatchNorm2d: 3-52 [1, 512, 2, 2] 1,024
# │ │ └─Identity: 3-53 [1, 512, 2, 2] --
# │ │ └─ReLU: 3-54 [1, 512, 2, 2] --
# │ │ └─Identity: 3-55 [1, 512, 2, 2] --
# │ │ └─Conv2d: 3-56 [1, 512, 2, 2] 2,359,296
# │ │ └─BatchNorm2d: 3-57 [1, 512, 2, 2] 1,024
# │ │ └─Sequential: 3-58 [1, 512, 2, 2] 132,096
# │ │ └─ReLU: 3-59 [1, 512, 2, 2] --
# │ └─BasicBlock: 2-8 [1, 512, 2, 2] --
# │ │ └─Conv2d: 3-60 [1, 512, 2, 2] 2,359,296
# │ │ └─BatchNorm2d: 3-61 [1, 512, 2, 2] 1,024
# │ │ └─Identity: 3-62 [1, 512, 2, 2] --
# │ │ └─ReLU: 3-63 [1, 512, 2, 2] --
# │ │ └─Identity: 3-64
Reshape Strategies#
After embeddings have been collected, they must be converted to a fixed size before further processing. This is achieved by reshaping the embeddings, converting them from multi-dimensional tensors to one-dimensional vectors. Note that this is not the same as dimensionality reduction, rather a required step to ensure the embedding of each sample is a 1D vector of fixed size.
We provide several strategies for reshaping embeddings, which are passed as an argument to the
EmbeddingsMetricsCollector
class:
mean
: Take the mean across all non-first dimensions (excluding the batch dimension).flatten
: Flatten all dimensions after the batch dimension.avg_pool_1_1
,avg_pool_2_2
,avg_pool_3_3
: Use average pooling with a given output size to ensure consistent shapes.A custom function that takes the embedding tensor as input and returns a reshaped tensor.
Explicit reshaping might not be necessary if the embeddings are already of fixed size, or if all inputs to the model
have the same size. In this case, the flatten
strategy, will ensure the output is one dimensional.
How Are the Embeddings Collected?
When using a
EmbeddingsMetricsCollector
,
the output of specific layers in the model are collected using forward
hooks. These hooks are registered on the model’s layers to capture
the activations as the data passes through the network. This is handled automatically by 3LC during metrics collection,
but it is also possible to collect embeddings manually, as in the following example:
import torch
import timm
model = timm.create_model("resnet18", pretrained=False, num_classes=2, in_chans=1)
layer_to_collect = 10 # Index of the layer to collect
def hook(module, input, output):
print(f"Embedding shape: {output.shape}")
# ...save the output of the layer
# Register the hook on the desired layer
for layer_index, (name, module) in enumerate(model.named_modules()):
if layer_index == layer_to_collect:
print("Registering hook on layer:", name)
module.register_forward_hook(hook)
break
# Run the model
data = torch.randn(1, 1, 224, 224)
output = model(data)
data = torch.randn(1, 1, 128, 128)
output = model(data)
# Output:
# Registering hook on layer: layer1.0.act1
# Embedding shape: torch.Size([1, 64, 56, 56])
# Embedding shape: torch.Size([1, 64, 32, 32])
Collection Examples#
Below is an example of collecting embeddings from a model and reshaping them using the EmbeddingsMetricsCollector
:
"""Collect embeddings from layers 10 and 11 of a model"""
import torch
import tlc
table = ... # a tlc.Table representing the input data
model = ... # a torch.nn.Module
collector = tlc.EmbeddingsMetricsCollector(
layers=[10, 11],
reshape_strategy={
10: "flatten",
11: "avg_pool_2_2",
},
)
# Wrap the model in a Predictor object in order to collect intermediate activations
predictor = tlc.Predictor(model, layers=[10, 11])
# Collect embeddings. The embeddings will be stored in a metrics table associated with the current active Run.
tlc.collect_metrics(table, collector, predictor)
For more details on tlc.Predictor
, and how to combine data from tables with models, see the
Predictor user guide.
Next we show an example of collecting embeddings from a pretrained ViT model, using a custom function to reshape the embeddings.
"""Collect embeddings from the `layernorm` layer of ViT"""
from transformers import ViTModel
import torch
import tlc
table: tlc.Table = ... # A table returning suitable input data (image-tensors of shape (3, 224, 224))
dataloader = torch.utils.data.DataLoader(table, batch_size=4)
model = ViTModel.from_pretrained("google/vit-base-patch16-224")
predictor = tlc.Predictor(model, layers=[223]) # Index of the `layernorm` layer
def reshape_function(layer_output):
"""Extract embeddings corresponding to the [CLS] token from the `layernorm` layer of a Vision Transformer model.
This function is designed to specifically extract the embeddings corresponding to the [CLS] token. In ViT models, the
[CLS] token is typically used as a summary representation of the entire image. By slicing with `layer_output[:, 0, :]`,
we select the embeddings of the [CLS] token from the layer's output, which has dimensions (batch_size, sequence_length,
hidden_size). This operation produces a tensor of shape (batch_size, hidden_size), making it suitable for further
analysis or downstream tasks. For this model, hidden size is 768."""
return layer_output[:, 0, :]
collector = tlc.EmbeddingsMetricsCollector(
layers=[223],
reshape_strategy={
223: reshape_function,
},
)
# Collect embeddings from the first batch only:
for batch in dataloader:
predictor_output = predictor(batch)
collected_embeddings = collector(batch, predictor_output)
break
print(collected_embeddings)
# Output:
# {'embeddings_223': tensor([[ 0.6929, -0.1778, 1.4126, ..., -0.2675, -0.2489, 0.6796],
# [-0.0734, -0.2110, 0.3797, ..., 0.5803, 0.7378, 0.1924],
# [-1.5251, -1.0982, 0.9778, ..., 1.3487, -0.2712, -0.5566],
# [-0.3959, -0.5170, 0.3551, ..., -0.4861, -0.4202, -0.7321]],
# device='cuda:0')}
Dimensionality Reduction in 3LC#
Dimensionality reduction techniques help transform high-dimensional embeddings into lower-dimensional spaces for easier visualization and analysis.
Importance of Dimensionality Reduction#
Dimensionality reduction is important because it:
Improves Visualization: High-dimensional data is challenging to visualize; reducing dimensions helps in plotting and interpreting data.
Enhances Clustering and Classification: Reduced dimensions can reveal inherent clusters and relationships within the data.
Simplifies Models: Lower-dimensional representations can lead to simpler models with faster training times and reduced risk of overfitting.
Reducing Embeddings#
3LC currently supports two popular methods for dimensionality reduction:
UMAP: Uniform Manifold Approximation and Projection reduces the dimensionality of data by constructing a high-dimensional graph and optimizing a low-dimensional graph to preserve the structure.
PACMAP: Pairwise Controlled Manifold Approximation optimizes a graph-based objective function to preserve both local and global data structures by balancing different pairwise distances during dimensionality reduction.
All dimensionality reduction algorithms in 3LC are implemented as Table types, operating on an input table and producing a new table containing the reduced data. These tables provide a common set of parameters for controlling the reduction process, as well as parameters specific to each algorithm.
Common parameters to all dimensionality reduction tables:
Parameter |
Description |
---|---|
source_embedding_column |
The name of the column containing the data to be reduced (must be 1-dimensional vectors) |
target_embedding_column |
The name of the column in the output table containing the dimensionality reduced data |
retain_source_embedding_column |
Whether the input column should be present in the output table |
While the low-level implementation of dimensionality reduction is implemented in the
UMAPTable
and
PaCMAPTable
table types, in practice you will use
one of the higher level helper functions in the reduce
module, or through Run object’s
Run.reduce_embeddings_by_foreign_table_url()
and Run.reduce_embeddings_per_dataset()
.
When embeddings are collected during training, using a metrics collector, the result is a sequence of metric tables containing the embeddings from the model across iterations. The Run-methods mentioned above are designed to reuse the same reduction model across multiple metric tables in specific ways. This allows e.g. the reducer to be fit on the final epoch embeddings and then applied to all intermediate embeddings. This is useful for visualizing how the embeddings evolve over time.
For a overview of method-specific arguments, see classes
PaCMAPTableArgs
, and
UMapTableArgs
.
Reduction Examples#
We demonstrate dimensionality reduction with a few examples.
"""Reduce a column containing embeddings using UMAP"""
import numpy as np
import tlc
# Create a table with random embeddings (100 embeddings of size 100)
table = tlc.Table.from_dict(
data={
"embedding": np.random.rand(100, 100).tolist(),
},
table_name="embeddings",
description="A table with random embeddings",
add_weight_column=False,
)
print(f"Created embedding table with columns: {table.columns}")
# Create a reducer to reduce the embeddings to 2D using UMAP
umap_args = {
"n_components": 2,
"n_neighbors": 5,
"retain_source_embedding_column": True,
}
reducer = tlc.create_reducer("umap", umap_args)
# Reduce the table
reduced_table_url = reducer.fit_reduction_method(table, "embedding")
reduced_table = tlc.Table.from_url(reduced_table_url)
print(f"Created reduced table with columns: {reduced_table.columns}")
# Output:
# Created embedding table with columns: ['embedding']
# Created reduced table with columns: ['embedding', 'embedding_umap']
"""Reduce all embedding-columns in a Run's metrics tables using PaCMAP"""
import tlc
train_table: tlc.Table = ... # A table containing the training data
val_table: tlc.Table = ... # A table containing the validation data
run: tlc.Run = ... # A Run object containing metrics tables with embeddings
# Let's assume embeddings were collected across multiple epochs on both splits
# It is often desirable to see how points move through embedding-space over time, and in order
# to accomplish this, we need to re-use the same reduction model across several metrics tables.
# We can choose between two high-level APIs for reducing all metrics tables in a Run:
# 1. Reduce embeddings using the most recent embeddings produced by a specific table revision
# This method identifies a single metrics table to fit the reduction model. This model is applied
# on all other metrics tables in the Run. We set delete_source_tables=True to delete the source
# tables after reduction, in order to save disk-space.
url_mapping = run.reduce_embeddings_by_foreign_table_url(train_table.url, delete_source_tables=True)
# 2. Reduce embeddings using the most recent embeddings originating from the same table
# This method identifies the most recent metrics table for each dataset in the Run,
# and fits reduction models on these tables. These models are then applied to all other
# metrics tables corresponding to the same table.
url_mapping = run.reduce_embeddings_per_dataset(delete_source_tables=True)
# The url_mapping dictionary contains the mapping between the original metrics table URLs
# and the URLs of the reduced tables.