clarena.cl_datasets
Continual Learning Datasets
This submodule provides the continual learning datasets that can be used in CLArena.
Please note that this is an API documantation. Please refer to the main documentation page for more information about the CL datasets and how to use and customize them:
- Configure your CL dataset: https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/cl-dataset
- Implement your CL dataset: https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-dataset
- A beginners' guide to continual learning (CL dataset): https://pengxiang-wang.com/posts/continual-learning-beginners-guide#CL-dataset
The datasets are implemented as subclasses of CLDataset
classes, which are the base class for all continual learning datasets in CLArena.
CLDataset
: The base class for continual learning datasets.CLPermutedDataset
: The base class for permuted continual learning datasets. A child class ofCLDataset
.
1""" 2 3# Continual Learning Datasets 4 5This submodule provides the **continual learning datasets** that can be used in CLArena. 6 7Please note that this is an API documantation. Please refer to the main documentation page for more information about the CL datasets and how to use and customize them: 8 9- **Configure your CL dataset:** [https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/cl-dataset](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/cl-dataset) 10- **Implement your CL dataset:** [https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-dataset](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-dataset) 11- **A beginners' guide to continual learning (CL dataset):** [https://pengxiang-wang.com/posts/continual-learning-beginners-guide#CL-dataset](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#CL-dataset) 12 13The datasets are implemented as subclasses of `CLDataset` classes, which are the base class for all continual learning datasets in CLArena. 14 15- `CLDataset`: The base class for continual learning datasets. 16- `CLPermutedDataset`: The base class for permuted continual learning datasets. A child class of `CLDataset`. 17 18""" 19 20from .base import CLClassMapping, CLDataset, CLPermutedDataset, CLSplitDataset, Permute 21from .permuted_mnist import PermutedMNIST 22from .split_cifar100 import SplitCIFAR100 23 24__all__ = [ 25 "CLDataset", 26 "CLPermutedDataset", 27 "CLSplitDataset", 28 "CLClassMapping", 29 "Permute", 30 "permuted_mnist", 31 "split_cifar100", 32]
22class CLDataset(LightningDataModule): 23 """The base class of continual learning datasets, inherited from `LightningDataModule`.""" 24 25 def __init__( 26 self, 27 root: str, 28 num_tasks: int, 29 validation_percentage: float, 30 batch_size: int = 1, 31 num_workers: int = 10, 32 custom_transforms: Callable | transforms.Compose | None = None, 33 custom_target_transforms: Callable | transforms.Compose | None = None, 34 ): 35 """Initialise the CL dataset object providing the root where data files live. 36 37 **Args:** 38 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 39 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 40 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 41 - **batch_size** (`int`): The batch size in train, val, test dataloader. 42 - **num_workers** (`int`): the number of workers for dataloaders. 43 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 44 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 45 """ 46 super().__init__() 47 48 self.root: str = root 49 """Store the root directory of the original data files. Used when constructing the dataset.""" 50 self.num_tasks: int = num_tasks 51 """Store the maximum number of tasks supported by the dataset.""" 52 self.validation_percentage: float = validation_percentage 53 """Store the percentage to randomly split some of the training data into validation data.""" 54 self.batch_size: int = batch_size 55 """Store the batch size. Used when constructing train, val, test dataloader.""" 56 self.num_workers: int = num_workers 57 """Store the number of workers. Used when constructing train, val, test dataloader.""" 58 self.custom_transforms: Callable | transforms.Compose | None = custom_transforms 59 """Store the custom transforms other than the basics. Used when constructing the dataset.""" 60 self.custom_target_transforms: Callable | transforms.Compose | None = ( 61 custom_target_transforms 62 ) 63 """Store the custom target transforms other than the CL class mapping. Used when constructing the dataset.""" 64 65 self.task_id: int 66 """Task ID counter indicating which task is being processed. Self updated during the task loop.""" 67 self.cl_paradigm: str 68 """Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from `set_cl_paradigm` and used to define the CL class map.""" 69 70 self.cl_class_map_t: dict[str | int, int] 71 """Store the CL class map for the current task `self.task_id`. """ 72 self.cl_class_mapping_t: Callable 73 """Store the CL class mapping transform for the current task `self.task_id`. """ 74 75 self.dataset_train: object 76 """The training dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 77 self.dataset_val: object 78 """The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 79 self.dataset_test: dict[int, object] = {} 80 """The dictionary to store test dataset object. Key is task_id, value is the dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 81 82 self.sanity_check() 83 84 def sanity_check(self) -> None: 85 """Check the sanity of the arguments. 86 87 **Raises:** 88 - **ValueError**: when the `validation_percentage` is not in the range of 0-1. 89 """ 90 if not 0.0 < self.validation_percentage < 1.0: 91 raise ValueError("The validation_percentage should be 0-1.") 92 93 @abstractmethod 94 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 95 """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses. 96 97 **Args:** 98 - **task_id** (`int`): The task ID to query CL class map. 99 100 **Returns:** 101 - The CL class map of the task. Key is original class label, value is integer class label for continual learning. 102 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 103 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 104 105 """ 106 107 @abstractmethod 108 def prepare_data(self) -> None: 109 """Use this to download and prepare data. It must be implemented by subclasses, regulated by `LightningDatamodule`.""" 110 111 def setup(self, stage: str) -> None: 112 """Set up the dataset for different stages. 113 114 **Args:** 115 - **stage** (`str`): the stage of the experiment. Should be one of the following: 116 - 'fit' or 'validate': training and validation dataset of current task `self.task_id` should be assigned to `self.dataset_train` and `self.dataset_val`. 117 - 'test': a list of test dataset of all seen tasks (from task 0 to `self.task_id`) should be assigned to `self.dataset_test`. 118 """ 119 if stage == "fit" or "validate": 120 121 pylogger.debug( 122 "Construct train and validation dataset for task %d...", self.task_id 123 ) 124 self.dataset_train, self.dataset_val = self.train_and_val_dataset() 125 pylogger.debug( 126 "Train and validation dataset for task %d are ready.", self.task_id 127 ) 128 129 if stage == "test": 130 131 pylogger.debug("Construct test dataset for task %d...", self.task_id) 132 self.dataset_test[self.task_id] = self.test_dataset() 133 pylogger.debug("Test dataset for task %d are ready.", self.task_id) 134 135 def setup_task_id(self, task_id: int) -> None: 136 """Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 137 138 **Args:** 139 - **task_id** (`int`): the target task ID. 140 """ 141 self.task_id = task_id 142 143 self.cl_class_map_t = self.cl_class_map(task_id) 144 self.cl_class_mapping_t = CLClassMapping(self.cl_class_map_t) 145 146 def set_cl_paradigm(self, cl_paradigm: str) -> None: 147 """Set the continual learning paradigm to `self.cl_paradigm`. It is used to define the CL class map. 148 149 **Args:** 150 - **cl_paradigm** (`str`): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). 151 """ 152 self.cl_paradigm = cl_paradigm 153 154 @abstractmethod 155 def mean(self, task_id: int) -> tuple[float]: 156 """The mean values for normalisation of task `task_id`. Used when constructing the dataset. 157 158 **Returns:** 159 - The mean values for normalisation. 160 """ 161 162 @abstractmethod 163 def std(self, task_id: int) -> tuple[float]: 164 """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. 165 166 **Returns:** 167 - The standard deviation values for normalisation. 168 """ 169 170 def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose: 171 """Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. It is a handy tool to use in subclasses when constructing the dataset. 172 173 **Args:** 174 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 175 176 **Returns:** 177 - The composed training transforms. 178 """ 179 180 return transforms.Compose( 181 list( 182 filter( 183 None, 184 [ 185 transforms.ToTensor() if to_tensor else None, 186 self.custom_transforms, 187 transforms.Normalize( 188 self.mean(self.task_id), self.std(self.task_id) 189 ), 190 ], 191 ) 192 ) 193 ) # the order of transforms matters 194 195 def test_transforms(self, to_tensor: bool) -> transforms.Compose: 196 """Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset. 197 198 **Args:** 199 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 200 201 **Returns:** 202 - The composed training transforms. 203 """ 204 205 return transforms.Compose( 206 list( 207 filter( 208 None, 209 [ 210 transforms.ToTensor() if to_tensor else None, 211 transforms.Normalize( 212 self.mean(self.task_id), self.std(self.task_id) 213 ), 214 ], 215 ) 216 ) 217 ) # the order of transforms matters 218 219 def target_transforms(self) -> transforms.Compose: 220 """The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset. 221 222 **Args:** 223 - **target** (`Tensor`): the target tensor. 224 225 **Returns:** 226 - The transformed target tensor. 227 """ 228 229 return transforms.Compose( 230 list( 231 filter( 232 None, 233 [ 234 self.custom_target_transforms, 235 self.cl_class_mapping_t, 236 ], 237 ) 238 ) 239 ) # the order of transforms matters 240 241 @abstractmethod 242 def train_and_val_dataset(self) -> object: 243 """Get the training and validation dataset of task `self.task_id`. It must be implemented by subclasses. 244 245 **Returns:** 246 - The train and validation dataset of task `self.task_id`. 247 """ 248 249 @abstractmethod 250 def test_dataset(self) -> object: 251 """Get the test dataset of task `self.task_id`. It must be implemented by subclasses. 252 253 **Returns:** 254 - The test dataset of task `self.task_id`. 255 """ 256 257 def train_dataloader(self) -> DataLoader: 258 """DataLoader generator for stage train of task `self.task_id`. It is automatically called before training. 259 260 **Returns:** 261 - The train DataLoader of task `self.task_id`. 262 """ 263 264 pylogger.debug("Construct train dataloader for task %d...", self.task_id) 265 266 return DataLoader( 267 dataset=self.dataset_train, 268 batch_size=self.batch_size, 269 shuffle=True, # shuffle train batch to prevent overfitting 270 num_workers=self.num_workers, 271 persistent_workers=True, 272 ) 273 274 def val_dataloader(self) -> DataLoader: 275 """DataLoader generator for stage validate. It is automatically called before validating. 276 277 **Returns:** 278 - The validation DataLoader of task `self.task_id`. 279 """ 280 281 pylogger.debug("Construct validation dataloader for task %d...", self.task_id) 282 283 return DataLoader( 284 dataset=self.dataset_val, 285 batch_size=self.batch_size, 286 shuffle=False, # don't have to shuffle val or test batch 287 num_workers=self.num_workers, 288 persistent_workers=True, 289 ) 290 291 def test_dataloader(self) -> dict[int, DataLoader]: 292 """DataLoader generator for stage test. It is automatically called before testing. 293 294 **Returns:** 295 - The test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Key is task_id, value is the DataLoader. 296 """ 297 298 pylogger.debug("Construct test dataloader for task %d...", self.task_id) 299 300 return { 301 task_id: DataLoader( 302 dataset=dataset_test, 303 batch_size=self.batch_size, 304 shuffle=False, # don't have to shuffle val or test batch 305 num_workers=self.num_workers, 306 persistent_workers=True, # speed up the dataloader worker initialization 307 ) 308 for task_id, dataset_test in self.dataset_test.items() 309 }
The base class of continual learning datasets, inherited from LightningDataModule
.
25 def __init__( 26 self, 27 root: str, 28 num_tasks: int, 29 validation_percentage: float, 30 batch_size: int = 1, 31 num_workers: int = 10, 32 custom_transforms: Callable | transforms.Compose | None = None, 33 custom_target_transforms: Callable | transforms.Compose | None = None, 34 ): 35 """Initialise the CL dataset object providing the root where data files live. 36 37 **Args:** 38 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 39 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 40 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 41 - **batch_size** (`int`): The batch size in train, val, test dataloader. 42 - **num_workers** (`int`): the number of workers for dataloaders. 43 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 44 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 45 """ 46 super().__init__() 47 48 self.root: str = root 49 """Store the root directory of the original data files. Used when constructing the dataset.""" 50 self.num_tasks: int = num_tasks 51 """Store the maximum number of tasks supported by the dataset.""" 52 self.validation_percentage: float = validation_percentage 53 """Store the percentage to randomly split some of the training data into validation data.""" 54 self.batch_size: int = batch_size 55 """Store the batch size. Used when constructing train, val, test dataloader.""" 56 self.num_workers: int = num_workers 57 """Store the number of workers. Used when constructing train, val, test dataloader.""" 58 self.custom_transforms: Callable | transforms.Compose | None = custom_transforms 59 """Store the custom transforms other than the basics. Used when constructing the dataset.""" 60 self.custom_target_transforms: Callable | transforms.Compose | None = ( 61 custom_target_transforms 62 ) 63 """Store the custom target transforms other than the CL class mapping. Used when constructing the dataset.""" 64 65 self.task_id: int 66 """Task ID counter indicating which task is being processed. Self updated during the task loop.""" 67 self.cl_paradigm: str 68 """Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from `set_cl_paradigm` and used to define the CL class map.""" 69 70 self.cl_class_map_t: dict[str | int, int] 71 """Store the CL class map for the current task `self.task_id`. """ 72 self.cl_class_mapping_t: Callable 73 """Store the CL class mapping transform for the current task `self.task_id`. """ 74 75 self.dataset_train: object 76 """The training dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 77 self.dataset_val: object 78 """The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 79 self.dataset_test: dict[int, object] = {} 80 """The dictionary to store test dataset object. Key is task_id, value is the dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 81 82 self.sanity_check()
Initialise the CL dataset object providing the root where data files live.
Args:
- root (
str
): the root directory where the original data files for constructing the CL dataset physically live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
Store the percentage to randomly split some of the training data into validation data.
Store the custom transforms other than the basics. Used when constructing the dataset.
Store the custom target transforms other than the CL class mapping. Used when constructing the dataset.
Task ID counter indicating which task is being processed. Self updated during the task loop.
Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from set_cl_paradigm
and used to define the CL class map.
Store the CL class mapping transform for the current task self.task_id
.
The training dataset object. Can be a PyTorch Dataset object or any other dataset object.
The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.
The dictionary to store test dataset object. Key is task_id, value is the dataset object. Can be a PyTorch Dataset object or any other dataset object.
84 def sanity_check(self) -> None: 85 """Check the sanity of the arguments. 86 87 **Raises:** 88 - **ValueError**: when the `validation_percentage` is not in the range of 0-1. 89 """ 90 if not 0.0 < self.validation_percentage < 1.0: 91 raise ValueError("The validation_percentage should be 0-1.")
Check the sanity of the arguments.
Raises:
- ValueError: when the
validation_percentage
is not in the range of 0-1.
93 @abstractmethod 94 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 95 """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses. 96 97 **Args:** 98 - **task_id** (`int`): The task ID to query CL class map. 99 100 **Returns:** 101 - The CL class map of the task. Key is original class label, value is integer class label for continual learning. 102 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 103 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 104 105 """
The mapping of classes of task task_id
to fit continual learning settings self.cl_paradigm
. It must be implemented by subclasses.
Args:
- task_id (
int
): The task ID to query CL class map.
Returns:
- The CL class map of the task. Key is original class label, value is integer class label for continual learning.
- If
self.cl_paradigm
is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigm
is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
107 @abstractmethod 108 def prepare_data(self) -> None: 109 """Use this to download and prepare data. It must be implemented by subclasses, regulated by `LightningDatamodule`."""
Use this to download and prepare data. It must be implemented by subclasses, regulated by LightningDatamodule
.
111 def setup(self, stage: str) -> None: 112 """Set up the dataset for different stages. 113 114 **Args:** 115 - **stage** (`str`): the stage of the experiment. Should be one of the following: 116 - 'fit' or 'validate': training and validation dataset of current task `self.task_id` should be assigned to `self.dataset_train` and `self.dataset_val`. 117 - 'test': a list of test dataset of all seen tasks (from task 0 to `self.task_id`) should be assigned to `self.dataset_test`. 118 """ 119 if stage == "fit" or "validate": 120 121 pylogger.debug( 122 "Construct train and validation dataset for task %d...", self.task_id 123 ) 124 self.dataset_train, self.dataset_val = self.train_and_val_dataset() 125 pylogger.debug( 126 "Train and validation dataset for task %d are ready.", self.task_id 127 ) 128 129 if stage == "test": 130 131 pylogger.debug("Construct test dataset for task %d...", self.task_id) 132 self.dataset_test[self.task_id] = self.test_dataset() 133 pylogger.debug("Test dataset for task %d are ready.", self.task_id)
Set up the dataset for different stages.
Args:
- stage (
str
): the stage of the experiment. Should be one of the following:- 'fit' or 'validate': training and validation dataset of current task
self.task_id
should be assigned toself.dataset_train
andself.dataset_val
. - 'test': a list of test dataset of all seen tasks (from task 0 to
self.task_id
) should be assigned toself.dataset_test
.
- 'fit' or 'validate': training and validation dataset of current task
135 def setup_task_id(self, task_id: int) -> None: 136 """Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 137 138 **Args:** 139 - **task_id** (`int`): the target task ID. 140 """ 141 self.task_id = task_id 142 143 self.cl_class_map_t = self.cl_class_map(task_id) 144 self.cl_class_mapping_t = CLClassMapping(self.cl_class_map_t)
Set up which task's dataset the CL experiment is on. This must be done before setup()
method is called.
Args:
- task_id (
int
): the target task ID.
146 def set_cl_paradigm(self, cl_paradigm: str) -> None: 147 """Set the continual learning paradigm to `self.cl_paradigm`. It is used to define the CL class map. 148 149 **Args:** 150 - **cl_paradigm** (`str`): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). 151 """ 152 self.cl_paradigm = cl_paradigm
Set the continual learning paradigm to self.cl_paradigm
. It is used to define the CL class map.
Args:
- cl_paradigm (
str
): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
154 @abstractmethod 155 def mean(self, task_id: int) -> tuple[float]: 156 """The mean values for normalisation of task `task_id`. Used when constructing the dataset. 157 158 **Returns:** 159 - The mean values for normalisation. 160 """
The mean values for normalisation of task task_id
. Used when constructing the dataset.
Returns:
- The mean values for normalisation.
162 @abstractmethod 163 def std(self, task_id: int) -> tuple[float]: 164 """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. 165 166 **Returns:** 167 - The standard deviation values for normalisation. 168 """
The standard deviation values for normalisation of task task_id
. Used when constructing the dataset.
Returns:
- The standard deviation values for normalisation.
170 def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose: 171 """Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. It is a handy tool to use in subclasses when constructing the dataset. 172 173 **Args:** 174 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 175 176 **Returns:** 177 - The composed training transforms. 178 """ 179 180 return transforms.Compose( 181 list( 182 filter( 183 None, 184 [ 185 transforms.ToTensor() if to_tensor else None, 186 self.custom_transforms, 187 transforms.Normalize( 188 self.mean(self.task_id), self.std(self.task_id) 189 ), 190 ], 191 ) 192 ) 193 ) # the order of transforms matters
Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like normalisation
and ToTensor()
. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- to_tensor (
bool
): whether to includeToTensor()
transform.
Returns:
- The composed training transforms.
195 def test_transforms(self, to_tensor: bool) -> transforms.Compose: 196 """Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset. 197 198 **Args:** 199 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 200 201 **Returns:** 202 - The composed training transforms. 203 """ 204 205 return transforms.Compose( 206 list( 207 filter( 208 None, 209 [ 210 transforms.ToTensor() if to_tensor else None, 211 transforms.Normalize( 212 self.mean(self.task_id), self.std(self.task_id) 213 ), 214 ], 215 ) 216 ) 217 ) # the order of transforms matters
Transforms generator for test dataset. Only basic transforms like normalisation
and ToTensor()
are included. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- to_tensor (
bool
): whether to includeToTensor()
transform.
Returns:
- The composed training transforms.
219 def target_transforms(self) -> transforms.Compose: 220 """The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset. 221 222 **Args:** 223 - **target** (`Tensor`): the target tensor. 224 225 **Returns:** 226 - The transformed target tensor. 227 """ 228 229 return transforms.Compose( 230 list( 231 filter( 232 None, 233 [ 234 self.custom_target_transforms, 235 self.cl_class_mapping_t, 236 ], 237 ) 238 ) 239 ) # the order of transforms matters
The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- target (
Tensor
): the target tensor.
Returns:
- The transformed target tensor.
241 @abstractmethod 242 def train_and_val_dataset(self) -> object: 243 """Get the training and validation dataset of task `self.task_id`. It must be implemented by subclasses. 244 245 **Returns:** 246 - The train and validation dataset of task `self.task_id`. 247 """
Get the training and validation dataset of task self.task_id
. It must be implemented by subclasses.
Returns:
- The train and validation dataset of task
self.task_id
.
249 @abstractmethod 250 def test_dataset(self) -> object: 251 """Get the test dataset of task `self.task_id`. It must be implemented by subclasses. 252 253 **Returns:** 254 - The test dataset of task `self.task_id`. 255 """
Get the test dataset of task self.task_id
. It must be implemented by subclasses.
Returns:
- The test dataset of task
self.task_id
.
257 def train_dataloader(self) -> DataLoader: 258 """DataLoader generator for stage train of task `self.task_id`. It is automatically called before training. 259 260 **Returns:** 261 - The train DataLoader of task `self.task_id`. 262 """ 263 264 pylogger.debug("Construct train dataloader for task %d...", self.task_id) 265 266 return DataLoader( 267 dataset=self.dataset_train, 268 batch_size=self.batch_size, 269 shuffle=True, # shuffle train batch to prevent overfitting 270 num_workers=self.num_workers, 271 persistent_workers=True, 272 )
DataLoader generator for stage train of task self.task_id
. It is automatically called before training.
Returns:
- The train DataLoader of task
self.task_id
.
274 def val_dataloader(self) -> DataLoader: 275 """DataLoader generator for stage validate. It is automatically called before validating. 276 277 **Returns:** 278 - The validation DataLoader of task `self.task_id`. 279 """ 280 281 pylogger.debug("Construct validation dataloader for task %d...", self.task_id) 282 283 return DataLoader( 284 dataset=self.dataset_val, 285 batch_size=self.batch_size, 286 shuffle=False, # don't have to shuffle val or test batch 287 num_workers=self.num_workers, 288 persistent_workers=True, 289 )
DataLoader generator for stage validate. It is automatically called before validating.
Returns:
- The validation DataLoader of task
self.task_id
.
291 def test_dataloader(self) -> dict[int, DataLoader]: 292 """DataLoader generator for stage test. It is automatically called before testing. 293 294 **Returns:** 295 - The test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Key is task_id, value is the DataLoader. 296 """ 297 298 pylogger.debug("Construct test dataloader for task %d...", self.task_id) 299 300 return { 301 task_id: DataLoader( 302 dataset=dataset_test, 303 batch_size=self.batch_size, 304 shuffle=False, # don't have to shuffle val or test batch 305 num_workers=self.num_workers, 306 persistent_workers=True, # speed up the dataloader worker initialization 307 ) 308 for task_id, dataset_test in self.dataset_test.items() 309 }
DataLoader generator for stage test. It is automatically called before testing.
Returns:
- The test DataLoader dict of
self.task_id
and all tasks before (as the test is conducted on all seen tasks). Key is task_id, value is the DataLoader.
312class CLPermutedDataset(CLDataset): 313 """The base class of continual learning datasets which are constructed as permutations from an original dataset, inherited from `CLDataset`.""" 314 315 num_classes: int 316 """The number of classes in the original dataset before permutation. It must be provided in subclasses.""" 317 318 img_size: torch.Size 319 """The size of images in the original dataset before permutation. Used when constructing permutation operations. It must be provided in subclasses.""" 320 321 mean_original: tuple[float] 322 """The mean values for normalisation. It must be provided in subclasses.""" 323 324 std_original: tuple[float] 325 """The standard deviation values for normalisation. It must be provided in subclasses.""" 326 327 def __init__( 328 self, 329 root: str, 330 num_tasks: int, 331 validation_percentage: float, 332 batch_size: int = 1, 333 num_workers: int = 10, 334 custom_transforms: Callable | transforms.Compose | None = None, 335 custom_target_transforms: Callable | transforms.Compose | None = None, 336 permutation_mode: str = "first_channel_only", 337 permutation_seeds: list[int] | None = None, 338 ): 339 """Initialise the CL dataset object providing the root where data files live. 340 341 **Args:** 342 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 343 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 344 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 345 - **batch_size** (`int`): The batch size in train, val, test dataloader. 346 - **num_workers** (`int`): the number of workers for dataloaders. 347 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 348 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 349 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 350 1. 'all': permute all pixels. 351 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 352 3. 'first_channel_only': permute only the first channel. 353 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 354 """ 355 self.permutation_mode: str = permutation_mode 356 """Store the mode of permutation. Used when permutation operations used to construct tasks. """ 357 358 self.permutation_seeds: list[int] = ( 359 permutation_seeds if permutation_seeds else list(range(num_tasks)) 360 ) 361 """Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks. """ 362 363 self.permutation_seed_t: int 364 """Store the permutation seed for the current task `self.task_id`.""" 365 self.permute_t: Permute 366 """Store the permutation transform for the current task `self.task_id`. """ 367 368 super().__init__( 369 root=root, 370 num_tasks=num_tasks, 371 validation_percentage=validation_percentage, 372 batch_size=batch_size, 373 num_workers=num_workers, 374 custom_transforms=custom_transforms, 375 custom_target_transforms=custom_target_transforms, 376 ) 377 378 def sanity_check(self) -> None: 379 """Check the sanity of the arguments. 380 381 **Raises:** 382 - **ValueError**: when the `permutation_seeds` is not equal to `num_tasks`. 383 """ 384 if self.permutation_seeds and self.num_tasks != len(self.permutation_seeds): 385 raise ValueError( 386 "The number of permutation seeds is not equal to number of tasks!" 387 ) 388 389 super().sanity_check() 390 391 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 392 """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 393 394 **Args:** 395 - **task_id** (`int`): The task ID to query CL class map. 396 397 **Returns:** 398 - The CL class map of the task. Key is original class label, value is integer class label for continual learning. 399 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 400 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 401 """ 402 if self.cl_paradigm == "TIL": 403 return {i: i for i in range(self.num_classes)} 404 if self.cl_paradigm == "CIL": 405 return { 406 i: i + (task_id - 1) * self.num_classes for i in range(self.num_classes) 407 } 408 409 def setup_task_id(self, task_id: int) -> None: 410 """Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 411 412 **Args:** 413 - **task_id** (`int`): the target task ID. 414 """ 415 super().setup_task_id(task_id) 416 417 self.permutation_seed_t = self.permutation_seeds[task_id - 1] 418 self.permute_t = Permute( 419 img_size=self.img_size, 420 mode=self.permutation_mode, 421 seed=self.permutation_seed_t, 422 ) 423 424 def mean(self, task_id: int) -> tuple[float]: 425 """The mean values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset. 426 427 **Returns:** 428 - The mean values for normalisation. 429 """ 430 return self.mean_original 431 432 def std(self, task_id: int) -> tuple[float]: 433 """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset. 434 435 **Returns:** 436 - The standard deviation values for normalisation. 437 """ 438 return self.std_original 439 440 def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose: 441 """Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset. 442 443 **Args:** 444 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 445 446 **Returns:** 447 - The composed training transforms. 448 """ 449 450 return transforms.Compose( 451 list( 452 filter( 453 None, 454 [ 455 transforms.ToTensor() if to_tensor else None, 456 self.permute_t, 457 self.custom_transforms, 458 transforms.Normalize( 459 self.mean(self.task_id), self.std(self.task_id) 460 ), 461 ], 462 ) 463 ) 464 ) # the order of transforms matters 465 466 def test_transforms(self, to_tensor: bool) -> transforms.Compose: 467 """Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset. 468 469 **Args:** 470 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 471 472 **Returns:** 473 - The composed training transforms. 474 """ 475 476 return transforms.Compose( 477 list( 478 filter( 479 None, 480 [ 481 transforms.ToTensor() if to_tensor else None, 482 self.permute_t, 483 transforms.Normalize( 484 self.mean(self.task_id), self.std(self.task_id) 485 ), 486 ], 487 ) 488 ) 489 ) # the order of transforms matters
The base class of continual learning datasets which are constructed as permutations from an original dataset, inherited from CLDataset
.
327 def __init__( 328 self, 329 root: str, 330 num_tasks: int, 331 validation_percentage: float, 332 batch_size: int = 1, 333 num_workers: int = 10, 334 custom_transforms: Callable | transforms.Compose | None = None, 335 custom_target_transforms: Callable | transforms.Compose | None = None, 336 permutation_mode: str = "first_channel_only", 337 permutation_seeds: list[int] | None = None, 338 ): 339 """Initialise the CL dataset object providing the root where data files live. 340 341 **Args:** 342 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 343 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 344 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 345 - **batch_size** (`int`): The batch size in train, val, test dataloader. 346 - **num_workers** (`int`): the number of workers for dataloaders. 347 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 348 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 349 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 350 1. 'all': permute all pixels. 351 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 352 3. 'first_channel_only': permute only the first channel. 353 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 354 """ 355 self.permutation_mode: str = permutation_mode 356 """Store the mode of permutation. Used when permutation operations used to construct tasks. """ 357 358 self.permutation_seeds: list[int] = ( 359 permutation_seeds if permutation_seeds else list(range(num_tasks)) 360 ) 361 """Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks. """ 362 363 self.permutation_seed_t: int 364 """Store the permutation seed for the current task `self.task_id`.""" 365 self.permute_t: Permute 366 """Store the permutation transform for the current task `self.task_id`. """ 367 368 super().__init__( 369 root=root, 370 num_tasks=num_tasks, 371 validation_percentage=validation_percentage, 372 batch_size=batch_size, 373 num_workers=num_workers, 374 custom_transforms=custom_transforms, 375 custom_target_transforms=custom_target_transforms, 376 )
Initialise the CL dataset object providing the root where data files live.
Args:
- root (
str
): the root directory where the original data files for constructing the CL dataset physically live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. - permutation_mode (
str
): the mode of permutation, should be one of the following:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- permutation_seeds (
list[int]
orNone
): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds asnum_tasks
. Default is None, which creates a list of seeds from 1 tonum_tasks
.
The number of classes in the original dataset before permutation. It must be provided in subclasses.
The size of images in the original dataset before permutation. Used when constructing permutation operations. It must be provided in subclasses.
The standard deviation values for normalisation. It must be provided in subclasses.
Store the mode of permutation. Used when permutation operations used to construct tasks.
Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks.
378 def sanity_check(self) -> None: 379 """Check the sanity of the arguments. 380 381 **Raises:** 382 - **ValueError**: when the `permutation_seeds` is not equal to `num_tasks`. 383 """ 384 if self.permutation_seeds and self.num_tasks != len(self.permutation_seeds): 385 raise ValueError( 386 "The number of permutation seeds is not equal to number of tasks!" 387 ) 388 389 super().sanity_check()
Check the sanity of the arguments.
Raises:
- ValueError: when the
permutation_seeds
is not equal tonum_tasks
.
391 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 392 """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 393 394 **Args:** 395 - **task_id** (`int`): The task ID to query CL class map. 396 397 **Returns:** 398 - The CL class map of the task. Key is original class label, value is integer class label for continual learning. 399 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 400 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 401 """ 402 if self.cl_paradigm == "TIL": 403 return {i: i for i in range(self.num_classes)} 404 if self.cl_paradigm == "CIL": 405 return { 406 i: i + (task_id - 1) * self.num_classes for i in range(self.num_classes) 407 }
The mapping of classes of task task_id
to fit continual learning settings self.cl_paradigm
.
Args:
- task_id (
int
): The task ID to query CL class map.
Returns:
- The CL class map of the task. Key is original class label, value is integer class label for continual learning.
- If
self.cl_paradigm
is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigm
is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
409 def setup_task_id(self, task_id: int) -> None: 410 """Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 411 412 **Args:** 413 - **task_id** (`int`): the target task ID. 414 """ 415 super().setup_task_id(task_id) 416 417 self.permutation_seed_t = self.permutation_seeds[task_id - 1] 418 self.permute_t = Permute( 419 img_size=self.img_size, 420 mode=self.permutation_mode, 421 seed=self.permutation_seed_t, 422 )
Set up which task's dataset the CL experiment is on. This must be done before setup()
method is called.
Args:
- task_id (
int
): the target task ID.
424 def mean(self, task_id: int) -> tuple[float]: 425 """The mean values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset. 426 427 **Returns:** 428 - The mean values for normalisation. 429 """ 430 return self.mean_original
The mean values for normalisation of task task_id
. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
Returns:
- The mean values for normalisation.
432 def std(self, task_id: int) -> tuple[float]: 433 """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset. 434 435 **Returns:** 436 - The standard deviation values for normalisation. 437 """ 438 return self.std_original
The standard deviation values for normalisation of task task_id
. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
Returns:
- The standard deviation values for normalisation.
440 def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose: 441 """Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset. 442 443 **Args:** 444 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 445 446 **Returns:** 447 - The composed training transforms. 448 """ 449 450 return transforms.Compose( 451 list( 452 filter( 453 None, 454 [ 455 transforms.ToTensor() if to_tensor else None, 456 self.permute_t, 457 self.custom_transforms, 458 transforms.Normalize( 459 self.mean(self.task_id), self.std(self.task_id) 460 ), 461 ], 462 ) 463 ) 464 ) # the order of transforms matters
Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like normalisation
and ToTensor()
. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- to_tensor (
bool
): whether to includeToTensor()
transform.
Returns:
- The composed training transforms.
466 def test_transforms(self, to_tensor: bool) -> transforms.Compose: 467 """Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset. 468 469 **Args:** 470 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. 471 472 **Returns:** 473 - The composed training transforms. 474 """ 475 476 return transforms.Compose( 477 list( 478 filter( 479 None, 480 [ 481 transforms.ToTensor() if to_tensor else None, 482 self.permute_t, 483 transforms.Normalize( 484 self.mean(self.task_id), self.std(self.task_id) 485 ), 486 ], 487 ) 488 ) 489 ) # the order of transforms matters
Transforms generator for test dataset. Only basic transforms like normalisation
and ToTensor()
are included. It is a handy tool to use in subclasses when constructing the dataset.
Args:
- to_tensor (
bool
): whether to includeToTensor()
transform.
Returns:
- The composed training transforms.
492class CLSplitDataset(CLDataset): 493 """The base class of continual learning datasets, which are constructed as permutations from an original dataset, inherited from `CLDataset`.""" 494 495 num_classes: int 496 """The number of classes in the original dataset before permutation. It must be provided in subclasses.""" 497 498 mean_original: tuple[float] 499 """The mean values for normalisation. It must be provided in subclasses.""" 500 501 std_original: tuple[float] 502 """The standard deviation values for normalisation. It must be provided in subclasses.""" 503 504 def __init__( 505 self, 506 root: str, 507 num_tasks: int, 508 class_split: list[list[int]], 509 validation_percentage: float, 510 batch_size: int = 1, 511 num_workers: int = 10, 512 custom_transforms: Callable | transforms.Compose | None = None, 513 custom_target_transforms: Callable | transforms.Compose | None = None, 514 ): 515 """Initialise the CL dataset object providing the root where data files live. 516 517 **Args:** 518 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 519 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 520 - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. 521 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 522 - **batch_size** (`int`): The batch size in train, val, test dataloader. 523 - **num_workers** (`int`): the number of workers for dataloaders. 524 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 525 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 526 """ 527 528 self.class_split = class_split 529 """Store the class split for each task. Used when constructing the split dataset.""" 530 531 super().__init__( 532 root=root, 533 num_tasks=num_tasks, 534 validation_percentage=validation_percentage, 535 batch_size=batch_size, 536 num_workers=num_workers, 537 custom_transforms=custom_transforms, 538 custom_target_transforms=custom_target_transforms, 539 ) 540 541 def sanity_check(self) -> None: 542 """Check the sanity of the arguments. 543 544 **Raises:** 545 - **ValueError**: when the length of `class_split` is not equal to `num_tasks`. 546 - **ValueError**: when any of the lists in `class_split` has less than 2 elements. A classification task must have less than 2 classes. 547 """ 548 if len(self.class_split) != self.num_tasks: 549 raise ValueError( 550 "The length of class split is not equal to number of tasks!" 551 ) 552 if any(len(split) < 2 for split in self.class_split): 553 raise ValueError("Each class split must contain at least 2 elements!") 554 555 super().sanity_check() 556 557 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 558 """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 559 560 **Args:** 561 - **task_id** (`int`): The task ID to query CL class map. 562 563 **Returns:** 564 - The CL class map of the task. Key is original class label, value is integer class label for continual learning. 565 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 566 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 567 """ 568 num_classes_t = len(self.class_split[task_id - 1]) 569 if self.cl_paradigm == "TIL": 570 return {self.class_split[task_id - 1][i]: i for i in range(num_classes_t)} 571 if self.cl_paradigm == "CIL": 572 return { 573 self.class_split[task_id - 1][i]: i + (task_id - 1) * self.num_classes 574 for i in range(num_classes_t) 575 } 576 577 def mean(self, task_id: int) -> tuple[float]: 578 """The mean values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset. 579 580 **Returns:** 581 - The mean values for normalisation. 582 """ 583 return self.mean_original 584 585 def std(self, task_id: int) -> tuple[float]: 586 """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset. 587 588 **Returns:** 589 - The standard deviation values for normalisation. 590 """ 591 return self.std_original 592 593 def get_class_subset(self, dataset: Dataset, classes: list[int]) -> Dataset: 594 """Provide a util method here to retrieve a subset from PyTorch Dataset of classes. It could be useful when you constructing the split CL dataset. 595 596 **Args:** 597 - **dataset** (`Dataset`): the original dataset to retrieve subset from. 598 - **classes** (`list[int]`): the classes of the subset. 599 600 **Returns:** 601 - `Dataset`: subset of original dataset in classes. 602 """ 603 # get the indices of the dataset that belong to the classes 604 idx = np.isin(dataset.targets, classes) 605 606 # subset the dataset by the indices, in-place operation 607 dataset.data = dataset.data[idx] 608 dataset.targets = dataset.targets[idx] 609 610 return dataset
The base class of continual learning datasets, which are constructed as permutations from an original dataset, inherited from CLDataset
.
504 def __init__( 505 self, 506 root: str, 507 num_tasks: int, 508 class_split: list[list[int]], 509 validation_percentage: float, 510 batch_size: int = 1, 511 num_workers: int = 10, 512 custom_transforms: Callable | transforms.Compose | None = None, 513 custom_target_transforms: Callable | transforms.Compose | None = None, 514 ): 515 """Initialise the CL dataset object providing the root where data files live. 516 517 **Args:** 518 - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live. 519 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 520 - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. 521 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 522 - **batch_size** (`int`): The batch size in train, val, test dataloader. 523 - **num_workers** (`int`): the number of workers for dataloaders. 524 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included. 525 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 526 """ 527 528 self.class_split = class_split 529 """Store the class split for each task. Used when constructing the split dataset.""" 530 531 super().__init__( 532 root=root, 533 num_tasks=num_tasks, 534 validation_percentage=validation_percentage, 535 batch_size=batch_size, 536 num_workers=num_workers, 537 custom_transforms=custom_transforms, 538 custom_target_transforms=custom_target_transforms, 539 )
Initialise the CL dataset object providing the root where data files live.
Args:
- root (
str
): the root directory where the original data files for constructing the CL dataset physically live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - class_split (
list[list[int]]
): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
The number of classes in the original dataset before permutation. It must be provided in subclasses.
The standard deviation values for normalisation. It must be provided in subclasses.
541 def sanity_check(self) -> None: 542 """Check the sanity of the arguments. 543 544 **Raises:** 545 - **ValueError**: when the length of `class_split` is not equal to `num_tasks`. 546 - **ValueError**: when any of the lists in `class_split` has less than 2 elements. A classification task must have less than 2 classes. 547 """ 548 if len(self.class_split) != self.num_tasks: 549 raise ValueError( 550 "The length of class split is not equal to number of tasks!" 551 ) 552 if any(len(split) < 2 for split in self.class_split): 553 raise ValueError("Each class split must contain at least 2 elements!") 554 555 super().sanity_check()
Check the sanity of the arguments.
Raises:
- ValueError: when the length of
class_split
is not equal tonum_tasks
. - ValueError: when any of the lists in
class_split
has less than 2 elements. A classification task must have less than 2 classes.
557 def cl_class_map(self, task_id: int) -> dict[str | int, int]: 558 """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 559 560 **Args:** 561 - **task_id** (`int`): The task ID to query CL class map. 562 563 **Returns:** 564 - The CL class map of the task. Key is original class label, value is integer class label for continual learning. 565 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 566 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 567 """ 568 num_classes_t = len(self.class_split[task_id - 1]) 569 if self.cl_paradigm == "TIL": 570 return {self.class_split[task_id - 1][i]: i for i in range(num_classes_t)} 571 if self.cl_paradigm == "CIL": 572 return { 573 self.class_split[task_id - 1][i]: i + (task_id - 1) * self.num_classes 574 for i in range(num_classes_t) 575 }
The mapping of classes of task task_id
to fit continual learning settings self.cl_paradigm
.
Args:
- task_id (
int
): The task ID to query CL class map.
Returns:
- The CL class map of the task. Key is original class label, value is integer class label for continual learning.
- If
self.cl_paradigm
is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigm
is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
577 def mean(self, task_id: int) -> tuple[float]: 578 """The mean values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset. 579 580 **Returns:** 581 - The mean values for normalisation. 582 """ 583 return self.mean_original
The mean values for normalisation of task task_id
. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
Returns:
- The mean values for normalisation.
585 def std(self, task_id: int) -> tuple[float]: 586 """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset. 587 588 **Returns:** 589 - The standard deviation values for normalisation. 590 """ 591 return self.std_original
The standard deviation values for normalisation of task task_id
. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
Returns:
- The standard deviation values for normalisation.
593 def get_class_subset(self, dataset: Dataset, classes: list[int]) -> Dataset: 594 """Provide a util method here to retrieve a subset from PyTorch Dataset of classes. It could be useful when you constructing the split CL dataset. 595 596 **Args:** 597 - **dataset** (`Dataset`): the original dataset to retrieve subset from. 598 - **classes** (`list[int]`): the classes of the subset. 599 600 **Returns:** 601 - `Dataset`: subset of original dataset in classes. 602 """ 603 # get the indices of the dataset that belong to the classes 604 idx = np.isin(dataset.targets, classes) 605 606 # subset the dataset by the indices, in-place operation 607 dataset.data = dataset.data[idx] 608 dataset.targets = dataset.targets[idx] 609 610 return dataset
Provide a util method here to retrieve a subset from PyTorch Dataset of classes. It could be useful when you constructing the split CL dataset.
Args:
- dataset (
Dataset
): the original dataset to retrieve subset from. - classes (
list[int]
): the classes of the subset.
Returns:
Dataset
: subset of original dataset in classes.
613class CLClassMapping: 614 """CL Class mapping to dataset labels. Used as a PyTorch target Transform.""" 615 616 def __init__(self, cl_class_map: dict[str | int, int]): 617 """Initialise the CL class mapping transform object from the CL class map of a task. 618 619 **Args:** 620 - **cl_class_map** (`dict[str | int, int]`): the CL class map for a task. 621 """ 622 self.cl_class_map = cl_class_map 623 624 def __call__(self, target: torch.Tensor) -> torch.Tensor: 625 """The CL class mapping transform to dataset labels. It is defined as a callable object like a PyTorch Transform. 626 627 **Args:** 628 - **target** (`Tensor`): the target tensor. 629 630 **Returns:** 631 - The transformed target tensor. 632 """ 633 634 return self.cl_class_map[target]
CL Class mapping to dataset labels. Used as a PyTorch target Transform.
616 def __init__(self, cl_class_map: dict[str | int, int]): 617 """Initialise the CL class mapping transform object from the CL class map of a task. 618 619 **Args:** 620 - **cl_class_map** (`dict[str | int, int]`): the CL class map for a task. 621 """ 622 self.cl_class_map = cl_class_map
Initialise the CL class mapping transform object from the CL class map of a task.
Args:
- cl_class_map (
dict[str | int, int]
): the CL class map for a task.
637class Permute: 638 """Permutation operation to image. Used to construct permuted CL dataset. 639 640 Used as a PyTorch Dataset Transform. 641 """ 642 643 def __init__( 644 self, 645 img_size: torch.Size, 646 mode: str = "first_channel_only", 647 seed: int | None = None, 648 ): 649 """Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime. 650 651 **Args:** 652 - **img_size** (`torch.Size`): the size of the image to be permuted. 653 - **mode** (`str`): the mode of permutation, shouble be one of the following: 654 - 'all': permute all pixels. 655 - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 656 - 'first_channel_only': permute only the first channel. 657 - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator. 658 """ 659 self.mode = mode 660 """Store the mode of permutation.""" 661 662 # get generator for permutation 663 torch_generator = torch.Generator() 664 if seed: 665 torch_generator.manual_seed(seed) 666 667 # calculate the number of pixels from the image size 668 if self.mode == "all": 669 num_pixels = img_size[0] * img_size[1] * img_size[2] 670 elif self.mode == "by_channel" or "first_channel_only": 671 num_pixels = img_size[1] * img_size[2] 672 673 self.permute: torch.Tensor = torch.randperm( 674 num_pixels, generator=torch_generator 675 ) 676 """The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation.""" 677 678 def __call__(self, img: torch.Tensor) -> torch.Tensor: 679 """The permutation operation to image is defined as a callable object like a PyTorch Transform. 680 681 **Args:** 682 - **img** (`Tensor`): image to be permuted. Must match the size of `img_size` in the initialisation. 683 684 **Returns:** 685 - The permuted image (`Tensor`). 686 """ 687 688 if self.mode == "all": 689 690 img_flat = img.view( 691 -1 692 ) # flatten the whole image to 1d so that it can be applied 1d permuted order 693 img_flat_permuted = img_flat[self.permute] # conduct permutation operation 694 img_permuted = img_flat_permuted.view( 695 img.size() 696 ) # return to the original image shape 697 return img_permuted 698 699 if self.mode == "by_channel": 700 701 permuted_channels = [] 702 for i in range(img.size(0)): 703 # act on every channel 704 channel_flat = img[i].view( 705 -1 706 ) # flatten the channel to 1d so that it can be applied 1d permuted order 707 channel_flat_permuted = channel_flat[ 708 self.permute 709 ] # conduct permutation operation 710 channel_permuted = channel_flat_permuted.view( 711 img[0].size() 712 ) # return to the original channel shape 713 permuted_channels.append(channel_permuted) 714 img_permuted = torch.stack( 715 permuted_channels 716 ) # stack the permuted channels to restore the image 717 return img_permuted 718 719 if self.mode == "first_channel_only": 720 721 first_channel_flat = img[0].view( 722 -1 723 ) # flatten the first channel to 1d so that it can be applied 1d permuted order 724 first_channel_flat_permuted = first_channel_flat[ 725 self.permute 726 ] # conduct permutation operation 727 first_channel_permuted = first_channel_flat_permuted.view( 728 img[0].size() 729 ) # return to the original channel shape 730 731 img_permuted = img.clone() 732 img_permuted[0] = first_channel_permuted 733 734 return img_permuted
Permutation operation to image. Used to construct permuted CL dataset.
Used as a PyTorch Dataset Transform.
643 def __init__( 644 self, 645 img_size: torch.Size, 646 mode: str = "first_channel_only", 647 seed: int | None = None, 648 ): 649 """Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime. 650 651 **Args:** 652 - **img_size** (`torch.Size`): the size of the image to be permuted. 653 - **mode** (`str`): the mode of permutation, shouble be one of the following: 654 - 'all': permute all pixels. 655 - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 656 - 'first_channel_only': permute only the first channel. 657 - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator. 658 """ 659 self.mode = mode 660 """Store the mode of permutation.""" 661 662 # get generator for permutation 663 torch_generator = torch.Generator() 664 if seed: 665 torch_generator.manual_seed(seed) 666 667 # calculate the number of pixels from the image size 668 if self.mode == "all": 669 num_pixels = img_size[0] * img_size[1] * img_size[2] 670 elif self.mode == "by_channel" or "first_channel_only": 671 num_pixels = img_size[1] * img_size[2] 672 673 self.permute: torch.Tensor = torch.randperm( 674 num_pixels, generator=torch_generator 675 ) 676 """The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""
Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime.
Args:
- img_size (
torch.Size
): the size of the image to be permuted. - mode (
str
): the mode of permutation, shouble be one of the following:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- seed (
int
orNone
): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.