tlc.client.torch.metrics.collect_dataset
#
Collect per-sample metrics with a PyTorch Dataset.
Module Contents#
Functions#
Function |
Description |
---|---|
Collect per-sample metrics with a PyTorch Dataset. |
API#
- tlc.client.torch.metrics.collect_dataset.collect_metrics(table: torch.utils.data.Dataset, metrics_collectors: list[tlc.client.torch.metrics.metrics_collectors.metrics_collector_base.MetricsCollector], predictor: torch.nn.Module | tlc.client.torch.metrics.predictor.Predictor | None = None, constants: dict[str, Any] = {}, constants_schemas: dict[str, tlc.core.schema.Schema] | None = None, dataset_url: str = '', dataset_name: str = '', run_url: tlc.core.url.Url | str | None = None, collect_aggregates: bool = True, split: str = '', exclude_zero_weights: bool = False, *, dataloader_args: dict[str, Any] | None = None) None #
Collect per-sample metrics with a PyTorch Dataset.
Writes a single metrics table which uses the input table as foreign table. This table will contain any constants contained in the
constants
argument, as well as any metrics computed by the metrics collectors.Add the metadata of the metrics table to the
metrics
property of the Run.Add the Url of the input table to the Run as an input.
Collect aggregate values from the metrics collectors and add them to the Run (merged with the constants)
- Parameters:
table – The Dataset to collect metrics from.
metrics_collectors – A list of metrics collectors to use.
predictor – A model or Predictor to use for computing metrics.
constants – A dictionary of constants to use when collecting metrics.
constants_schemas – A dictionary of schemas for the constants. If no schemas are provided, the schemas will be inferred from the constants.
dataset_url – The url of the dataset.
dataset_name – The name of the dataset.
run_url – The url of the run to add the metrics to. If not specified, the active run will be used.
collect_aggregates – Whether to collect aggregate values from the metrics collectors and add them to the Run. This allows an aggregate view to be shown in the Project page of the 3LC Dashboard. Aggregate values are computed for all computable columns in the metrics collectors, and are prefixed with the split name. For example, if a metrics collector defines a computable column called “accuracy”, and the split is “train”, then the aggregate value will be called “train_accuracy_avg”.
split – The split of the dataset. This will be prepended to the aggregate metric names.
dataloader_args – Additional arguments to pass to the dataloader.