tlc.integration.hugging_face.trainer

Module Contents

Classes

Class

Description

Trainer

API

class tlc.integration.hugging_face.trainer.Trainer(model: transformers.modeling_utils.PreTrainedModel | torch.nn.Module | None = None, args: transformers.training_args.TrainingArguments | None = None, data_collator: transformers.data.data_collator.DataCollator | None = None, train_dataset: tlc.core.objects.table.Table | None = None, eval_dataset: tlc.core.objects.table.Table | dict[str, tlc.core.objects.table.Table] | None = None, processing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.image_processing_utils.BaseImageProcessor | transformers.feature_extraction_utils.FeatureExtractionMixin | transformers.processing_utils.ProcessorMixin | None = None, model_init: Callable[[], transformers.modeling_utils.PreTrainedModel] | None = None, compute_loss_func: Callable | None = None, compute_metrics: Callable[[transformers.modeling_outputs.EvalPrediction], dict] | None = None, callbacks: list[transformers.trainer.TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, *, run_name: str | None = None, run_description: str | None = None, exclude_zero_weights_metrics_collection: bool = False, exclude_zero_weights_train: bool = True, weighted: bool = True, shuffle: bool = True, repeat_by_weight: bool = False, metrics_collectors: tlc.client.torch.metrics.metrics_collectors.MetricsCollectorType | None = None, metrics_collection_epochs: list[int] | None = None)

Bases: transformers.Trainer