tlc.integration.hugging_face.trainer

Module Contents

Classes

Class

Description

Trainer

API

class Trainer(
model: PreTrainedModel | Module | None = None,
args: TrainingArguments | None = None,
data_collator: Callable | None = None,
train_dataset: Table | None = None,
eval_dataset: Table | dict[str, Table] | None = None,
processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin | None = None,
model_init: Callable[[], PreTrainedModel] | None = None,
compute_loss_func: Callable | None = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[Optimizer | None, LambdaLR | None] = (None, None),
preprocess_logits_for_metrics: Callable[[Tensor, Tensor], Tensor] | None = None,
*,
run_name: str | None = None,
run_description: str | None = None,
exclude_zero_weights_metrics_collection: bool = False,
exclude_zero_weights_train: bool = True,
weighted: bool = True,
shuffle: bool = True,
repeat_by_weight: bool = False,
metrics_collectors: tlc.client.torch.metrics.metrics_collectors.metrics_collector_base.MetricsCollectorType | None = None,
metrics_collection_epochs: list[int] | None = None,
)

Bases: transformers.Trainer