tlc.client.torch.metrics.metrics_collectors.classification_metrics_collector#

Collect loss and predictions for classification problems.

Module Contents#

Classes#

Class

Description

ClassificationMetricsCollector

Collect common metrics for classification tasks.

API#

class tlc.client.torch.metrics.metrics_collectors.classification_metrics_collector.ClassificationMetricsCollector(classes: list[str] | None = None, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.CrossEntropyLoss(reduction='none'), compute_aggregates: bool = True, preprocess_fn: Callable[[tlc.core.builtins.types.SampleData, tlc.client.torch.metrics.predictor.PredictorOutput], tuple[tlc.core.builtins.types.SampleData, torch.Tensor]] | None = None)#

Bases: tlc.client.torch.metrics.metrics_collectors.metrics_collector_base.MetricsCollector

Collect common metrics for classification tasks.

This class is a specialized version of MetricsCollector and is designed to collect metrics relevant to classification problems. It is assumed that the result of the forward pass of the model is the raw logits for each class.

  • loss: The per-sample loss value, computed with the provided criterion function. By default, this is the cross-entropy loss.

  • predicted: The predicted class label.

  • accuracy: The per-sample accuracy of the prediction, i.e. whether it is correct.

  • confidence: The confidence of the prediction.

Example:

table = ...
model = ...
collector = ClassificationMetricsCollector()

tlc.collect_metrics(table, collector, model)

Initialize the classification metrics collector

Parameters:
  • classes – List of class names. If provided, the predicted schema will be updated to include a value map.

  • loss_fn – Unreduced (per-sample) loss function to use for calculating the loss metric. Default is torch.nn.CrossEntropyLoss(reduction="none").

  • compute_aggregates – Whether to compute aggregate metrics. Default is True.

  • preprocess_fn – Function to preprocess the batch and predictor output.

compute_metrics(batch: tlc.core.builtins.types.SampleData, predictor_output: tlc.client.torch.metrics.predictor.PredictorOutput) dict[str, tlc.core.builtins.types.MetricData]#
property column_schemas: dict[str, tlc.core.schema.Schema]#
preprocess(batch: tlc.core.builtins.types.SampleData, predictor_output: tlc.client.torch.metrics.predictor.PredictorOutput) tuple[tlc.core.builtins.types.SampleData, torch.Tensor]#
is_one_hot(labels: torch.Tensor) bool#

Check if labels are one-hot encoded.