Implement Your CL Dataset
Continual learning datasets are implemented as classes based on the base classes provided by this package in clarena/cl_datasets/base.py. The base classes are implemented based on Lightning data module with added features for continual learning:
CLDataset
: the general base class for all kinds of CL datasets, which incorporates mechanisms for managing sequential tasks.CLPermutedDataset
: the base class for permuted CL datasets, based onCLDataset
, which incorporates the permutation operation. See my CL beginners’ guide to know more about permuted CL datasets.
Required Arguments
The required config fields for CLDataset
are outlined in the arguments of base classes:
root
: the root directory where the original data files for constructing the CL dataset physically live.num_tasks
: the maximum number of tasks supported by the CL dataset.validation_percentage
: the percentage to randomly split some of the training data into validation data. We require that the validation dataset is divided from training dataset.batch_size
: the batch size to construct train, val and test dataloaders.num_workers
: the number of workers for dataloaders.custom_transforms
: the custom transforms to apply to ONLY TRAIN dataset. It excludes the functional transforms likeToTensor()
, normalise, permute and so on as they are already included in the system.custom_target_transforms
: the custom target transforms to apply to dataset labels. It excludes the CL class mapping transform as it is already included in the system..
The label in your dataset are expected to be integers starting from 0. If not, you need to implement a target transform to map your original class labels to integers.
Exclusive arguments for CLPermutedDataset
:
permutation_mode
: the mode of permutation, which includes:- ‘all’: permute all pixels.
- ‘by_channel’: permute channel by channel separately. All channels are applied the same permutation order.
- ‘first_channel_only’: permute only the first channel.
permutation_seeds
: the seeds for permutation operations used to construct tasks.
Implementing Instructions
These methods must be implemented:
cl_class_map()
: in this package, the class labels are uniformly transformed to integers starting from 0, for the convenience to apply cross-entropy loss. You need to tell the module explicitly the map.mean()
,std()
: get the mean and standard deviation values to construct normalisation transform.perpare_data()
: get your original dataset whatever source is from, like download. This is required by the parent Lightning datamodule.train_and_val_dataset()
,test_dataset()
: construct the training, validation, test dataset for current task indicated by self-updatedself.task_id
.
For permuted CL datasets:
- Specify class property
img_size
to construct permutation operation. - Specify class property
mean_original
,std_original
instead of implementingmean()
,std()
.
Please find the API documentation for detailed information about implementation, and remember you can always take implemented CL datasets in package source codes clarena/cl_datasets/ as examples, like PermutedMNIST
!