Integrating 3LC with PyTorch Lightning¶

PyTorch Lightning is a popular machine learning framework that provides a high-level interface for training and evaluating PyTorch models. Integrating 3LC with Lightning is done using the primitives provided by the 3lc Python Package, into the standard patterns provided by Lightning.

Integration pattern¶

Create your tables eagerly, before Trainer.fit(), then add 3LC calls to your LightningModule. The important steps are table creation, run creation and collection of metrics.

import tlc
import pytorch_lightning as pl
from torch.utils.data import DataLoader

train_table = tlc.Table.from_...
val_table = tlc.Table.from_...

class MyModule(pl.LightningModule):
    def __init__(self, train_table, val_table):
        super().__init__()
        # Tables are not pickleable in the way Lightning expects for hparams — exclude them.
        self.save_hyperparameters(ignore=["train_table", "val_table"])
        self.train_table = train_table
        self.val_table = val_table
        # ... model definition ...

    def train_dataloader(self):
        return DataLoader(
            self.train_table,
            sampler=tlc.integration.torch.samplers.create_sampler(
                self.train_table, weighted=True, exclude_zero_weights=True
            ),
            batch_size=...,
            num_workers=...,
        )

    def val_dataloader(self):
        return DataLoader(self.val_table, batch_size=..., num_workers=...)

    def on_train_start(self):
        super().on_train_start()
        self.tlc_run = tlc.init(
            project_name="my-project",
            run_name=...,
            description=...,
            parameters=self.hparams_initial,
        )
        self.tlc_run.set_status_running()

    def on_train_epoch_end(self):
        super().on_train_epoch_end()
        if (self.current_epoch + 1) % 5 == 0:
            self._collect_3lc_metrics()

    def on_train_end(self):
        super().on_train_end()
        self._collect_3lc_metrics()
        self.tlc_run.set_status_completed()

    def _collect_3lc_metrics(self):
        predictor = tlc.metrics.Predictor(self)
        for split, table in [("train", self.train_table), ("val", self.val_table)]:
            tlc.collect_metrics(
                table=table,
                metrics_collectors=[...],
                predictor=predictor,
                split=split,
                constants={"epoch": self.current_epoch},
                exclude_zero_weights=True,
            )


pl.Trainer(...).fit(MyModule(train_table, val_table))

Note

Lightning’s self.log(...) writes to trainer.callback_metrics. Surfacing those onto the tlc.Run is left to user code today. See the example notebooks for one approach.

Note

If you are running under DDP, extra considerations should be taken when integrating 3LC.

  • Lightning replicates the LightningModule across ranks, so hooks like on_train_end fire on every process. Guard tlc.collect_metrics calls behind self.trainer.is_global_zero (or @rank_zero_only) to avoid every rank writing the same metrics table.

  • Lightning auto-wraps your sampler in a DistributedSampler when DDP is active. Consider whether that interacts cleanly with tlc.integration.torch.samplers.create_sampler’s weighting, or whether you want use_distributed_sampler=False.

  • Creating tables outside the module, before Trainer.fit(), keeps table construction off the worker ranks entirely.

Example notebooks¶

See the end-to-end example notebooks for full integration examples

Note

The @tlc.integration.pytorch_lightning.lightning_module decorator was removed in 3lc version 3.0. The patterns above cover everything the decorator did, with the 3LC-side intent visible at the call site. If you have a 2.x project that depends on the decorator and the migration above is not enough, please reach out to the 3LC team.