clarena.cl_datasets
Continual Learning Datasets
This submodule provides the continual learning datasets that can be used in CLArena.
Please note that this is an API documantation. Please refer to the main documentation pages for more information about the CL datasets and how to configure and implement them:
- Configure CL Dataset
- Implement Your CL Dataset Class
- A Beginners' Guide to Continual Learning (CL Dataset)
The datasets are implemented as subclasses of CLDataset
classes, which are the base class for all continual learning datasets in CLArena.
CLDataset
: The base class for continual learning datasets.CLPermutedDataset
: The base class for permuted continual learning datasets. A child class ofCLDataset
.
1r""" 2 3# Continual Learning Datasets 4 5This submodule provides the **continual learning datasets** that can be used in CLArena. 6 7Please note that this is an API documantation. Please refer to the main documentation pages for more information about the CL datasets and how to configure and implement them: 8 9- [**Configure CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiment/cl-dataset) 10- [**Implement Your CL Dataset Class**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-dataset) 11- [**A Beginners' Guide to Continual Learning (CL Dataset)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-CL-dataset) 12 13The datasets are implemented as subclasses of `CLDataset` classes, which are the base class for all continual learning datasets in CLArena. 14 15- `CLDataset`: The base class for continual learning datasets. 16- `CLPermutedDataset`: The base class for permuted continual learning datasets. A child class of `CLDataset`. 17 18""" 19 20from .base import CLClassMapping, CLDataset, CLPermutedDataset, CLSplitDataset, Permute 21from .permuted_mnist import PermutedMNIST 22from .split_cifar100 import SplitCIFAR100 23 24__all__ = [ 25 "CLDataset", 26 "CLPermutedDataset", 27 "CLSplitDataset", 28 "CLClassMapping", 29 "Permute", 30 "permuted_mnist", 31 "split_cifar100", 32]
21class CLDataset(LightningDataModule): 22 r"""The base class of continual learning datasets, inherited from `LightningDataModule`.""" 23 24 def __init__( 25 self, 26 root: str, 27 num_tasks: int, 28 validation_percentage: float, 29 batch_size: int = 1, 30 num_workers: int = 10, 31 custom_transforms: Callable | transforms.Compose | None = None, 32 custom_target_transforms: Callable | transforms.Compose | None = None, 33 ) -> None: 34 r"""Initialise the CL dataset object providing the root where data files live. 35 36 **Args:** 37 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 38 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 39 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 40 - **batch_size** (`int`): The batch size in train, val, test dataloader. 41 - **num_workers** (`int`): the number of workers for dataloaders. 42 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 43 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 44 """ 45 LightningDataModule.__init__(self) 46 47 self.root: str = root 48 r"""Store the root directory of the original data files. Used when constructing the dataset.""" 49 self.num_tasks: int = num_tasks 50 r"""Store the maximum number of tasks supported by the dataset.""" 51 self.validation_percentage: float = validation_percentage 52 r"""Store the percentage to randomly split some of the training data into validation data.""" 53 self.batch_size: int = batch_size 54 r"""Store the batch size. Used when constructing train, val, test dataloader.""" 55 self.num_workers: int = num_workers 56 r"""Store the number of workers. Used when constructing train, val, test dataloader.""" 57 self.custom_transforms: Callable | transforms.Compose | None = custom_transforms 58 r"""Store the custom transforms other than the basics. Used when constructing the dataset.""" 59 self.custom_target_transforms: Callable | transforms.Compose | None = ( 60 custom_target_transforms 61 ) 62 r"""Store the custom target transforms other than the CL class mapping. Used when constructing the dataset.""" 63 64 self.task_id: int 65 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 66 self.cl_paradigm: str 67 r"""Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from `set_cl_paradigm` and used to define the CL class map.""" 68 69 self.cl_class_map_t: dict[str | int, int] 70 r"""Store the CL class map for the current task `self.task_id`. """ 71 self.cl_class_mapping_t: Callable 72 r"""Store the CL class mapping transform for the current task `self.task_id`. """ 73 74 self.dataset_train: object 75 r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 76 self.dataset_val: object 77 r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 78 self.dataset_test: dict[str, object] = {} 79 r"""The dictionary to store test dataset object. Keys are task IDs (string type) and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.""" 80 81 CLDataset.sanity_check(self) 82 83 def sanity_check(self) -> None: 84 r"""Check the sanity of the arguments. 85 86 **Raises:** 87 - **ValueError**: when the `validation_percentage` is not in the range of 0-1. 88 """ 89 if not 0.0 < self.validation_percentage < 1.0: 90 raise ValueError("The validation_percentage should be 0-1.") 91 92 @abstractmethod 93 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 94 r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses. 95 96 **Args:** 97 - **task_id** (`int`): The task ID to query CL class map. 98 99 **Returns:** 100 - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning. 101 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 102 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 103 """ 104 105 @abstractmethod 106 def prepare_data(self) -> None: 107 r"""Use this to download and prepare data. It must be implemented by subclasses, regulated by `LightningDatamodule`.""" 108 109 def setup(self, stage: str) -> None: 110 r"""Set up the dataset for different stages. 111 112 **Args:** 113 - **stage** (`str`): the stage of the experiment. Should be one of the following: 114 - 'fit' or 'validation': training and validation dataset of current task `self.task_id` should be assigned to `self.dataset_train` and `self.dataset_val`. 115 - 'test': a list of test dataset of all seen tasks (from task 0 to `self.task_id`) should be assigned to `self.dataset_test`. 116 """ 117 if stage == "fit" or "validate": 118 119 pylogger.debug( 120 "Construct train and validation dataset for task %d...", self.task_id 121 ) 122 self.dataset_train, self.dataset_val = self.train_and_val_dataset() 123 self.dataset_train.target_transform = ( 124 self.target_transforms() 125 ) # apply target transform after potential class split 126 self.dataset_val.target_transform = ( 127 self.target_transforms() 128 ) # apply target transform after potential class split 129 pylogger.debug( 130 "Train and validation dataset for task %d are ready.", self.task_id 131 ) 132 133 if stage == "test": 134 135 pylogger.debug("Construct test dataset for task %d...", self.task_id) 136 self.dataset_test[f"{self.task_id}"] = self.test_dataset() 137 self.dataset_test[f"{self.task_id}"].target_transform = ( 138 self.target_transforms() 139 ) # apply target transform after potential class split 140 pylogger.debug("Test dataset for task %d are ready.", self.task_id) 141 142 def setup_task_id(self, task_id: int) -> None: 143 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 144 145 **Args:** 146 - **task_id** (`int`): the target task ID. 147 """ 148 self.task_id = task_id 149 150 self.cl_class_map_t = self.cl_class_map(task_id) 151 self.cl_class_mapping_t = CLClassMapping(self.cl_class_map_t) 152 153 def set_cl_paradigm(self, cl_paradigm: str) -> None: 154 r"""Set the continual learning paradigm to `self.cl_paradigm`. It is used to define the CL class map. 155 156 **Args:** 157 - **cl_paradigm** (`str`): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). 158 """ 159 self.cl_paradigm = cl_paradigm 160 161 @abstractmethod 162 def mean(self, task_id: int) -> tuple[float]: 163 r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. 164 165 **Returns:** 166 - **mean** (`tuple[float]`): the mean values for normalisation. 167 """ 168 169 @abstractmethod 170 def std(self, task_id: int) -> tuple[float]: 171 r"""The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. 172 173 **Returns:** 174 - **std** (`tuple[float]`): the standard deviation values for normalisation. 175 """ 176 177 def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose: 178 r"""Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. It is a handy tool to use in subclasses when constructing the dataset. 179 180 **Args:** 181 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 182 183 **Returns:** 184 - **train_and_val_transforms** (`transforms.Compose`): the composed training transforms. 185 """ 186 187 return transforms.Compose( 188 list( 189 filter( 190 None, 191 [ 192 transforms.ToTensor() if to_tensor else None, 193 self.custom_transforms, 194 transforms.Normalize( 195 self.mean(self.task_id), self.std(self.task_id) 196 ), 197 ], 198 ) 199 ) 200 ) # the order of transforms matters 201 202 def test_transforms(self, to_tensor: bool) -> transforms.Compose: 203 r"""Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset. 204 205 **Args:** 206 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 207 208 **Returns:** 209 - **test_transforms** (`transforms.Compose`): the composed training transforms. 210 """ 211 212 return transforms.Compose( 213 list( 214 filter( 215 None, 216 [ 217 transforms.ToTensor() if to_tensor else None, 218 transforms.Normalize( 219 self.mean(self.task_id), self.std(self.task_id) 220 ), 221 ], 222 ) 223 ) 224 ) # the order of transforms matters 225 226 def target_transforms(self) -> transforms.Compose: 227 r"""The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset. 228 229 **Args:** 230 - **target** (`Tensor`): the target tensor. 231 232 **Returns:** 233 - **target_transforms** (`transforms.Compose`): the transformed target tensor. 234 """ 235 236 return transforms.Compose( 237 list( 238 filter( 239 None, 240 [ 241 self.custom_target_transforms, 242 self.cl_class_mapping_t, 243 ], 244 ) 245 ) 246 ) # the order of transforms matters 247 248 @abstractmethod 249 def train_and_val_dataset(self) -> Any: 250 r"""Get the training and validation dataset of task `self.task_id`. It must be implemented by subclasses. 251 252 **Returns:** 253 - **train_and_val_dataset** (`Any`): the train and validation dataset of task `self.task_id`. 254 """ 255 256 @abstractmethod 257 def test_dataset(self) -> Any: 258 """Get the test dataset of task `self.task_id`. It must be implemented by subclasses. 259 260 **Returns:** 261 - **train_and_val_dataset** (`Any`): the test dataset of task `self.task_id`. 262 """ 263 264 def train_dataloader(self) -> DataLoader: 265 r"""DataLoader generator for stage train of task `self.task_id`. It is automatically called before training. 266 267 **Returns:** 268 - **train_dataloader** (`Dataloader`): the train DataLoader of task `self.task_id`. 269 """ 270 271 pylogger.debug("Construct train dataloader for task %d...", self.task_id) 272 273 return DataLoader( 274 dataset=self.dataset_train, 275 batch_size=self.batch_size, 276 shuffle=True, # shuffle train batch to prevent overfitting 277 num_workers=self.num_workers, 278 persistent_workers=True, 279 ) 280 281 def val_dataloader(self) -> DataLoader: 282 r"""DataLoader generator for stage validate. It is automatically called before validation. 283 284 **Returns:** 285 - **val_dataloader** (`Dataloader`): the validation DataLoader of task `self.task_id`. 286 """ 287 288 pylogger.debug("Construct validation dataloader for task %d...", self.task_id) 289 290 return DataLoader( 291 dataset=self.dataset_val, 292 batch_size=self.batch_size, 293 shuffle=False, # don't have to shuffle val or test batch 294 num_workers=self.num_workers, 295 persistent_workers=True, 296 ) 297 298 def test_dataloader(self) -> dict[int, DataLoader]: 299 r"""DataLoader generator for stage test. It is automatically called before testing. 300 301 **Returns:** 302 - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs (integer type) and values are the DataLoaders. 303 """ 304 305 pylogger.debug("Construct test dataloader for task %d...", self.task_id) 306 307 return { 308 task_id: DataLoader( 309 dataset=dataset_test, 310 batch_size=self.batch_size, 311 shuffle=False, # don't have to shuffle val or test batch 312 num_workers=self.num_workers, 313 persistent_workers=True, # speed up the dataloader worker initialization 314 ) 315 for task_id, dataset_test in self.dataset_test.items() 316 }
The base class of continual learning datasets, inherited from LightningDataModule
.
24 def __init__( 25 self, 26 root: str, 27 num_tasks: int, 28 validation_percentage: float, 29 batch_size: int = 1, 30 num_workers: int = 10, 31 custom_transforms: Callable | transforms.Compose | None = None, 32 custom_target_transforms: Callable | transforms.Compose | None = None, 33 ) -> None: 34 r"""Initialise the CL dataset object providing the root where data files live. 35 36 **Args:** 37 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 38 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 39 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 40 - **batch_size** (`int`): The batch size in train, val, test dataloader. 41 - **num_workers** (`int`): the number of workers for dataloaders. 42 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 43 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 44 """ 45 LightningDataModule.__init__(self) 46 47 self.root: str = root 48 r"""Store the root directory of the original data files. Used when constructing the dataset.""" 49 self.num_tasks: int = num_tasks 50 r"""Store the maximum number of tasks supported by the dataset.""" 51 self.validation_percentage: float = validation_percentage 52 r"""Store the percentage to randomly split some of the training data into validation data.""" 53 self.batch_size: int = batch_size 54 r"""Store the batch size. Used when constructing train, val, test dataloader.""" 55 self.num_workers: int = num_workers 56 r"""Store the number of workers. Used when constructing train, val, test dataloader.""" 57 self.custom_transforms: Callable | transforms.Compose | None = custom_transforms 58 r"""Store the custom transforms other than the basics. Used when constructing the dataset.""" 59 self.custom_target_transforms: Callable | transforms.Compose | None = ( 60 custom_target_transforms 61 ) 62 r"""Store the custom target transforms other than the CL class mapping. Used when constructing the dataset.""" 63 64 self.task_id: int 65 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 66 self.cl_paradigm: str 67 r"""Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from `set_cl_paradigm` and used to define the CL class map.""" 68 69 self.cl_class_map_t: dict[str | int, int] 70 r"""Store the CL class map for the current task `self.task_id`. """ 71 self.cl_class_mapping_t: Callable 72 r"""Store the CL class mapping transform for the current task `self.task_id`. """ 73 74 self.dataset_train: object 75 r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 76 self.dataset_val: object 77 r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 78 self.dataset_test: dict[str, object] = {} 79 r"""The dictionary to store test dataset object. Keys are task IDs (string type) and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.""" 80 81 CLDataset.sanity_check(self)
Initialise the CL dataset object providing the root where data files live.
Args:
- root (
str
): the root directory where the original data files for constructing the CL dataset physically live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
Store the percentage to randomly split some of the training data into validation data.
Store the custom transforms other than the basics. Used when constructing the dataset.
Store the custom target transforms other than the CL class mapping. Used when constructing the dataset.
Task ID counter indicating which task is being processed. Self updated during the task loop.
Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from set_cl_paradigm
and used to define the CL class map.
Store the CL class mapping transform for the current task self.task_id
.
The training dataset object. Can be a PyTorch Dataset object or any other dataset object.
The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.
The dictionary to store test dataset object. Keys are task IDs (string type) and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
83 def sanity_check(self) -> None: 84 r"""Check the sanity of the arguments. 85 86 **Raises:** 87 - **ValueError**: when the `validation_percentage` is not in the range of 0-1. 88 """ 89 if not 0.0 < self.validation_percentage < 1.0: 90 raise ValueError("The validation_percentage should be 0-1.")
Check the sanity of the arguments.
Raises:
- ValueError: when the
validation_percentage
is not in the range of 0-1.
92 @abstractmethod 93 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 94 r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses. 95 96 **Args:** 97 - **task_id** (`int`): The task ID to query CL class map. 98 99 **Returns:** 100 - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning. 101 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 102 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 103 """
The mapping of classes of task task_id
to fit continual learning settings self.cl_paradigm
. It must be implemented by subclasses.
Args:
- task_id (
int
): The task ID to query CL class map.
Returns:
- cl_class_map(
dict[str | int, int]
): the CL class map of the task. Key is original class label, value is integer class label for continual learning.- If
self.cl_paradigm
is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigm
is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
105 @abstractmethod 106 def prepare_data(self) -> None: 107 r"""Use this to download and prepare data. It must be implemented by subclasses, regulated by `LightningDatamodule`."""
Use this to download and prepare data. It must be implemented by subclasses, regulated by LightningDatamodule
.
109 def setup(self, stage: str) -> None: 110 r"""Set up the dataset for different stages. 111 112 **Args:** 113 - **stage** (`str`): the stage of the experiment. Should be one of the following: 114 - 'fit' or 'validation': training and validation dataset of current task `self.task_id` should be assigned to `self.dataset_train` and `self.dataset_val`. 115 - 'test': a list of test dataset of all seen tasks (from task 0 to `self.task_id`) should be assigned to `self.dataset_test`. 116 """ 117 if stage == "fit" or "validate": 118 119 pylogger.debug( 120 "Construct train and validation dataset for task %d...", self.task_id 121 ) 122 self.dataset_train, self.dataset_val = self.train_and_val_dataset() 123 self.dataset_train.target_transform = ( 124 self.target_transforms() 125 ) # apply target transform after potential class split 126 self.dataset_val.target_transform = ( 127 self.target_transforms() 128 ) # apply target transform after potential class split 129 pylogger.debug( 130 "Train and validation dataset for task %d are ready.", self.task_id 131 ) 132 133 if stage == "test": 134 135 pylogger.debug("Construct test dataset for task %d...", self.task_id) 136 self.dataset_test[f"{self.task_id}"] = self.test_dataset() 137 self.dataset_test[f"{self.task_id}"].target_transform = ( 138 self.target_transforms() 139 ) # apply target transform after potential class split 140 pylogger.debug("Test dataset for task %d are ready.", self.task_id)
Set up the dataset for different stages.
Args:
- stage (
str
): the stage of the experiment. Should be one of the following:- 'fit' or 'validation': training and validation dataset of current task
self.task_id
should be assigned toself.dataset_train
andself.dataset_val
. - 'test': a list of test dataset of all seen tasks (from task 0 to
self.task_id
) should be assigned toself.dataset_test
.
- 'fit' or 'validation': training and validation dataset of current task
142 def setup_task_id(self, task_id: int) -> None: 143 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 144 145 **Args:** 146 - **task_id** (`int`): the target task ID. 147 """ 148 self.task_id = task_id 149 150 self.cl_class_map_t = self.cl_class_map(task_id) 151 self.cl_class_mapping_t = CLClassMapping(self.cl_class_map_t)
Set up which task's dataset the CL experiment is on. This must be done before setup()
method is called.
Args:
- task_id (
int
): the target task ID.
153 def set_cl_paradigm(self, cl_paradigm: str) -> None: 154 r"""Set the continual learning paradigm to `self.cl_paradigm`. It is used to define the CL class map. 155 156 **Args:** 157 - **cl_paradigm** (`str`): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). 158 """ 159 self.cl_paradigm = cl_paradigm
Set the continual learning paradigm to self.cl_paradigm
. It is used to define the CL class map.
Args:
- cl_paradigm (
str
): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
161 @abstractmethod 162 def mean(self, task_id: int) -> tuple[float]: 163 r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. 164 165 **Returns:** 166 - **mean** (`tuple[float]`): the mean values for normalisation. 167 """
The mean values for normalisation of task task_id
. Used when constructing the dataset.
Returns:
- mean (
tuple[float]
): the mean values for normalisation.
169 @abstractmethod 170 def std(self, task_id: int) -> tuple[float]: 171 r"""The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. 172 173 **Returns:** 174 - **std** (`tuple[float]`): the standard deviation values for normalisation. 175 """
The standard deviation values for normalisation of task task_id
. Used when constructing the dataset.
Returns:
- std (
tuple[float]
): the standard deviation values for normalisation.
177 def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose: 178 r"""Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. It is a handy tool to use in subclasses when constructing the dataset. 179 180 **Args:** 181 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 182 183 **Returns:** 184 - **train_and_val_transforms** (`transforms.Compose`): the composed training transforms. 185 """ 186 187 return transforms.Compose( 188 list( 189 filter( 190 None, 191 [ 192 transforms.ToTensor() if to_tensor else None, 193 self.custom_transforms, 194 transforms.Normalize( 195 self.mean(self.task_id), self.std(self.task_id) 196 ), 197 ], 198 ) 199 ) 200 ) # the order of transforms matters
Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like normalisation
and ToTensor()
. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- to_tensor (
bool
): whether to includeToTensor()
transform.
Returns:
- train_and_val_transforms (
transforms.Compose
): the composed training transforms.
202 def test_transforms(self, to_tensor: bool) -> transforms.Compose: 203 r"""Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset. 204 205 **Args:** 206 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 207 208 **Returns:** 209 - **test_transforms** (`transforms.Compose`): the composed training transforms. 210 """ 211 212 return transforms.Compose( 213 list( 214 filter( 215 None, 216 [ 217 transforms.ToTensor() if to_tensor else None, 218 transforms.Normalize( 219 self.mean(self.task_id), self.std(self.task_id) 220 ), 221 ], 222 ) 223 ) 224 ) # the order of transforms matters
Transforms generator for test dataset. Only basic transforms like normalisation
and ToTensor()
are included. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- to_tensor (
bool
): whether to includeToTensor()
transform.
Returns:
- test_transforms (
transforms.Compose
): the composed training transforms.
226 def target_transforms(self) -> transforms.Compose: 227 r"""The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset. 228 229 **Args:** 230 - **target** (`Tensor`): the target tensor. 231 232 **Returns:** 233 - **target_transforms** (`transforms.Compose`): the transformed target tensor. 234 """ 235 236 return transforms.Compose( 237 list( 238 filter( 239 None, 240 [ 241 self.custom_target_transforms, 242 self.cl_class_mapping_t, 243 ], 244 ) 245 ) 246 ) # the order of transforms matters
The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- target (
Tensor
): the target tensor.
Returns:
- target_transforms (
transforms.Compose
): the transformed target tensor.
248 @abstractmethod 249 def train_and_val_dataset(self) -> Any: 250 r"""Get the training and validation dataset of task `self.task_id`. It must be implemented by subclasses. 251 252 **Returns:** 253 - **train_and_val_dataset** (`Any`): the train and validation dataset of task `self.task_id`. 254 """
Get the training and validation dataset of task self.task_id
. It must be implemented by subclasses.
Returns:
- train_and_val_dataset (
Any
): the train and validation dataset of taskself.task_id
.
256 @abstractmethod 257 def test_dataset(self) -> Any: 258 """Get the test dataset of task `self.task_id`. It must be implemented by subclasses. 259 260 **Returns:** 261 - **train_and_val_dataset** (`Any`): the test dataset of task `self.task_id`. 262 """
Get the test dataset of task self.task_id
. It must be implemented by subclasses.
Returns:
- train_and_val_dataset (
Any
): the test dataset of taskself.task_id
.
264 def train_dataloader(self) -> DataLoader: 265 r"""DataLoader generator for stage train of task `self.task_id`. It is automatically called before training. 266 267 **Returns:** 268 - **train_dataloader** (`Dataloader`): the train DataLoader of task `self.task_id`. 269 """ 270 271 pylogger.debug("Construct train dataloader for task %d...", self.task_id) 272 273 return DataLoader( 274 dataset=self.dataset_train, 275 batch_size=self.batch_size, 276 shuffle=True, # shuffle train batch to prevent overfitting 277 num_workers=self.num_workers, 278 persistent_workers=True, 279 )
DataLoader generator for stage train of task self.task_id
. It is automatically called before training.
Returns:
- train_dataloader (
Dataloader
): the train DataLoader of taskself.task_id
.
281 def val_dataloader(self) -> DataLoader: 282 r"""DataLoader generator for stage validate. It is automatically called before validation. 283 284 **Returns:** 285 - **val_dataloader** (`Dataloader`): the validation DataLoader of task `self.task_id`. 286 """ 287 288 pylogger.debug("Construct validation dataloader for task %d...", self.task_id) 289 290 return DataLoader( 291 dataset=self.dataset_val, 292 batch_size=self.batch_size, 293 shuffle=False, # don't have to shuffle val or test batch 294 num_workers=self.num_workers, 295 persistent_workers=True, 296 )
DataLoader generator for stage validate. It is automatically called before validation.
Returns:
- val_dataloader (
Dataloader
): the validation DataLoader of taskself.task_id
.
298 def test_dataloader(self) -> dict[int, DataLoader]: 299 r"""DataLoader generator for stage test. It is automatically called before testing. 300 301 **Returns:** 302 - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs (integer type) and values are the DataLoaders. 303 """ 304 305 pylogger.debug("Construct test dataloader for task %d...", self.task_id) 306 307 return { 308 task_id: DataLoader( 309 dataset=dataset_test, 310 batch_size=self.batch_size, 311 shuffle=False, # don't have to shuffle val or test batch 312 num_workers=self.num_workers, 313 persistent_workers=True, # speed up the dataloader worker initialization 314 ) 315 for task_id, dataset_test in self.dataset_test.items() 316 }
DataLoader generator for stage test. It is automatically called before testing.
Returns:
- test_dataloader (
dict[int, DataLoader]
): the test DataLoader dict ofself.task_id
and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs (integer type) and values are the DataLoaders.
319class CLPermutedDataset(CLDataset): 320 r"""The base class of continual learning datasets which are constructed as permutations from an original dataset, inherited from `CLDataset`.""" 321 322 num_classes: int 323 r"""The number of classes in the original dataset before permutation. It must be provided in subclasses.""" 324 325 img_size: torch.Size 326 r"""The size of images in the original dataset before permutation. Used when constructing permutation operations. It must be provided in subclasses.""" 327 328 mean_original: tuple[float] 329 r"""The mean values for normalisation. It must be provided in subclasses.""" 330 331 std_original: tuple[float] 332 r"""The standard deviation values for normalisation. It must be provided in subclasses.""" 333 334 def __init__( 335 self, 336 root: str, 337 num_tasks: int, 338 validation_percentage: float, 339 batch_size: int = 1, 340 num_workers: int = 10, 341 custom_transforms: Callable | transforms.Compose | None = None, 342 custom_target_transforms: Callable | transforms.Compose | None = None, 343 permutation_mode: str = "first_channel_only", 344 permutation_seeds: list[int] | None = None, 345 ) -> None: 346 r"""Initialise the CL dataset object providing the root where data files live. 347 348 **Args:** 349 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 350 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 351 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 352 - **batch_size** (`int`): The batch size in train, val, test dataloader. 353 - **num_workers** (`int`): the number of workers for dataloaders. 354 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 355 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 356 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 357 1. 'all': permute all pixels. 358 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 359 3. 'first_channel_only': permute only the first channel. 360 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 361 """ 362 CLDataset.__init__( 363 self, 364 root=root, 365 num_tasks=num_tasks, 366 validation_percentage=validation_percentage, 367 batch_size=batch_size, 368 num_workers=num_workers, 369 custom_transforms=custom_transforms, 370 custom_target_transforms=custom_target_transforms, 371 ) 372 373 self.permutation_mode: str = permutation_mode 374 r"""Store the mode of permutation. Used when permutation operations used to construct tasks. """ 375 376 self.permutation_seeds: list[int] = ( 377 permutation_seeds if permutation_seeds else list(range(num_tasks)) 378 ) 379 r"""Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks. """ 380 381 self.permutation_seed_t: int 382 r"""Store the permutation seed for the current task `self.task_id`.""" 383 self.permute_t: Permute 384 r"""Store the permutation transform for the current task `self.task_id`. """ 385 386 CLPermutedDataset.sanity_check(self) 387 388 def sanity_check(self) -> None: 389 r"""Check the sanity of the arguments. 390 391 **Raises:** 392 - **ValueError**: when the `permutation_seeds` is not equal to `num_tasks`, or the `permutation_mode` is not one of the valid options. 393 """ 394 if self.permutation_seeds and self.num_tasks != len(self.permutation_seeds): 395 raise ValueError( 396 "The number of permutation seeds is not equal to number of tasks!" 397 ) 398 if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]: 399 raise ValueError( 400 "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'." 401 ) 402 403 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 404 r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 405 406 **Args:** 407 - **task_id** (`int`): The task ID to query CL class map. 408 409 **Returns:** 410 - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning. 411 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 412 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 413 """ 414 if self.cl_paradigm == "TIL": 415 return {i: i for i in range(self.num_classes)} 416 if self.cl_paradigm == "CIL": 417 return { 418 i: i + (task_id - 1) * self.num_classes for i in range(self.num_classes) 419 } 420 421 def setup_task_id(self, task_id: int) -> None: 422 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 423 424 **Args:** 425 - **task_id** (`int`): the target task ID. 426 """ 427 super().setup_task_id(task_id) 428 429 self.permutation_seed_t = self.permutation_seeds[task_id - 1] 430 self.permute_t = Permute( 431 img_size=self.img_size, 432 mode=self.permutation_mode, 433 seed=self.permutation_seed_t, 434 ) 435 436 def mean(self, task_id: int) -> tuple[float]: 437 r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset. 438 439 **Returns:** 440 - **mean** (`tuple[float]`): the mean values for normalisation. 441 """ 442 return self.mean_original 443 444 def std(self, task_id: int) -> tuple[float]: 445 """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset. 446 447 **Returns:** 448 - **std** (`tuple[float]`): the standard deviation values for normalisation. 449 """ 450 return self.std_original 451 452 def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose: 453 r"""Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset. 454 455 **Args:** 456 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 457 458 **Returns:** 459 - **train_and_val_transforms** (`transforms.Compose`): the composed training transforms. 460 """ 461 462 return transforms.Compose( 463 list( 464 filter( 465 None, 466 [ 467 transforms.ToTensor() if to_tensor else None, 468 self.permute_t, 469 self.custom_transforms, 470 transforms.Normalize( 471 self.mean(self.task_id), self.std(self.task_id) 472 ), 473 ], 474 ) 475 ) 476 ) # the order of transforms matters 477 478 def test_transforms(self, to_tensor: bool) -> transforms.Compose: 479 r"""Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset. 480 481 **Args:** 482 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 483 484 **Returns:** 485 - **test_transforms** (`transforms.Compose`): the composed training transforms. 486 """ 487 488 return transforms.Compose( 489 list( 490 filter( 491 None, 492 [ 493 transforms.ToTensor() if to_tensor else None, 494 self.permute_t, 495 transforms.Normalize( 496 self.mean(self.task_id), self.std(self.task_id) 497 ), 498 ], 499 ) 500 ) 501 ) # the order of transforms matters
The base class of continual learning datasets which are constructed as permutations from an original dataset, inherited from CLDataset
.
334 def __init__( 335 self, 336 root: str, 337 num_tasks: int, 338 validation_percentage: float, 339 batch_size: int = 1, 340 num_workers: int = 10, 341 custom_transforms: Callable | transforms.Compose | None = None, 342 custom_target_transforms: Callable | transforms.Compose | None = None, 343 permutation_mode: str = "first_channel_only", 344 permutation_seeds: list[int] | None = None, 345 ) -> None: 346 r"""Initialise the CL dataset object providing the root where data files live. 347 348 **Args:** 349 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 350 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 351 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 352 - **batch_size** (`int`): The batch size in train, val, test dataloader. 353 - **num_workers** (`int`): the number of workers for dataloaders. 354 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 355 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 356 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 357 1. 'all': permute all pixels. 358 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 359 3. 'first_channel_only': permute only the first channel. 360 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 361 """ 362 CLDataset.__init__( 363 self, 364 root=root, 365 num_tasks=num_tasks, 366 validation_percentage=validation_percentage, 367 batch_size=batch_size, 368 num_workers=num_workers, 369 custom_transforms=custom_transforms, 370 custom_target_transforms=custom_target_transforms, 371 ) 372 373 self.permutation_mode: str = permutation_mode 374 r"""Store the mode of permutation. Used when permutation operations used to construct tasks. """ 375 376 self.permutation_seeds: list[int] = ( 377 permutation_seeds if permutation_seeds else list(range(num_tasks)) 378 ) 379 r"""Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks. """ 380 381 self.permutation_seed_t: int 382 r"""Store the permutation seed for the current task `self.task_id`.""" 383 self.permute_t: Permute 384 r"""Store the permutation transform for the current task `self.task_id`. """ 385 386 CLPermutedDataset.sanity_check(self)
Initialise the CL dataset object providing the root where data files live.
Args:
- root (
str
): the root directory where the original data files for constructing the CL dataset physically live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. - permutation_mode (
str
): the mode of permutation, should be one of the following:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- permutation_seeds (
list[int]
orNone
): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds asnum_tasks
. Default is None, which creates a list of seeds from 1 tonum_tasks
.
The number of classes in the original dataset before permutation. It must be provided in subclasses.
The size of images in the original dataset before permutation. Used when constructing permutation operations. It must be provided in subclasses.
The standard deviation values for normalisation. It must be provided in subclasses.
Store the mode of permutation. Used when permutation operations used to construct tasks.
Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks.
388 def sanity_check(self) -> None: 389 r"""Check the sanity of the arguments. 390 391 **Raises:** 392 - **ValueError**: when the `permutation_seeds` is not equal to `num_tasks`, or the `permutation_mode` is not one of the valid options. 393 """ 394 if self.permutation_seeds and self.num_tasks != len(self.permutation_seeds): 395 raise ValueError( 396 "The number of permutation seeds is not equal to number of tasks!" 397 ) 398 if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]: 399 raise ValueError( 400 "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'." 401 )
Check the sanity of the arguments.
Raises:
- ValueError: when the
permutation_seeds
is not equal tonum_tasks
, or thepermutation_mode
is not one of the valid options.
403 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 404 r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 405 406 **Args:** 407 - **task_id** (`int`): The task ID to query CL class map. 408 409 **Returns:** 410 - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning. 411 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 412 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 413 """ 414 if self.cl_paradigm == "TIL": 415 return {i: i for i in range(self.num_classes)} 416 if self.cl_paradigm == "CIL": 417 return { 418 i: i + (task_id - 1) * self.num_classes for i in range(self.num_classes) 419 }
The mapping of classes of task task_id
to fit continual learning settings self.cl_paradigm
.
Args:
- task_id (
int
): The task ID to query CL class map.
Returns:
- cl_class_map(
dict[str | int, int]
): the CL class map of the task. Key is original class label, value is integer class label for continual learning.- If
self.cl_paradigm
is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigm
is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
421 def setup_task_id(self, task_id: int) -> None: 422 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 423 424 **Args:** 425 - **task_id** (`int`): the target task ID. 426 """ 427 super().setup_task_id(task_id) 428 429 self.permutation_seed_t = self.permutation_seeds[task_id - 1] 430 self.permute_t = Permute( 431 img_size=self.img_size, 432 mode=self.permutation_mode, 433 seed=self.permutation_seed_t, 434 )
Set up which task's dataset the CL experiment is on. This must be done before setup()
method is called.
Args:
- task_id (
int
): the target task ID.
436 def mean(self, task_id: int) -> tuple[float]: 437 r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset. 438 439 **Returns:** 440 - **mean** (`tuple[float]`): the mean values for normalisation. 441 """ 442 return self.mean_original
The mean values for normalisation of task task_id
. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
Returns:
- mean (
tuple[float]
): the mean values for normalisation.
444 def std(self, task_id: int) -> tuple[float]: 445 """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset. 446 447 **Returns:** 448 - **std** (`tuple[float]`): the standard deviation values for normalisation. 449 """ 450 return self.std_original
The standard deviation values for normalisation of task task_id
. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
Returns:
- std (
tuple[float]
): the standard deviation values for normalisation.
452 def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose: 453 r"""Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset. 454 455 **Args:** 456 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 457 458 **Returns:** 459 - **train_and_val_transforms** (`transforms.Compose`): the composed training transforms. 460 """ 461 462 return transforms.Compose( 463 list( 464 filter( 465 None, 466 [ 467 transforms.ToTensor() if to_tensor else None, 468 self.permute_t, 469 self.custom_transforms, 470 transforms.Normalize( 471 self.mean(self.task_id), self.std(self.task_id) 472 ), 473 ], 474 ) 475 ) 476 ) # the order of transforms matters
Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like normalisation
and ToTensor()
. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- to_tensor (
bool
): whether to includeToTensor()
transform.
Returns:
- train_and_val_transforms (
transforms.Compose
): the composed training transforms.
478 def test_transforms(self, to_tensor: bool) -> transforms.Compose: 479 r"""Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset. 480 481 **Args:** 482 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 483 484 **Returns:** 485 - **test_transforms** (`transforms.Compose`): the composed training transforms. 486 """ 487 488 return transforms.Compose( 489 list( 490 filter( 491 None, 492 [ 493 transforms.ToTensor() if to_tensor else None, 494 self.permute_t, 495 transforms.Normalize( 496 self.mean(self.task_id), self.std(self.task_id) 497 ), 498 ], 499 ) 500 ) 501 ) # the order of transforms matters
Transforms generator for test dataset. Only basic transforms like normalisation
and ToTensor()
are included. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- to_tensor (
bool
): whether to includeToTensor()
transform.
Returns:
- test_transforms (
transforms.Compose
): the composed training transforms.
504class CLSplitDataset(CLDataset): 505 r"""The base class of continual learning datasets, which are constructed as permutations from an original dataset, inherited from `CLDataset`.""" 506 507 num_classes: int 508 r"""The number of classes in the original dataset before permutation. It must be provided in subclasses.""" 509 510 mean_original: tuple[float] 511 r"""The mean values for normalisation. It must be provided in subclasses.""" 512 513 std_original: tuple[float] 514 r"""The standard deviation values for normalisation. It must be provided in subclasses.""" 515 516 def __init__( 517 self, 518 root: str, 519 num_tasks: int, 520 class_split: list[list[int]], 521 validation_percentage: float, 522 batch_size: int = 1, 523 num_workers: int = 10, 524 custom_transforms: Callable | transforms.Compose | None = None, 525 custom_target_transforms: Callable | transforms.Compose | None = None, 526 ) -> None: 527 r"""Initialise the CL dataset object providing the root where data files live. 528 529 **Args:** 530 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 531 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 532 - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. 533 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 534 - **batch_size** (`int`): The batch size in train, val, test dataloader. 535 - **num_workers** (`int`): the number of workers for dataloaders. 536 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 537 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 538 """ 539 CLDataset.__init__( 540 self, 541 root=root, 542 num_tasks=num_tasks, 543 validation_percentage=validation_percentage, 544 batch_size=batch_size, 545 num_workers=num_workers, 546 custom_transforms=custom_transforms, 547 custom_target_transforms=custom_target_transforms, 548 ) 549 550 self.class_split = class_split 551 r"""Store the class split for each task. Used when constructing the split dataset.""" 552 553 CLSplitDataset.sanity_check(self) 554 555 def sanity_check(self) -> None: 556 r"""Check the sanity of the arguments. 557 558 **Raises:** 559 - **ValueError**: when the length of `class_split` is not equal to `num_tasks`. 560 - **ValueError**: when any of the lists in `class_split` has less than 2 elements. A classification task must have less than 2 classes. 561 """ 562 if len(self.class_split) != self.num_tasks: 563 raise ValueError( 564 "The length of class split is not equal to number of tasks!" 565 ) 566 if any(len(split) < 2 for split in self.class_split): 567 raise ValueError("Each class split must contain at least 2 elements!") 568 569 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 570 r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 571 572 **Args:** 573 - **task_id** (`int`): The task ID to query CL class map. 574 575 **Returns:** 576 - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning. 577 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 578 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 579 """ 580 num_classes_t = len(self.class_split[task_id - 1]) 581 if self.cl_paradigm == "TIL": 582 return {self.class_split[task_id - 1][i]: i for i in range(num_classes_t)} 583 if self.cl_paradigm == "CIL": 584 num_classes_previous = sum( 585 [len(self.class_split[i]) for i in range(self.task_id - 1)] 586 ) 587 return { 588 self.class_split[task_id - 1][i]: i 589 + (task_id - 1) * num_classes_previous 590 for i in range(num_classes_t) 591 } 592 593 def mean(self, task_id: int) -> tuple[float]: 594 r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset. 595 596 **Returns:** 597 - **mean** (`tuple[float]`): the mean values for normalisation. 598 """ 599 return self.mean_original 600 601 def std(self, task_id: int) -> tuple[float]: 602 r"""The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset. 603 604 **Returns:** 605 - **std** (`tuple[float]`): he standard deviation values for normalisation. 606 """ 607 return self.std_original 608 609 def get_class_subset(self, dataset: Dataset) -> Dataset: 610 r"""Provide a util method here to retrieve a subset from PyTorch Dataset of current classes of `self.task_id`. It could be useful when you constructing the split CL dataset. 611 612 **Args:** 613 - **dataset** (`Dataset`): the original dataset to retrieve subset from. 614 615 **Returns:** 616 - **subset** (`Dataset`): subset of original dataset in classes. 617 """ 618 classes = self.class_split[self.task_id - 1] 619 620 # get the indices of the dataset that belong to the classes 621 idx = [i for i, (_, target) in enumerate(dataset) if target in classes] 622 623 # subset the dataset by the indices, in-place operation 624 dataset.data = dataset.data[idx] 625 dataset.targets = [dataset.targets[i] for i in idx] 626 627 return dataset
The base class of continual learning datasets, which are constructed as permutations from an original dataset, inherited from CLDataset
.
516 def __init__( 517 self, 518 root: str, 519 num_tasks: int, 520 class_split: list[list[int]], 521 validation_percentage: float, 522 batch_size: int = 1, 523 num_workers: int = 10, 524 custom_transforms: Callable | transforms.Compose | None = None, 525 custom_target_transforms: Callable | transforms.Compose | None = None, 526 ) -> None: 527 r"""Initialise the CL dataset object providing the root where data files live. 528 529 **Args:** 530 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 531 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 532 - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. 533 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 534 - **batch_size** (`int`): The batch size in train, val, test dataloader. 535 - **num_workers** (`int`): the number of workers for dataloaders. 536 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 537 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 538 """ 539 CLDataset.__init__( 540 self, 541 root=root, 542 num_tasks=num_tasks, 543 validation_percentage=validation_percentage, 544 batch_size=batch_size, 545 num_workers=num_workers, 546 custom_transforms=custom_transforms, 547 custom_target_transforms=custom_target_transforms, 548 ) 549 550 self.class_split = class_split 551 r"""Store the class split for each task. Used when constructing the split dataset.""" 552 553 CLSplitDataset.sanity_check(self)
Initialise the CL dataset object providing the root where data files live.
Args:
- root (
str
): the root directory where the original data files for constructing the CL dataset physically live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - class_split (
list[list[int]]
): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
The number of classes in the original dataset before permutation. It must be provided in subclasses.
The standard deviation values for normalisation. It must be provided in subclasses.
555 def sanity_check(self) -> None: 556 r"""Check the sanity of the arguments. 557 558 **Raises:** 559 - **ValueError**: when the length of `class_split` is not equal to `num_tasks`. 560 - **ValueError**: when any of the lists in `class_split` has less than 2 elements. A classification task must have less than 2 classes. 561 """ 562 if len(self.class_split) != self.num_tasks: 563 raise ValueError( 564 "The length of class split is not equal to number of tasks!" 565 ) 566 if any(len(split) < 2 for split in self.class_split): 567 raise ValueError("Each class split must contain at least 2 elements!")
Check the sanity of the arguments.
Raises:
- ValueError: when the length of
class_split
is not equal tonum_tasks
. - ValueError: when any of the lists in
class_split
has less than 2 elements. A classification task must have less than 2 classes.
569 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 570 r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 571 572 **Args:** 573 - **task_id** (`int`): The task ID to query CL class map. 574 575 **Returns:** 576 - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning. 577 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 578 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 579 """ 580 num_classes_t = len(self.class_split[task_id - 1]) 581 if self.cl_paradigm == "TIL": 582 return {self.class_split[task_id - 1][i]: i for i in range(num_classes_t)} 583 if self.cl_paradigm == "CIL": 584 num_classes_previous = sum( 585 [len(self.class_split[i]) for i in range(self.task_id - 1)] 586 ) 587 return { 588 self.class_split[task_id - 1][i]: i 589 + (task_id - 1) * num_classes_previous 590 for i in range(num_classes_t) 591 }
The mapping of classes of task task_id
to fit continual learning settings self.cl_paradigm
.
Args:
- task_id (
int
): The task ID to query CL class map.
Returns:
- cl_class_map(
dict[str | int, int]
): the CL class map of the task. Key is original class label, value is integer class label for continual learning.- If
self.cl_paradigm
is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigm
is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
593 def mean(self, task_id: int) -> tuple[float]: 594 r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset. 595 596 **Returns:** 597 - **mean** (`tuple[float]`): the mean values for normalisation. 598 """ 599 return self.mean_original
The mean values for normalisation of task task_id
. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
Returns:
- mean (
tuple[float]
): the mean values for normalisation.
601 def std(self, task_id: int) -> tuple[float]: 602 r"""The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset. 603 604 **Returns:** 605 - **std** (`tuple[float]`): he standard deviation values for normalisation. 606 """ 607 return self.std_original
The standard deviation values for normalisation of task task_id
. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
Returns:
- std (
tuple[float]
): he standard deviation values for normalisation.
609 def get_class_subset(self, dataset: Dataset) -> Dataset: 610 r"""Provide a util method here to retrieve a subset from PyTorch Dataset of current classes of `self.task_id`. It could be useful when you constructing the split CL dataset. 611 612 **Args:** 613 - **dataset** (`Dataset`): the original dataset to retrieve subset from. 614 615 **Returns:** 616 - **subset** (`Dataset`): subset of original dataset in classes. 617 """ 618 classes = self.class_split[self.task_id - 1] 619 620 # get the indices of the dataset that belong to the classes 621 idx = [i for i, (_, target) in enumerate(dataset) if target in classes] 622 623 # subset the dataset by the indices, in-place operation 624 dataset.data = dataset.data[idx] 625 dataset.targets = [dataset.targets[i] for i in idx] 626 627 return dataset
Provide a util method here to retrieve a subset from PyTorch Dataset of current classes of self.task_id
. It could be useful when you constructing the split CL dataset.
Args:
- dataset (
Dataset
): the original dataset to retrieve subset from.
Returns:
- subset (
Dataset
): subset of original dataset in classes.
630class CLClassMapping: 631 r"""CL Class mapping to dataset labels. Used as a PyTorch target Transform.""" 632 633 def __init__(self, cl_class_map: dict[str | int, int]) -> None: 634 r"""Initialise the CL class mapping transform object from the CL class map of a task. 635 636 **Args:** 637 - **cl_class_map** (`dict[str | int, int]`): the CL class map for a task. 638 """ 639 self.cl_class_map = cl_class_map 640 641 def __call__(self, target: torch.Tensor) -> torch.Tensor: 642 r"""The CL class mapping transform to dataset labels. It is defined as a callable object like a PyTorch Transform. 643 644 **Args:** 645 - **target** (`Tensor`): the target tensor. 646 647 **Returns:** 648 - **transformed_target** (`Tensor`): the transformed target tensor. 649 """ 650 651 return self.cl_class_map[target]
CL Class mapping to dataset labels. Used as a PyTorch target Transform.
633 def __init__(self, cl_class_map: dict[str | int, int]) -> None: 634 r"""Initialise the CL class mapping transform object from the CL class map of a task. 635 636 **Args:** 637 - **cl_class_map** (`dict[str | int, int]`): the CL class map for a task. 638 """ 639 self.cl_class_map = cl_class_map
Initialise the CL class mapping transform object from the CL class map of a task.
Args:
- cl_class_map (
dict[str | int, int]
): the CL class map for a task.
654class Permute: 655 r"""Permutation operation to image. Used to construct permuted CL dataset. 656 657 Used as a PyTorch Dataset Transform. 658 """ 659 660 def __init__( 661 self, 662 img_size: torch.Size, 663 mode: str = "first_channel_only", 664 seed: int | None = None, 665 ) -> None: 666 r"""Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime. 667 668 **Args:** 669 - **img_size** (`torch.Size`): the size of the image to be permuted. 670 - **mode** (`str`): the mode of permutation, shouble be one of the following: 671 - 'all': permute all pixels. 672 - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 673 - 'first_channel_only': permute only the first channel. 674 - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator. 675 """ 676 self.mode = mode 677 r"""Store the mode of permutation.""" 678 679 # get generator for permutation 680 torch_generator = torch.Generator() 681 if seed: 682 torch_generator.manual_seed(seed) 683 684 # calculate the number of pixels from the image size 685 if self.mode == "all": 686 num_pixels = img_size[0] * img_size[1] * img_size[2] 687 elif self.mode == "by_channel" or "first_channel_only": 688 num_pixels = img_size[1] * img_size[2] 689 690 self.permute: torch.Tensor = torch.randperm( 691 num_pixels, generator=torch_generator 692 ) 693 r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation.""" 694 695 def __call__(self, img: torch.Tensor) -> torch.Tensor: 696 r"""The permutation operation to image is defined as a callable object like a PyTorch Transform. 697 698 **Args:** 699 - **img** (`Tensor`): image to be permuted. Must match the size of `img_size` in the initialisation. 700 701 **Returns:** 702 - **img_permuted** (`Tensor`): the permuted image. 703 """ 704 705 if self.mode == "all": 706 707 img_flat = img.view( 708 -1 709 ) # flatten the whole image to 1d so that it can be applied 1d permuted order 710 img_flat_permuted = img_flat[self.permute] # conduct permutation operation 711 img_permuted = img_flat_permuted.view( 712 img.size() 713 ) # return to the original image shape 714 return img_permuted 715 716 if self.mode == "by_channel": 717 718 permuted_channels = [] 719 for i in range(img.size(0)): 720 # act on every channel 721 channel_flat = img[i].view( 722 -1 723 ) # flatten the channel to 1d so that it can be applied 1d permuted order 724 channel_flat_permuted = channel_flat[ 725 self.permute 726 ] # conduct permutation operation 727 channel_permuted = channel_flat_permuted.view( 728 img[0].size() 729 ) # return to the original channel shape 730 permuted_channels.append(channel_permuted) 731 img_permuted = torch.stack( 732 permuted_channels 733 ) # stack the permuted channels to restore the image 734 return img_permuted 735 736 if self.mode == "first_channel_only": 737 738 first_channel_flat = img[0].view( 739 -1 740 ) # flatten the first channel to 1d so that it can be applied 1d permuted order 741 first_channel_flat_permuted = first_channel_flat[ 742 self.permute 743 ] # conduct permutation operation 744 first_channel_permuted = first_channel_flat_permuted.view( 745 img[0].size() 746 ) # return to the original channel shape 747 748 img_permuted = img.clone() 749 img_permuted[0] = first_channel_permuted 750 751 return img_permuted
Permutation operation to image. Used to construct permuted CL dataset.
Used as a PyTorch Dataset Transform.
660 def __init__( 661 self, 662 img_size: torch.Size, 663 mode: str = "first_channel_only", 664 seed: int | None = None, 665 ) -> None: 666 r"""Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime. 667 668 **Args:** 669 - **img_size** (`torch.Size`): the size of the image to be permuted. 670 - **mode** (`str`): the mode of permutation, shouble be one of the following: 671 - 'all': permute all pixels. 672 - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 673 - 'first_channel_only': permute only the first channel. 674 - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator. 675 """ 676 self.mode = mode 677 r"""Store the mode of permutation.""" 678 679 # get generator for permutation 680 torch_generator = torch.Generator() 681 if seed: 682 torch_generator.manual_seed(seed) 683 684 # calculate the number of pixels from the image size 685 if self.mode == "all": 686 num_pixels = img_size[0] * img_size[1] * img_size[2] 687 elif self.mode == "by_channel" or "first_channel_only": 688 num_pixels = img_size[1] * img_size[2] 689 690 self.permute: torch.Tensor = torch.randperm( 691 num_pixels, generator=torch_generator 692 ) 693 r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""
Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime.
Args:
- img_size (
torch.Size
): the size of the image to be permuted. - mode (
str
): the mode of permutation, shouble be one of the following:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- seed (
int
orNone
): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.