Integrating 3LC with PyTorch Lightning¶
PyTorch Lightning is a popular machine learning framework that provides a
high-level interface for training and evaluating PyTorch models. Integrating 3LC with Lightning is done using the
primitives provided by the 3lc Python Package, into the standard patterns provided by Lightning.
Integration pattern¶
Create your tables eagerly, before Trainer.fit(), then add 3LC calls to your LightningModule. The important steps
are table creation, run creation and collection of metrics.
import tlc
import pytorch_lightning as pl
from torch.utils.data import DataLoader
train_table = tlc.Table.from_...
val_table = tlc.Table.from_...
class MyModule(pl.LightningModule):
def __init__(self, train_table, val_table):
super().__init__()
# Tables are not pickleable in the way Lightning expects for hparams — exclude them.
self.save_hyperparameters(ignore=["train_table", "val_table"])
self.train_table = train_table
self.val_table = val_table
# ... model definition ...
def train_dataloader(self):
return DataLoader(
self.train_table,
sampler=tlc.integration.torch.samplers.create_sampler(
self.train_table, weighted=True, exclude_zero_weights=True
),
batch_size=...,
num_workers=...,
)
def val_dataloader(self):
return DataLoader(self.val_table, batch_size=..., num_workers=...)
def on_train_start(self):
super().on_train_start()
self.tlc_run = tlc.init(
project_name="my-project",
run_name=...,
description=...,
parameters=self.hparams_initial,
)
self.tlc_run.set_status_running()
def on_train_epoch_end(self):
super().on_train_epoch_end()
if (self.current_epoch + 1) % 5 == 0:
self._collect_3lc_metrics()
def on_train_end(self):
super().on_train_end()
self._collect_3lc_metrics()
self.tlc_run.set_status_completed()
def _collect_3lc_metrics(self):
predictor = tlc.metrics.Predictor(self)
for split, table in [("train", self.train_table), ("val", self.val_table)]:
tlc.collect_metrics(
table=table,
metrics_collectors=[...],
predictor=predictor,
split=split,
constants={"epoch": self.current_epoch},
exclude_zero_weights=True,
)
pl.Trainer(...).fit(MyModule(train_table, val_table))
Note
Lightning’s self.log(...) writes to trainer.callback_metrics. Surfacing those onto the tlc.Run is left
to user code today. See the example notebooks for one approach.
Note
If you are running under DDP, extra considerations should be taken when integrating 3LC.
Lightning replicates the
LightningModuleacross ranks, so hooks likeon_train_endfire on every process. Guardtlc.collect_metricscalls behindself.trainer.is_global_zero(or@rank_zero_only) to avoid every rank writing the same metrics table.Lightning auto-wraps your sampler in a
DistributedSamplerwhen DDP is active. Consider whether that interacts cleanly withtlc.integration.torch.samplers.create_sampler’s weighting, or whether you wantuse_distributed_sampler=False.Creating tables outside the module, before
Trainer.fit(), keeps table construction off the worker ranks entirely.
Example notebooks¶
See the end-to-end example notebooks for full integration examples
Note
The @tlc.integration.pytorch_lightning.lightning_module decorator was removed in 3lc version 3.0. The patterns above
cover everything the decorator did, with the 3LC-side intent visible at the call site. If you have a 2.x project that
depends on the decorator and the migration above is not enough, please reach out to the 3LC team.