tlc.integration.hugging_face.trainer

Module Contents

Classes

Class

Description

Trainer

API

class Trainer(
model: transformers.modeling_utils.PreTrainedModel | Module | None = None,
args: transformers.training_args.TrainingArguments | None = None,
data_collator: transformers.data.data_collator.DataCollator | None = None,
train_dataset: Table | None = None,
eval_dataset: Table | dict[str, Table] | None = None,
processing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.image_processing_utils.BaseImageProcessor | transformers.feature_extraction_utils.FeatureExtractionMixin | transformers.processing_utils.ProcessorMixin | None = None,
model_init: Callable[[], transformers.modeling_utils.PreTrainedModel] | None = None,
compute_loss_func: Callable | None = None,
compute_metrics: Callable[[transformers.modeling_outputs.EvalPrediction], dict] | None = None,
callbacks: list[transformers.trainer.TrainerCallback] | None = None,
optimizers: tuple[Optimizer | None, LambdaLR | None] = (None, None),
preprocess_logits_for_metrics: Callable[[Tensor, Tensor], Tensor] | None = None,
*,
run_name: str | None = None,
run_description: str | None = None,
exclude_zero_weights_metrics_collection: bool = False,
exclude_zero_weights_train: bool = True,
weighted: bool = True,
shuffle: bool = True,
repeat_by_weight: bool = False,
metrics_collectors: tlc.client.torch.metrics.metrics_collectors.metrics_collector_base.MetricsCollectorType | None = None,
metrics_collection_epochs: list[int] | None = None,
)

Bases: transformers.Trainer