tlc.client.torch.metrics.predictor
#
A wrapper for calling PyTorch models
Module Contents#
Classes#
Class |
Description |
---|---|
Arguments for the Predictor class. |
|
The output of the Predictor class. |
|
A wrapper for PyTorch models that handles preprocessing, device management, embedding extraction, and prediction. |
API#
- class tlc.client.torch.metrics.predictor.PredictorArgs#
Arguments for the Predictor class.
- layers: Sequence[int] = field(...)#
The indices of the hidden layers to extract during a forward pass through the model. Hidden layers are returned as torch tensors and are collected using the forward hooks mechanism. Layer indices correspond to the order of the modules as returned by the wrapped model’s named modules attribute.
- preprocess_fn: Callable[[tlc.core.builtins.types.SampleData], Any] | None = field(...)#
A function to preprocess the input data.
- device: torch.device | str | None = field(...)#
The device to use for processing.
- class tlc.client.torch.metrics.predictor.PredictorOutput#
The output of the Predictor class.
A dictionary where each key is the layer index, and the value is the activation of that layer.
- class tlc.client.torch.metrics.predictor.Predictor(model: torch.nn.Module, **predictor_args: Any)#
A wrapper for PyTorch models that handles preprocessing, device management, embedding extraction, and prediction.
A high-level wrapper around a PyTorch model designed to standardize the workflow of processing inputs, making predictions, and handling outputs. It serves to unify the interface for different PyTorch models, ensuring consistency and ease of use across various modeling tasks.
A Predictor can be configured to extract hidden layer activations from the model during a forward pass by supplying the
PredictorArgs.layers
argument to the constructor. These activations are stored in thePredictorOutput
, and can be used for downstream tasks such as feature extraction, visualization, or debugging.See the
torch.nn.Module
documentation for more information on PyTorch models and modules.Initializes the Predictor with a model and optional arguments.
- Parameters:
model – A torch.nn.Module model for which predictions will be made.
**predictor_args – Arbitrary keyword arguments that will be passed to the
PredictorArgs
dataclass. These can include configurations such as which layers to hook for output, preprocessing functions, device specifications, and whether to unpack dictionaries or lists when passing data to the model.
- get_device() torch.device #
Determines the appropriate device for model computation.
If a device is specified in the predictor arguments, it is used. Otherwise, attempts to use the same device as the model parameters. Defaults to CPU if the model has no parameters.
- preprocess(batch: tlc.core.builtins.types.SampleData) tlc.core.builtins.types.SampleData #
Applies preprocessing to the input batch, based on the predictor arguments.
If a custom preprocessing function is provided, it is used. Otherwise, default preprocessing is applied.
The default preprocessing behavior attempts to identify the input data within the batch using the following heuristics:
If the batch is a list of dictionaries, it is assumed that the input data is already preprocessed.
If the batch is a tuple or list, the first item is assumed to be the input data.
If the batch is a dictionary, the input data is assumed to be under the keys
image
,images
, orpixel_values
.
To disable preprocessing, set the disable_preprocess argument to True.
- Parameters:
batch – The input data batch to preprocess.
- Returns:
The preprocessed data batch.
- to_device(batch: tlc.core.builtins.types.SampleData) tlc.core.builtins.types.SampleData #
Moves the batch of data to the appropriate device.
This method uses a utility function to recursively move all tensors in the batch to the specified device.
- Parameters:
batch – The preprocessed data batch to move to the device.
- Returns:
The data batch, with all tensors moved to the specified device.
- call_model(processed_batch: tlc.core.builtins.types.SampleData) tlc.core.builtins.types.SampleData #
Calls the model with the processed batch, handling unpacking if necessary.
This method supports passing the batch to the model as unpacked dictionaries or lists, based on the predictor arguments, or directly if no unpacking is required.
- Parameters:
processed_batch – The batch of data to pass to the model, already preprocessed.
- Returns:
The raw forward pass output from the model.