Implement Custom CL Dataset
This section guides you through implementing custom continual learning datasets for use in CLArena.
CL datasets are constructed from single-task datasets. There are three ways:
datasets are usually constructed in three ways:
- Permute: Generate from a single dataset by permuting pixels of the images in certain order to create different tasks.
- Split: Split a single dataset by classes to create different tasks.
- Combine: Use different datasets, each as a separate task.
Please refer to my continual learning beginners’ guide to learn more about CL datasets.
Base Classes
In CLArena, continual learning datasets are implemented as subclasses of the base classes defined in clarena/cl_datasets/base.py. The base classes are implemented inheriting Lightning data module with additional features for continual learning:
clarena.cl_datasets.CLDataset
: The base class for all continual learning datasets.clarena.cl_datasets.CLPermutedDataset
: The base class for permuted continual learning datasets. A child class ofCLDataset
.clarena.cl_datasets.CLSplitDataset
: The base class for split continual learning datasets. A child class ofCLDataset
.clarena.cl_datasets.CLCombinedDataset
: The base class for combined continual learning datasets. A child class ofCLDataset
.
Implement Permuted CL Dataset
To implement permuted CL datasets:
- Inherit
CLPermutedDataset
. - Define class property
original_dataset_python_class
, which is the raw python class of the original dataset that the permuted CL dataset is generated from. If there’s no such class, implement one under clarena/stl_datasets/raw/ (preferably a PyTorchDataset
). - Define the constants of the original dataset in a subclass of
DatasetConstants
in clarena/stl_datasets/raw/constants.py. Link the constants class to theoriginal_dataset_python_class
in theDATASET_CONSTANTS_MAPPING
dictionary. - Write
prepare_data()
,train_and_val_dataset()
,test_dataset()
. You may call the APIs provided by theoriginal_dataset_python_class
. Make sure to useself.train_and_val_transforms()
,self.test_transforms()
andself.target_transform()
to assign transforms. Note that the permutation transform is included inself.train_and_val_transforms()
andself.test_transforms()
.
prepare_data()
is called at the beginning of each task. Please include the following snippet to avoid redundant data downloads.
if self.task_id != 1:
return # download all original datasets only at the beginning of first task
Implement Split CL Dataset
To implement permuted CL datasets:
- Inherit
CLSplitDataset
. - Define class property
original_dataset_python_class
, which is the raw python class of the original dataset that the permuted CL dataset is generated from. If there’s no such class, implement one under clarena/stl_datasets/raw/ (preferably a PyTorchDataset
). - Define the constants of the original dataset in a subclass of
DatasetConstants
in clarena/stl_datasets/raw/constants.py. Link the constants class to theoriginal_dataset_python_class
in theDATASET_CONSTANTS_MAPPING
dictionary. - Write
get_subset_of_classes()
to return a subset of classes for the current taskself.task_id
. - Write
prepare_data()
,train_and_val_dataset()
,test_dataset()
. You may call the APIs provided by theoriginal_dataset_python_class
. Make sure to useself.train_and_val_transforms()
,self.test_transforms()
andself.target_transform()
to assign transforms. Useself.get_subset_of_classes()
to filter the classes for the task.
Implement Combined CL Dataset
Combined CL dataset is already implemented as Combined
in clarena/cl_datasets/combined.py. To add more available single-task datasets to construct combined CL dataset, please add them in AVAILABLE_DATASETS
, prepare_data()
, train_and_val_dataset()
, and test_dataset()
methods.
For more details, please refer to the API Reference and source code. You may take implemented CL datasets in CLArena as examples. Feel free to contribute by submitting pull requests in GitHub!