Map Functions¶
In most training scripts, some kind of transform is applied to the data before it is consumed by the model. For example, in computer vision, augmentations and conversion to tensors is usually applied. 3LC enables this through the method table.map().
When using table.map(), provide a Python function that takes a full sample and returns the result of applying the transform. This transform is applied on-the-fly when __getitem__ is called on the Table, essentially overwriting the Sample View of the current instantiation of the Table. The transformed samples are not saved anywhere, there is no in-memory caching and the Table data itself is not modified in any way.
Example¶
The following simple example shows how table.map() can be used to apply torchvision transforms to the data.
[1]:
import tlc
from PIL import Image
table = tlc.Table.from_torch_dataset(
dataset=[{"image": Image.open("3lc-logo.png")}],
structure={"image": tlc.PILImage},
project_name="Map Functions Project",
if_exists="overwrite",
)
# At first, the `Sample` view presents the PIL Image
table[0]["image"]
[1]:
[2]:
from torchvision import transforms
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.RandomRotation(degrees=45),
transforms.RandomHorizontalFlip(),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.5, 1.5), shear=10),
transforms.ToPILImage(), # Back to PIL for visualization
])
def train_map_function(sample):
return {"image": train_transforms(sample["image"])}
table.map(train_map_function)
def plot_augmented_samples(images):
widths, heights = zip(*(im.size for im in images))
total_width = sum(widths)
max_height = max(heights)
concat_img = Image.new('RGB', (total_width, max_height))
x_offset = 0
for im in images:
concat_img.paste(im, (x_offset, 0))
x_offset += im.size[0]
return concat_img
plot_augmented_samples([table[0]["image"] for _ in range(10)])
[2]:
Collection-only map functions¶
It is common to perform different augmentation for training and validation. In 3LC, this corresponds to applying different transforms during metrics collection to those used in training.
Map functions specific to metrics collection can be added by using the method table.map_collect_metrics(). When using table.map_collect_metrics(). When using tlc.collect_metrics() these transforms are applied. If not defined, functions provided through table.map() are applied in metrics collection instead.
If you would like to use this distinction in your own code, use the Table.collection_mode() context manager:
[3]:
collection_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(), # Back to PIL for visualization
])
def collection_map_function(sample):
return {"image": collection_transforms(sample["image"])}
table = table.map_collect_metrics(collection_map_function)
with table.collection_mode():
collection_plot = plot_augmented_samples([table[0]["image"] for _ in range(10)])
collection_plot
[3]:
Multiple map functions¶
Multiple map functions can be applied in succession by repeatedly calling table.map().
[4]:
table = table.map(lambda x: x)
Note that it is not possible to add the same function multiple times, and an attempt to do so will emit a warning:
[5]:
table = table.map(train_map_function)
WARNING 3lc: Function <function train_map_function at 0x16cbe3920> already exists in the map functions list. Skipping.
Removing map functions¶
To remove all the map functions from a Table instance, use the method table.clear_maps():
[6]:
table.clear_maps()
table[0]["image"]
[6]: