tlc.client.torch.metrics.metrics_collectors.classification_metrics_collector#

Collect loss and predictions for classification problems.

Module Contents#

Classes#

Class

Description

ClassificationMetricsCollector

Collects classification metrics loss and prediction.

API#

class tlc.client.torch.metrics.metrics_collectors.classification_metrics_collector.ClassificationMetricsCollector(model: torch.nn.Module, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.CrossEntropyLoss(reduction='none'), predicted_schema: tlc.core.schema.Schema | None = None, compute_aggregates: bool = True)#

Bases: tlc.client.torch.metrics.metrics_collectors.metrics_collector_base.MetricsCollector

Collects classification metrics loss and prediction.

This class is a specialized version of MetricsCollector and is designed to collect metrics relevant to classification problems.

You can set up data transformation pipelines by using the transforms, transform, and target_transform parameters.

Example:

model = SomeTorchModel()
collector = ClassificationMetricsCollector(model)
Parameters:
  • model – The PyTorch model for which the metrics are to be collected.

  • criterion – Unreduced (per-sample) loss function to use for calculating the loss metric. Default is torch.nn.CrossEntropyLoss(reduction="none").

  • predicted_schema – Schema for the predicted output. Can be None.

  • transforms – A callable for common transforms to both input and target. Optional.

  • transform – A callable for transforming the input data before prediction. Keyword only.

  • target_transform – A callable for transforming the target data before loss computation. Keyword only.

compute_metrics(batch: tlc.core.builtins.types.SampleData, predictions: torch.Tensor, _: Any) dict[str, tlc.core.builtins.types.MetricData]#
property column_schemas: dict[str, tlc.core.schema.Schema]#