Implement Custom MTL Dataset
This section guides you through implementing custom multi-task learning datasets for use in CLArena.
Multi-task learning can be constructed from single-task datasets by combining different datasets, each as a separate task.
Base Classes
In CLArena, multi-task learning datasets are implemented as subclasses of the base classes defined in clarena/mtl_datasets/base.py. The base classes are implemented inheriting Lightning data module with additional features for multi-task learning:
clarena.mtl_datasets.MTLDataset
: The base class for all multi-task learning datasets.clarena.mtl_datasets.MTLCombinedDataset
: The base class for combined multi-task learning datasets. A child class ofMTLDataset
.clarena.mtl_datasets.MTLDatasetFromCL
: The base class for constructing multi-task learning datasets from continual learning datasets. A child class ofMTLDataset
.
Implement Combined MTL Dataset
Combined MTL dataset is already implemented as Combined
in clarena/mtl_datasets/combined.py. To add more available single-task datasets to construct combined MTL dataset, please add them in AVAILABLE_DATASETS
, prepare_data()
, train_and_val_dataset()
, and test_dataset()
methods.
The MTL Dataset must be task labelled, which means each sample of a batch not only has input data and target label –
To turn a single-task dataset into a task-labelled dataset, you can use the TaskLabelledDataset
wrapper in clarena/stl_datasets/base.py.
For more details, please refer to the API Reference and source code. You may take implemented MTL datasets in CLArena as examples. Feel free to contribute by submitting pull requests in GitHub!