clarena.cl_datasets
Continual Learning Datasets
This submodule provides the continual learning datasets that can be used in CLArena.
Here are the base classes for continual learning datasets, which inherit from Lightning LightningDataModule:
CLDataset: The base class for all continual learning datasets.CLPermutedDataset: The base class for permuted continual learning datasets. A child class ofCLDataset.CLRotatedDataset: The base class for rotated continual learning datasets. A child class ofCLDataset.CLSplitDataset: The base class for split continual learning datasets. A child class ofCLDataset.CLCombinedDataset: The base class for combined continual learning datasets. A child class ofCLDataset.
Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement continual learning datasets:
1r""" 2 3# Continual Learning Datasets 4 5This submodule provides the **continual learning datasets** that can be used in CLArena. 6 7Here are the base classes for continual learning datasets, which inherit from Lightning `LightningDataModule`: 8 9- `CLDataset`: The base class for all continual learning datasets. 10 - `CLPermutedDataset`: The base class for permuted continual learning datasets. A child class of `CLDataset`. 11 - `CLRotatedDataset`: The base class for rotated continual learning datasets. A child class of `CLDataset`. 12 - `CLSplitDataset`: The base class for split continual learning datasets. A child class of `CLDataset`. 13 - `CLCombinedDataset`: The base class for combined continual learning datasets. A child class of `CLDataset`. 14 15Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement continual learning datasets: 16 17- [**Configure CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/CL-dataset) 18- [**Implement Custom CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/cl_dataset) 19- [**A Beginners' Guide to Continual Learning (CL Dataset)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-CL-dataset) 20 21 22""" 23 24from .base import ( 25 CLDataset, 26 CLPermutedDataset, 27 CLRotatedDataset, 28 CLSplitDataset, 29 CLCombinedDataset, 30) 31 32from .permuted_mnist import PermutedMNIST 33from .rotated_mnist import RotatedMNIST 34from .permuted_emnist import PermutedEMNIST 35from .permuted_fashionmnist import PermutedFashionMNIST 36from .permuted_kmnist import PermutedKMNIST 37from .permuted_notmnist import PermutedNotMNIST 38from .permuted_sign_language_mnist import PermutedSignLanguageMNIST 39from .permuted_ahdd import PermutedArabicHandwrittenDigits 40from .permuted_kannadamnist import PermutedKannadaMNIST 41from .permuted_svhn import PermutedSVHN 42from .permuted_country211 import PermutedCountry211 43from .permuted_imagenette import PermutedImagenette 44from .permuted_dtd import PermutedDTD 45from .permuted_cifar10 import PermutedCIFAR10 46from .permuted_cifar100 import PermutedCIFAR100 47from .permuted_caltech101 import PermutedCaltech101 48from .permuted_caltech256 import PermutedCaltech256 49from .permuted_eurosat import PermutedEuroSAT 50from .permuted_fgvc_aircraft import PermutedFGVCAircraft 51from .permuted_flowers102 import PermutedFlowers102 52from .permuted_food101 import PermutedFood101 53from .permuted_celeba import PermutedCelebA 54from .permuted_fer2013 import PermutedFER2013 55from .permuted_tinyimagenet import PermutedTinyImageNet 56from .permuted_oxford_iiit_pet import PermutedOxfordIIITPet 57from .permuted_pcam import PermutedPCAM 58from .permuted_renderedsst2 import PermutedRenderedSST2 59from .permuted_stanfordcars import PermutedStanfordCars 60from .permuted_sun397 import PermutedSUN397 61from .permuted_usps import PermutedUSPS 62from .permuted_SEMEION import PermutedSEMEION 63from .permuted_facescrub import PermutedFaceScrub 64from .permuted_cub2002011 import PermutedCUB2002011 65from .permuted_gtsrb import PermutedGTSRB 66from .permuted_linnaeus5 import PermutedLinnaeus5 67 68from .split_cifar10 import SplitCIFAR10 69from .split_mnist import SplitMNIST 70from .split_cifar100 import SplitCIFAR100 71from .split_tinyimagenet import SplitTinyImageNet 72from .split_cub2002011 import SplitCUB2002011 73 74from .combined import Combined 75 76 77__all__ = [ 78 "CLDataset", 79 "CLPermutedDataset", 80 "CLRotatedDataset", 81 "CLSplitDataset", 82 "CLCombinedDataset", 83 "combined", 84 "permuted_mnist", 85 "rotated_mnist", 86 "permuted_emnist", 87 "permuted_fashionmnist", 88 "permuted_imagenette", 89 "permuted_sign_language_mnist", 90 "permuted_ahdd", 91 "permuted_kannadamnist", 92 "permuted_country211", 93 "permuted_dtd", 94 "permuted_fer2013", 95 "permuted_fgvc_aircraft", 96 "permuted_flowers102", 97 "permuted_food101", 98 "permuted_kmnist", 99 "permuted_notmnist", 100 "permuted_svhn", 101 "permuted_cifar10", 102 "permuted_cifar100", 103 "permuted_caltech101", 104 "permuted_caltech256", 105 "permuted_oxford_iiit_pet", 106 "permuted_celeba", 107 "permuted_eurosat", 108 "permuted_facescrub", 109 "permuted_pcam", 110 "permuted_renderedsst2", 111 "permuted_stanfordcars", 112 "permuted_sun397", 113 "permuted_usps", 114 "permuted_SEMEION", 115 "permuted_tinyimagenet", 116 "permuted_cub2002011", 117 "permuted_gtsrb", 118 "permuted_linnaeus5", 119 "split_mnist", 120 "split_cifar10", 121 "split_cifar100", 122 "split_tinyimagenet", 123 "split_cub2002011", 124]
35class CLDataset(LightningDataModule): 36 r"""The base class of continual learning datasets.""" 37 38 def __init__( 39 self, 40 root: str | dict[int, str], 41 num_tasks: int, 42 batch_size: int | dict[int, int] = 1, 43 num_workers: int | dict[int, int] = 0, 44 custom_transforms: ( 45 Callable 46 | transforms.Compose 47 | None 48 | dict[int, Callable | transforms.Compose | None] 49 ) = None, 50 repeat_channels: int | None | dict[int, int | None] = None, 51 to_tensor: bool | dict[int, bool] = True, 52 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 53 ) -> None: 54 r""" 55 **Args:** 56 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live. 57 If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. 58 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 59 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 60 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 61 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 62 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 63 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 64 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 65 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 66 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 67 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 68 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 69 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 70 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 71 """ 72 super().__init__() 73 74 self.root: dict[int, str] = ( 75 OmegaConf.to_container(root) 76 if isinstance(root, DictConfig) 77 else {t: root for t in range(1, num_tasks + 1)} 78 ) 79 r"""The dict of root directories of the original data files for each task.""" 80 self.num_tasks: int = num_tasks 81 r"""The maximum number of tasks supported by the dataset.""" 82 self.cl_paradigm: str 83 r"""The continual learning paradigm.""" 84 self.batch_size: dict[int, int] = ( 85 OmegaConf.to_container(batch_size) 86 if isinstance(batch_size, DictConfig) 87 else {t: batch_size for t in range(1, num_tasks + 1)} 88 ) 89 r"""The dict of batch sizes for each task.""" 90 self.num_workers: dict[int, int] = ( 91 OmegaConf.to_container(num_workers) 92 if isinstance(num_workers, DictConfig) 93 else {t: num_workers for t in range(1, num_tasks + 1)} 94 ) 95 r"""The dict of numbers of workers for each task.""" 96 self.custom_transforms: dict[int, Callable | transforms.Compose | None] = ( 97 OmegaConf.to_container(custom_transforms) 98 if isinstance(custom_transforms, DictConfig) 99 else {t: custom_transforms for t in range(1, num_tasks + 1)} 100 ) 101 r"""The dict of custom transforms for each task.""" 102 self.repeat_channels: dict[int, int | None] = ( 103 OmegaConf.to_container(repeat_channels) 104 if isinstance(repeat_channels, DictConfig) 105 else {t: repeat_channels for t in range(1, num_tasks + 1)} 106 ) 107 r"""The dict of number of channels to repeat for each task.""" 108 self.to_tensor: dict[int, bool] = ( 109 OmegaConf.to_container(to_tensor) 110 if isinstance(to_tensor, DictConfig) 111 else {t: to_tensor for t in range(1, num_tasks + 1)} 112 ) 113 r"""The dict of to_tensor flag for each task. """ 114 self.resize: dict[int, tuple[int, int] | None] = ( 115 {t: tuple(rs) if rs else None for t, rs in resize.items()} 116 if isinstance(resize, DictConfig) 117 else { 118 t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1) 119 } 120 ) 121 r"""The dict of sizes to resize to for each task.""" 122 123 # task-specific attributes 124 self.root_t: str 125 r"""The root directory of the original data files for the current task `self.task_id`.""" 126 self.batch_size_t: int 127 r"""The batch size for the current task `self.task_id`.""" 128 self.num_workers_t: int 129 r"""The number of workers for the current task `self.task_id`.""" 130 self.custom_transforms_t: Callable | transforms.Compose | None 131 r"""The custom transforms for the current task `self.task_id`.""" 132 self.repeat_channels_t: int | None 133 r"""The number of channels to repeat for the current task `self.task_id`.""" 134 self.to_tensor_t: bool 135 r"""The to_tensor flag for the current task `self.task_id`.""" 136 self.resize_t: tuple[int, int] | None 137 r"""The size to resize for the current task `self.task_id`.""" 138 self.mean_t: float 139 r"""The mean values for normalization for the current task `self.task_id`.""" 140 self.std_t: float 141 r"""The standard deviation values for normalization for the current task `self.task_id`.""" 142 143 # dataset containers 144 self.dataset_train_t: Any 145 r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 146 self.dataset_val_t: Any 147 r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 148 self.dataset_test: dict[int, Any] = {} 149 r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.""" 150 151 # task ID control 152 self.task_id: int 153 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 154 self.processed_task_ids: list[int] = [] 155 r"""Task IDs that have been processed.""" 156 157 CLDataset.sanity_check(self) 158 159 def sanity_check(self) -> None: 160 r"""Sanity check.""" 161 162 # check if each task has been provided with necessary arguments 163 for attr in [ 164 "root", 165 "batch_size", 166 "num_workers", 167 "custom_transforms", 168 "repeat_channels", 169 "to_tensor", 170 "resize", 171 ]: 172 value = getattr(self, attr) 173 expected_keys = set(range(1, self.num_tasks + 1)) 174 if set(value.keys()) != expected_keys: 175 raise ValueError( 176 f"{attr} dict keys must be consecutive integers from 1 to num_tasks." 177 ) 178 179 @abstractmethod 180 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 181 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses. 182 183 **Args:** 184 - **task_id** (`int`): the task ID to query the CL class map. 185 186 **Returns:** 187 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 188 - If `self.cl_paradigm` is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes. 189 - If `self.cl_paradigm` is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 190 """ 191 192 @abstractmethod 193 def prepare_data(self) -> None: 194 r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDataModule`. This method is called at the beginning of each task.""" 195 196 def setup(self, stage: str) -> None: 197 r"""Set up the dataset for different stages. This method is called at the beginning of each task. 198 199 **Args:** 200 - **stage** (`str`): the stage of the experiment; one of: 201 - 'fit': training and validation datasets of the current task `self.task_id` are assigned to `self.dataset_train_t` and `self.dataset_val_t`. 202 - 'test': a dict of test datasets of all seen tasks should be assigned to `self.dataset_test`. 203 """ 204 if stage == "fit": 205 # these two stages must be done together because a sanity check for validation is conducted before training 206 pylogger.debug( 207 "Construct train and validation dataset for task %d...", self.task_id 208 ) 209 210 self.dataset_train_t, self.dataset_val_t = self.train_and_val_dataset() 211 212 pylogger.info( 213 "Train and validation dataset for task %d are ready.", self.task_id 214 ) 215 pylogger.info( 216 "Train dataset for task %d size: %d", 217 self.task_id, 218 len(self.dataset_train_t), 219 ) 220 pylogger.info( 221 "Validation dataset for task %d size: %d", 222 self.task_id, 223 len(self.dataset_val_t), 224 ) 225 226 elif stage == "test": 227 pylogger.debug("Construct test dataset for task %d...", self.task_id) 228 229 self.dataset_test[self.task_id] = self.test_dataset() 230 231 pylogger.info("Test dataset for task %d are ready.", self.task_id) 232 pylogger.info( 233 "Test dataset for task %d size: %d", 234 self.task_id, 235 len(self.dataset_test[self.task_id]), 236 ) 237 238 def setup_task_id(self, task_id: int) -> None: 239 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 240 241 **Args:** 242 - **task_id** (`int`): the target task ID. 243 """ 244 245 self.task_id = task_id 246 247 self.root_t = self.root[task_id] 248 self.batch_size_t = self.batch_size[task_id] 249 self.num_workers_t = self.num_workers[task_id] 250 self.custom_transforms_t = self.custom_transforms[task_id] 251 self.repeat_channels_t = self.repeat_channels[task_id] 252 self.to_tensor_t = self.to_tensor[task_id] 253 self.resize_t = self.resize[task_id] 254 255 self.processed_task_ids.append(task_id) 256 257 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 258 r"""Set up tasks for continual learning main evaluation. 259 260 **Args:** 261 - **eval_tasks** (`list[int]`): the list of task IDs to evaluate. 262 """ 263 print(eval_tasks, "eval_tasks") 264 for task_id in eval_tasks: 265 self.setup_task_id(task_id=task_id) 266 self.setup(stage="test") 267 268 def set_cl_paradigm(self, cl_paradigm: str) -> None: 269 r"""Set `cl_paradigm` to `self.cl_paradigm`. It is used to define the CL class map. 270 271 **Args:** 272 - **cl_paradigm** (`str`): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). 273 """ 274 self.cl_paradigm = cl_paradigm 275 276 def train_and_val_transforms(self) -> transforms.Compose: 277 r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. It can be used in subclasses when constructing the dataset. 278 279 **Returns:** 280 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 281 """ 282 repeat_channels_transform = ( 283 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 284 if self.repeat_channels_t is not None 285 else None 286 ) 287 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 288 resize_transform = ( 289 transforms.Resize(self.resize_t) if self.resize_t is not None else None 290 ) 291 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 292 293 return transforms.Compose( 294 list( 295 filter( 296 None, 297 [ 298 repeat_channels_transform, 299 to_tensor_transform, 300 resize_transform, 301 self.custom_transforms_t, 302 normalization_transform, 303 ], 304 ) 305 ) 306 ) # the order of transforms matters 307 308 def test_transforms(self) -> transforms.Compose: 309 r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. It is used in subclasses when constructing the dataset. 310 311 **Returns:** 312 - **test_transforms** (`transforms.Compose`): the composed test transforms. 313 """ 314 repeat_channels_transform = ( 315 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 316 if self.repeat_channels_t is not None 317 else None 318 ) 319 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 320 resize_transform = ( 321 transforms.Resize(self.resize_t) if self.resize_t is not None else None 322 ) 323 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 324 325 return transforms.Compose( 326 list( 327 filter( 328 None, 329 [ 330 repeat_channels_transform, 331 to_tensor_transform, 332 resize_transform, 333 normalization_transform, 334 ], 335 ) 336 ) 337 ) # the order of transforms matters. No custom transforms for test 338 339 def target_transform(self) -> ClassMapping: 340 r"""Target transform to map the original class labels to CL class labels according to `self.cl_paradigm`. It can be used in subclasses when constructing the dataset. 341 342 **Returns:** 343 - **target_transform** (`Callable`): the target transform function. 344 """ 345 346 cl_class_map = self.get_cl_class_map(task_id=self.task_id) 347 348 target_transform = ClassMapping(class_map=cl_class_map) 349 350 return target_transform 351 352 @abstractmethod 353 def train_and_val_dataset(self) -> tuple[Any, Any]: 354 r"""Get the training and validation datasets of the current task `self.task_id`. It must be implemented by subclasses. 355 356 **Returns:** 357 - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation datasets of the current task `self.task_id`. 358 """ 359 360 @abstractmethod 361 def test_dataset(self) -> Any: 362 r"""Get the test dataset of the current task `self.task_id`. It must be implemented by subclasses. 363 364 **Returns:** 365 - **test_dataset** (`Any`): the test dataset of the current task `self.task_id`. 366 """ 367 368 def train_dataloader(self) -> DataLoader: 369 r"""DataLoader generator for the train stage of the current task `self.task_id`. It is automatically called before training the task. 370 371 **Returns:** 372 - **train_dataloader** (`DataLoader`): the train DataLoader of task `self.task_id`. 373 """ 374 375 pylogger.debug("Construct train dataloader for task %d...", self.task_id) 376 377 return DataLoader( 378 dataset=self.dataset_train_t, 379 batch_size=self.batch_size_t, 380 shuffle=True, # shuffle train batch to prevent overfitting 381 num_workers=self.num_workers_t, 382 drop_last=True, # to avoid batchnorm error (when batch_size is 1) 383 ) 384 385 def val_dataloader(self) -> DataLoader: 386 r"""DataLoader generator for the validation stage of the current task `self.task_id`. It is automatically called before the task's validation. 387 388 **Returns:** 389 - **val_dataloader** (`DataLoader`): the validation DataLoader of task `self.task_id`. 390 """ 391 392 pylogger.debug("Construct validation dataloader for task %d...", self.task_id) 393 394 return DataLoader( 395 dataset=self.dataset_val_t, 396 batch_size=self.batch_size_t, 397 shuffle=False, # don't have to shuffle val or test batch 398 num_workers=self.num_workers_t, 399 ) 400 401 def test_dataloader(self) -> dict[int, DataLoader]: 402 r"""DataLoader generator for the test stage of the current task `self.task_id`. It is automatically called before testing the task. 403 404 **Returns:** 405 - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders. 406 """ 407 408 pylogger.debug("Construct test dataloader for task %d...", self.task_id) 409 410 return { 411 task_id: DataLoader( 412 dataset=dataset_test_t, 413 batch_size=self.batch_size_t, 414 shuffle=False, # don't have to shuffle val or test batch 415 num_workers=self.num_workers_t, 416 ) 417 for task_id, dataset_test_t in self.dataset_test.items() 418 } 419 420 def __len__(self) -> int: 421 r"""Get the number of tasks in the dataset. 422 423 **Returns:** 424 - **num_tasks** (`int`): the number of tasks in the dataset. 425 """ 426 return self.num_tasks
The base class of continual learning datasets.
38 def __init__( 39 self, 40 root: str | dict[int, str], 41 num_tasks: int, 42 batch_size: int | dict[int, int] = 1, 43 num_workers: int | dict[int, int] = 0, 44 custom_transforms: ( 45 Callable 46 | transforms.Compose 47 | None 48 | dict[int, Callable | transforms.Compose | None] 49 ) = None, 50 repeat_channels: int | None | dict[int, int | None] = None, 51 to_tensor: bool | dict[int, bool] = True, 52 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 53 ) -> None: 54 r""" 55 **Args:** 56 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live. 57 If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. 58 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 59 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 60 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 61 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 62 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 63 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 64 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 65 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 66 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 67 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 68 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 69 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 70 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 71 """ 72 super().__init__() 73 74 self.root: dict[int, str] = ( 75 OmegaConf.to_container(root) 76 if isinstance(root, DictConfig) 77 else {t: root for t in range(1, num_tasks + 1)} 78 ) 79 r"""The dict of root directories of the original data files for each task.""" 80 self.num_tasks: int = num_tasks 81 r"""The maximum number of tasks supported by the dataset.""" 82 self.cl_paradigm: str 83 r"""The continual learning paradigm.""" 84 self.batch_size: dict[int, int] = ( 85 OmegaConf.to_container(batch_size) 86 if isinstance(batch_size, DictConfig) 87 else {t: batch_size for t in range(1, num_tasks + 1)} 88 ) 89 r"""The dict of batch sizes for each task.""" 90 self.num_workers: dict[int, int] = ( 91 OmegaConf.to_container(num_workers) 92 if isinstance(num_workers, DictConfig) 93 else {t: num_workers for t in range(1, num_tasks + 1)} 94 ) 95 r"""The dict of numbers of workers for each task.""" 96 self.custom_transforms: dict[int, Callable | transforms.Compose | None] = ( 97 OmegaConf.to_container(custom_transforms) 98 if isinstance(custom_transforms, DictConfig) 99 else {t: custom_transforms for t in range(1, num_tasks + 1)} 100 ) 101 r"""The dict of custom transforms for each task.""" 102 self.repeat_channels: dict[int, int | None] = ( 103 OmegaConf.to_container(repeat_channels) 104 if isinstance(repeat_channels, DictConfig) 105 else {t: repeat_channels for t in range(1, num_tasks + 1)} 106 ) 107 r"""The dict of number of channels to repeat for each task.""" 108 self.to_tensor: dict[int, bool] = ( 109 OmegaConf.to_container(to_tensor) 110 if isinstance(to_tensor, DictConfig) 111 else {t: to_tensor for t in range(1, num_tasks + 1)} 112 ) 113 r"""The dict of to_tensor flag for each task. """ 114 self.resize: dict[int, tuple[int, int] | None] = ( 115 {t: tuple(rs) if rs else None for t, rs in resize.items()} 116 if isinstance(resize, DictConfig) 117 else { 118 t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1) 119 } 120 ) 121 r"""The dict of sizes to resize to for each task.""" 122 123 # task-specific attributes 124 self.root_t: str 125 r"""The root directory of the original data files for the current task `self.task_id`.""" 126 self.batch_size_t: int 127 r"""The batch size for the current task `self.task_id`.""" 128 self.num_workers_t: int 129 r"""The number of workers for the current task `self.task_id`.""" 130 self.custom_transforms_t: Callable | transforms.Compose | None 131 r"""The custom transforms for the current task `self.task_id`.""" 132 self.repeat_channels_t: int | None 133 r"""The number of channels to repeat for the current task `self.task_id`.""" 134 self.to_tensor_t: bool 135 r"""The to_tensor flag for the current task `self.task_id`.""" 136 self.resize_t: tuple[int, int] | None 137 r"""The size to resize for the current task `self.task_id`.""" 138 self.mean_t: float 139 r"""The mean values for normalization for the current task `self.task_id`.""" 140 self.std_t: float 141 r"""The standard deviation values for normalization for the current task `self.task_id`.""" 142 143 # dataset containers 144 self.dataset_train_t: Any 145 r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 146 self.dataset_val_t: Any 147 r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 148 self.dataset_test: dict[int, Any] = {} 149 r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.""" 150 151 # task ID control 152 self.task_id: int 153 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 154 self.processed_task_ids: list[int] = [] 155 r"""Task IDs that have been processed.""" 156 157 CLDataset.sanity_check(self)
Args:
- root (
str|dict[int, str]): the root directory where the original data files for constructing the CL dataset physically live. If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. - num_tasks (
int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 tonum_tasks. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
The dict of custom transforms for each task.
The custom transforms for the current task self.task_id.
The training dataset object. Can be a PyTorch Dataset object or any other dataset object.
The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.
The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.
159 def sanity_check(self) -> None: 160 r"""Sanity check.""" 161 162 # check if each task has been provided with necessary arguments 163 for attr in [ 164 "root", 165 "batch_size", 166 "num_workers", 167 "custom_transforms", 168 "repeat_channels", 169 "to_tensor", 170 "resize", 171 ]: 172 value = getattr(self, attr) 173 expected_keys = set(range(1, self.num_tasks + 1)) 174 if set(value.keys()) != expected_keys: 175 raise ValueError( 176 f"{attr} dict keys must be consecutive integers from 1 to num_tasks." 177 )
Sanity check.
179 @abstractmethod 180 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 181 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses. 182 183 **Args:** 184 - **task_id** (`int`): the task ID to query the CL class map. 185 186 **Returns:** 187 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 188 - If `self.cl_paradigm` is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes. 189 - If `self.cl_paradigm` is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 190 """
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm. It must be implemented by subclasses.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
192 @abstractmethod 193 def prepare_data(self) -> None: 194 r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDataModule`. This method is called at the beginning of each task."""
Use this to download and prepare data. It must be implemented by subclasses, as required by LightningDataModule. This method is called at the beginning of each task.
196 def setup(self, stage: str) -> None: 197 r"""Set up the dataset for different stages. This method is called at the beginning of each task. 198 199 **Args:** 200 - **stage** (`str`): the stage of the experiment; one of: 201 - 'fit': training and validation datasets of the current task `self.task_id` are assigned to `self.dataset_train_t` and `self.dataset_val_t`. 202 - 'test': a dict of test datasets of all seen tasks should be assigned to `self.dataset_test`. 203 """ 204 if stage == "fit": 205 # these two stages must be done together because a sanity check for validation is conducted before training 206 pylogger.debug( 207 "Construct train and validation dataset for task %d...", self.task_id 208 ) 209 210 self.dataset_train_t, self.dataset_val_t = self.train_and_val_dataset() 211 212 pylogger.info( 213 "Train and validation dataset for task %d are ready.", self.task_id 214 ) 215 pylogger.info( 216 "Train dataset for task %d size: %d", 217 self.task_id, 218 len(self.dataset_train_t), 219 ) 220 pylogger.info( 221 "Validation dataset for task %d size: %d", 222 self.task_id, 223 len(self.dataset_val_t), 224 ) 225 226 elif stage == "test": 227 pylogger.debug("Construct test dataset for task %d...", self.task_id) 228 229 self.dataset_test[self.task_id] = self.test_dataset() 230 231 pylogger.info("Test dataset for task %d are ready.", self.task_id) 232 pylogger.info( 233 "Test dataset for task %d size: %d", 234 self.task_id, 235 len(self.dataset_test[self.task_id]), 236 )
Set up the dataset for different stages. This method is called at the beginning of each task.
Args:
- stage (
str): the stage of the experiment; one of:- 'fit': training and validation datasets of the current task
self.task_idare assigned toself.dataset_train_tandself.dataset_val_t. - 'test': a dict of test datasets of all seen tasks should be assigned to
self.dataset_test.
- 'fit': training and validation datasets of the current task
238 def setup_task_id(self, task_id: int) -> None: 239 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 240 241 **Args:** 242 - **task_id** (`int`): the target task ID. 243 """ 244 245 self.task_id = task_id 246 247 self.root_t = self.root[task_id] 248 self.batch_size_t = self.batch_size[task_id] 249 self.num_workers_t = self.num_workers[task_id] 250 self.custom_transforms_t = self.custom_transforms[task_id] 251 self.repeat_channels_t = self.repeat_channels[task_id] 252 self.to_tensor_t = self.to_tensor[task_id] 253 self.resize_t = self.resize[task_id] 254 255 self.processed_task_ids.append(task_id)
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.
257 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 258 r"""Set up tasks for continual learning main evaluation. 259 260 **Args:** 261 - **eval_tasks** (`list[int]`): the list of task IDs to evaluate. 262 """ 263 print(eval_tasks, "eval_tasks") 264 for task_id in eval_tasks: 265 self.setup_task_id(task_id=task_id) 266 self.setup(stage="test")
Set up tasks for continual learning main evaluation.
Args:
- eval_tasks (
list[int]): the list of task IDs to evaluate.
268 def set_cl_paradigm(self, cl_paradigm: str) -> None: 269 r"""Set `cl_paradigm` to `self.cl_paradigm`. It is used to define the CL class map. 270 271 **Args:** 272 - **cl_paradigm** (`str`): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). 273 """ 274 self.cl_paradigm = cl_paradigm
Set cl_paradigm to self.cl_paradigm. It is used to define the CL class map.
Args:
- cl_paradigm (
str): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
276 def train_and_val_transforms(self) -> transforms.Compose: 277 r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. It can be used in subclasses when constructing the dataset. 278 279 **Returns:** 280 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 281 """ 282 repeat_channels_transform = ( 283 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 284 if self.repeat_channels_t is not None 285 else None 286 ) 287 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 288 resize_transform = ( 289 transforms.Resize(self.resize_t) if self.resize_t is not None else None 290 ) 291 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 292 293 return transforms.Compose( 294 list( 295 filter( 296 None, 297 [ 298 repeat_channels_transform, 299 to_tensor_transform, 300 resize_transform, 301 self.custom_transforms_t, 302 normalization_transform, 303 ], 304 ) 305 ) 306 ) # the order of transforms matters
Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and ToTensor(). It can be used in subclasses when constructing the dataset.
Returns:
- train_and_val_transforms (
transforms.Compose): the composed train/val transforms.
308 def test_transforms(self) -> transforms.Compose: 309 r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. It is used in subclasses when constructing the dataset. 310 311 **Returns:** 312 - **test_transforms** (`transforms.Compose`): the composed test transforms. 313 """ 314 repeat_channels_transform = ( 315 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 316 if self.repeat_channels_t is not None 317 else None 318 ) 319 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 320 resize_transform = ( 321 transforms.Resize(self.resize_t) if self.resize_t is not None else None 322 ) 323 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 324 325 return transforms.Compose( 326 list( 327 filter( 328 None, 329 [ 330 repeat_channels_transform, 331 to_tensor_transform, 332 resize_transform, 333 normalization_transform, 334 ], 335 ) 336 ) 337 ) # the order of transforms matters. No custom transforms for test
Transforms for the test dataset. Only basic transforms like normalization and ToTensor() are included. It is used in subclasses when constructing the dataset.
Returns:
- test_transforms (
transforms.Compose): the composed test transforms.
339 def target_transform(self) -> ClassMapping: 340 r"""Target transform to map the original class labels to CL class labels according to `self.cl_paradigm`. It can be used in subclasses when constructing the dataset. 341 342 **Returns:** 343 - **target_transform** (`Callable`): the target transform function. 344 """ 345 346 cl_class_map = self.get_cl_class_map(task_id=self.task_id) 347 348 target_transform = ClassMapping(class_map=cl_class_map) 349 350 return target_transform
Target transform to map the original class labels to CL class labels according to self.cl_paradigm. It can be used in subclasses when constructing the dataset.
Returns:
- target_transform (
Callable): the target transform function.
352 @abstractmethod 353 def train_and_val_dataset(self) -> tuple[Any, Any]: 354 r"""Get the training and validation datasets of the current task `self.task_id`. It must be implemented by subclasses. 355 356 **Returns:** 357 - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation datasets of the current task `self.task_id`. 358 """
Get the training and validation datasets of the current task self.task_id. It must be implemented by subclasses.
Returns:
- train_and_val_dataset (
tuple[Any, Any]): the train and validation datasets of the current taskself.task_id.
360 @abstractmethod 361 def test_dataset(self) -> Any: 362 r"""Get the test dataset of the current task `self.task_id`. It must be implemented by subclasses. 363 364 **Returns:** 365 - **test_dataset** (`Any`): the test dataset of the current task `self.task_id`. 366 """
Get the test dataset of the current task self.task_id. It must be implemented by subclasses.
Returns:
- test_dataset (
Any): the test dataset of the current taskself.task_id.
368 def train_dataloader(self) -> DataLoader: 369 r"""DataLoader generator for the train stage of the current task `self.task_id`. It is automatically called before training the task. 370 371 **Returns:** 372 - **train_dataloader** (`DataLoader`): the train DataLoader of task `self.task_id`. 373 """ 374 375 pylogger.debug("Construct train dataloader for task %d...", self.task_id) 376 377 return DataLoader( 378 dataset=self.dataset_train_t, 379 batch_size=self.batch_size_t, 380 shuffle=True, # shuffle train batch to prevent overfitting 381 num_workers=self.num_workers_t, 382 drop_last=True, # to avoid batchnorm error (when batch_size is 1) 383 )
DataLoader generator for the train stage of the current task self.task_id. It is automatically called before training the task.
Returns:
- train_dataloader (
DataLoader): the train DataLoader of taskself.task_id.
385 def val_dataloader(self) -> DataLoader: 386 r"""DataLoader generator for the validation stage of the current task `self.task_id`. It is automatically called before the task's validation. 387 388 **Returns:** 389 - **val_dataloader** (`DataLoader`): the validation DataLoader of task `self.task_id`. 390 """ 391 392 pylogger.debug("Construct validation dataloader for task %d...", self.task_id) 393 394 return DataLoader( 395 dataset=self.dataset_val_t, 396 batch_size=self.batch_size_t, 397 shuffle=False, # don't have to shuffle val or test batch 398 num_workers=self.num_workers_t, 399 )
DataLoader generator for the validation stage of the current task self.task_id. It is automatically called before the task's validation.
Returns:
- val_dataloader (
DataLoader): the validation DataLoader of taskself.task_id.
401 def test_dataloader(self) -> dict[int, DataLoader]: 402 r"""DataLoader generator for the test stage of the current task `self.task_id`. It is automatically called before testing the task. 403 404 **Returns:** 405 - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders. 406 """ 407 408 pylogger.debug("Construct test dataloader for task %d...", self.task_id) 409 410 return { 411 task_id: DataLoader( 412 dataset=dataset_test_t, 413 batch_size=self.batch_size_t, 414 shuffle=False, # don't have to shuffle val or test batch 415 num_workers=self.num_workers_t, 416 ) 417 for task_id, dataset_test_t in self.dataset_test.items() 418 }
DataLoader generator for the test stage of the current task self.task_id. It is automatically called before testing the task.
Returns:
- test_dataloader (
dict[int, DataLoader]): the test DataLoader dict ofself.task_idand all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders.
429class CLPermutedDataset(CLDataset): 430 r"""The base class of continual learning datasets constructed as permutations of an original dataset.""" 431 432 original_dataset_python_class: type[Dataset] 433 r"""The original dataset class. **It must be provided in subclasses.** """ 434 435 def __init__( 436 self, 437 root: str, 438 num_tasks: int, 439 batch_size: int | dict[int, int] = 1, 440 num_workers: int | dict[int, int] = 0, 441 custom_transforms: ( 442 Callable 443 | transforms.Compose 444 | None 445 | dict[int, Callable | transforms.Compose | None] 446 ) = None, 447 repeat_channels: int | None | dict[int, int | None] = None, 448 to_tensor: bool | dict[int, bool] = True, 449 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 450 permutation_mode: str = "first_channel_only", 451 permutation_seeds: dict[int, int] | None = None, 452 ) -> None: 453 r""" 454 **Args:** 455 - **root** (`str`): the root directory where the original dataset live. 456 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 457 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 458 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 459 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 460 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 461 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 462 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 463 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 464 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 465 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 466 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 467 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 468 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 469 - **permutation_mode** (`str`): the mode of permutation; one of: 470 1. 'all': permute all pixels. 471 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 472 3. 'first_channel_only': permute only the first channel. 473 - **permutation_seeds** (`dict[int, int]` | `None`): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is `None`, which creates a dict of seeds from 0 to `num_tasks`-1. 474 """ 475 super().__init__( 476 root=root, 477 num_tasks=num_tasks, 478 batch_size=batch_size, 479 num_workers=num_workers, 480 custom_transforms=custom_transforms, 481 repeat_channels=repeat_channels, 482 to_tensor=to_tensor, 483 resize=resize, 484 ) 485 486 self.original_dataset_constants: type[DatasetConstants] = ( 487 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 488 ) 489 r"""The original dataset constants class.""" 490 491 self.permutation_mode: str = permutation_mode 492 r"""The mode of permutation.""" 493 self.permutation_seeds: dict[int, int] = ( 494 permutation_seeds 495 if permutation_seeds 496 else {t: t - 1 for t in range(1, num_tasks + 1)} 497 ) 498 r"""The dict of permutation seeds for each task.""" 499 500 self.permutation_seed_t: int 501 r"""The permutation seed for the current task `self.task_id`.""" 502 self.permute_transform_t: Permute 503 r"""The permutation transform for the current task `self.task_id`.""" 504 505 CLPermutedDataset.sanity_check(self) 506 507 def sanity_check(self) -> None: 508 r"""Sanity check.""" 509 510 # check the permutation mode 511 if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]: 512 raise ValueError( 513 "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'." 514 ) 515 516 # check the permutation seeds 517 expected_keys = set(range(1, self.num_tasks + 1)) 518 if set(self.permutation_seeds.keys()) != expected_keys: 519 raise ValueError( 520 f"{self.permutation_seeds} dict keys must be consecutive integers from 1 to num_tasks." 521 ) 522 523 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 524 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 525 526 **Args:** 527 - **task_id** (`int`): the task ID to query the CL class map. 528 529 **Returns:** 530 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 531 - If `self.cl_paradigm` is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 532 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 533 """ 534 535 num_classes_t = ( 536 self.original_dataset_constants.NUM_CLASSES 537 ) # the same with the original dataset 538 class_map_t = ( 539 self.original_dataset_constants.CLASS_MAP 540 ) # the same with the original dataset 541 542 if self.cl_paradigm == "TIL" or "DIL": 543 return {class_map_t[i]: i for i in range(num_classes_t)} 544 if self.cl_paradigm == "CIL": 545 return { 546 class_map_t[i]: i + (task_id - 1) * num_classes_t 547 for i in range(num_classes_t) 548 } 549 550 def setup_task_id(self, task_id: int) -> None: 551 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 552 553 **Args:** 554 - **task_id** (`int`): the target task ID. 555 """ 556 557 CLDataset.setup_task_id(self, task_id) 558 559 self.mean_t = ( 560 self.original_dataset_constants.MEAN 561 ) # the same with the original dataset 562 self.std_t = ( 563 self.original_dataset_constants.STD 564 ) # the same with the original dataset 565 566 num_channels = ( 567 self.original_dataset_constants.NUM_CHANNELS 568 if self.repeat_channels_t is None 569 else self.repeat_channels_t 570 ) 571 572 if ( 573 hasattr(self.original_dataset_constants, "IMG_SIZE") 574 or self.resize_t is not None 575 ): 576 img_size = ( 577 self.original_dataset_constants.IMG_SIZE 578 if self.resize_t is None 579 else torch.Size(self.resize_t) 580 ) 581 else: 582 raise AttributeError( 583 "The original dataset has different image sizes. Please resize the images to a fixed size by specifying hyperparameter: resize." 584 ) 585 586 # set up the permutation transform 587 self.permutation_seed_t = self.permutation_seeds[task_id] 588 self.permute_transform_t = Permute( 589 num_channels=num_channels, 590 img_size=img_size, 591 mode=self.permutation_mode, 592 seed=self.permutation_seed_t, 593 ) 594 595 def train_and_val_transforms(self) -> transforms.Compose: 596 r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. In permuted CL datasets, a permute transform also applies. 597 598 **Returns:** 599 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 600 """ 601 602 repeat_channels_transform = ( 603 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 604 if self.repeat_channels_t is not None 605 else None 606 ) 607 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 608 resize_transform = ( 609 transforms.Resize(self.resize_t) if self.resize_t is not None else None 610 ) 611 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 612 613 return transforms.Compose( 614 list( 615 filter( 616 None, 617 [ 618 repeat_channels_transform, 619 to_tensor_transform, 620 resize_transform, 621 self.permute_transform_t, # permutation is included here 622 self.custom_transforms_t, 623 normalization_transform, 624 ], 625 ) 626 ) 627 ) # the order of transforms matters 628 629 def test_transforms(self) -> transforms.Compose: 630 r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. In permuted CL datasets, a permute transform also applies. 631 632 **Returns:** 633 - **test_transforms** (`transforms.Compose`): the composed test transforms. 634 """ 635 636 repeat_channels_transform = ( 637 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 638 if self.repeat_channels_t is not None 639 else None 640 ) 641 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 642 resize_transform = ( 643 transforms.Resize(self.resize_t) if self.resize_t is not None else None 644 ) 645 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 646 647 return transforms.Compose( 648 list( 649 filter( 650 None, 651 [ 652 repeat_channels_transform, 653 to_tensor_transform, 654 resize_transform, 655 self.permute_transform_t, # permutation is included here 656 normalization_transform, 657 ], 658 ) 659 ) 660 ) # the order of transforms matters. No custom transforms for test
The base class of continual learning datasets constructed as permutations of an original dataset.
435 def __init__( 436 self, 437 root: str, 438 num_tasks: int, 439 batch_size: int | dict[int, int] = 1, 440 num_workers: int | dict[int, int] = 0, 441 custom_transforms: ( 442 Callable 443 | transforms.Compose 444 | None 445 | dict[int, Callable | transforms.Compose | None] 446 ) = None, 447 repeat_channels: int | None | dict[int, int | None] = None, 448 to_tensor: bool | dict[int, bool] = True, 449 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 450 permutation_mode: str = "first_channel_only", 451 permutation_seeds: dict[int, int] | None = None, 452 ) -> None: 453 r""" 454 **Args:** 455 - **root** (`str`): the root directory where the original dataset live. 456 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 457 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 458 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 459 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 460 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 461 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 462 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 463 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 464 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 465 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 466 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 467 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 468 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 469 - **permutation_mode** (`str`): the mode of permutation; one of: 470 1. 'all': permute all pixels. 471 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 472 3. 'first_channel_only': permute only the first channel. 473 - **permutation_seeds** (`dict[int, int]` | `None`): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is `None`, which creates a dict of seeds from 0 to `num_tasks`-1. 474 """ 475 super().__init__( 476 root=root, 477 num_tasks=num_tasks, 478 batch_size=batch_size, 479 num_workers=num_workers, 480 custom_transforms=custom_transforms, 481 repeat_channels=repeat_channels, 482 to_tensor=to_tensor, 483 resize=resize, 484 ) 485 486 self.original_dataset_constants: type[DatasetConstants] = ( 487 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 488 ) 489 r"""The original dataset constants class.""" 490 491 self.permutation_mode: str = permutation_mode 492 r"""The mode of permutation.""" 493 self.permutation_seeds: dict[int, int] = ( 494 permutation_seeds 495 if permutation_seeds 496 else {t: t - 1 for t in range(1, num_tasks + 1)} 497 ) 498 r"""The dict of permutation seeds for each task.""" 499 500 self.permutation_seed_t: int 501 r"""The permutation seed for the current task `self.task_id`.""" 502 self.permute_transform_t: Permute 503 r"""The permutation transform for the current task `self.task_id`.""" 504 505 CLPermutedDataset.sanity_check(self)
Args:
- root (
str): the root directory where the original dataset live. - num_tasks (
int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 tonum_tasks. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied. - permutation_mode (
str): the mode of permutation; one of:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- permutation_seeds (
dict[int, int]|None): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default isNone, which creates a dict of seeds from 0 tonum_tasks-1.
The original dataset class. It must be provided in subclasses.
The original dataset constants class.
The permutation transform for the current task self.task_id.
507 def sanity_check(self) -> None: 508 r"""Sanity check.""" 509 510 # check the permutation mode 511 if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]: 512 raise ValueError( 513 "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'." 514 ) 515 516 # check the permutation seeds 517 expected_keys = set(range(1, self.num_tasks + 1)) 518 if set(self.permutation_seeds.keys()) != expected_keys: 519 raise ValueError( 520 f"{self.permutation_seeds} dict keys must be consecutive integers from 1 to num_tasks." 521 )
Sanity check.
523 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 524 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 525 526 **Args:** 527 - **task_id** (`int`): the task ID to query the CL class map. 528 529 **Returns:** 530 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 531 - If `self.cl_paradigm` is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 532 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 533 """ 534 535 num_classes_t = ( 536 self.original_dataset_constants.NUM_CLASSES 537 ) # the same with the original dataset 538 class_map_t = ( 539 self.original_dataset_constants.CLASS_MAP 540 ) # the same with the original dataset 541 542 if self.cl_paradigm == "TIL" or "DIL": 543 return {class_map_t[i]: i for i in range(num_classes_t)} 544 if self.cl_paradigm == "CIL": 545 return { 546 class_map_t[i]: i + (task_id - 1) * num_classes_t 547 for i in range(num_classes_t) 548 }
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
550 def setup_task_id(self, task_id: int) -> None: 551 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 552 553 **Args:** 554 - **task_id** (`int`): the target task ID. 555 """ 556 557 CLDataset.setup_task_id(self, task_id) 558 559 self.mean_t = ( 560 self.original_dataset_constants.MEAN 561 ) # the same with the original dataset 562 self.std_t = ( 563 self.original_dataset_constants.STD 564 ) # the same with the original dataset 565 566 num_channels = ( 567 self.original_dataset_constants.NUM_CHANNELS 568 if self.repeat_channels_t is None 569 else self.repeat_channels_t 570 ) 571 572 if ( 573 hasattr(self.original_dataset_constants, "IMG_SIZE") 574 or self.resize_t is not None 575 ): 576 img_size = ( 577 self.original_dataset_constants.IMG_SIZE 578 if self.resize_t is None 579 else torch.Size(self.resize_t) 580 ) 581 else: 582 raise AttributeError( 583 "The original dataset has different image sizes. Please resize the images to a fixed size by specifying hyperparameter: resize." 584 ) 585 586 # set up the permutation transform 587 self.permutation_seed_t = self.permutation_seeds[task_id] 588 self.permute_transform_t = Permute( 589 num_channels=num_channels, 590 img_size=img_size, 591 mode=self.permutation_mode, 592 seed=self.permutation_seed_t, 593 )
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.
595 def train_and_val_transforms(self) -> transforms.Compose: 596 r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. In permuted CL datasets, a permute transform also applies. 597 598 **Returns:** 599 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 600 """ 601 602 repeat_channels_transform = ( 603 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 604 if self.repeat_channels_t is not None 605 else None 606 ) 607 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 608 resize_transform = ( 609 transforms.Resize(self.resize_t) if self.resize_t is not None else None 610 ) 611 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 612 613 return transforms.Compose( 614 list( 615 filter( 616 None, 617 [ 618 repeat_channels_transform, 619 to_tensor_transform, 620 resize_transform, 621 self.permute_transform_t, # permutation is included here 622 self.custom_transforms_t, 623 normalization_transform, 624 ], 625 ) 626 ) 627 ) # the order of transforms matters
Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and ToTensor(). In permuted CL datasets, a permute transform also applies.
Returns:
- train_and_val_transforms (
transforms.Compose): the composed train/val transforms.
629 def test_transforms(self) -> transforms.Compose: 630 r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. In permuted CL datasets, a permute transform also applies. 631 632 **Returns:** 633 - **test_transforms** (`transforms.Compose`): the composed test transforms. 634 """ 635 636 repeat_channels_transform = ( 637 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 638 if self.repeat_channels_t is not None 639 else None 640 ) 641 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 642 resize_transform = ( 643 transforms.Resize(self.resize_t) if self.resize_t is not None else None 644 ) 645 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 646 647 return transforms.Compose( 648 list( 649 filter( 650 None, 651 [ 652 repeat_channels_transform, 653 to_tensor_transform, 654 resize_transform, 655 self.permute_transform_t, # permutation is included here 656 normalization_transform, 657 ], 658 ) 659 ) 660 ) # the order of transforms matters. No custom transforms for test
Transforms for the test dataset. Only basic transforms like normalization and ToTensor() are included. In permuted CL datasets, a permute transform also applies.
Returns:
- test_transforms (
transforms.Compose): the composed test transforms.
665class CLRotatedDataset(CLDataset): 666 r"""The base class of continual learning datasets constructed as rotations of an original dataset.""" 667 668 original_dataset_python_class: type[Dataset] 669 r"""The original dataset class. **It must be provided in subclasses.** """ 670 671 def __init__( 672 self, 673 root: str, 674 num_tasks: int, 675 batch_size: int | dict[int, int] = 1, 676 num_workers: int | dict[int, int] = 0, 677 custom_transforms: ( 678 Callable 679 | transforms.Compose 680 | None 681 | dict[int, Callable | transforms.Compose | None] 682 ) = None, 683 repeat_channels: int | None | dict[int, int | None] = None, 684 to_tensor: bool | dict[int, bool] = True, 685 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 686 rotation_degrees: dict[int, float] | list[float] | None = None, 687 ) -> None: 688 r""" 689 **Args:** 690 - **root** (`str`): the root directory where the original dataset live. 691 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 692 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 693 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 694 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 695 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 696 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, and rotation are not included. 697 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 698 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 699 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 700 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 701 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 702 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 703 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 704 - **rotation_degrees** (`dict[int, float]` | `list[float]` | `None`): the rotation degrees for each task. If it is a list, its length must match `num_tasks` and it is mapped to task IDs in order. If it is `None`, angles are evenly spaced in [0, 180) across tasks. 705 """ 706 super().__init__( 707 root=root, 708 num_tasks=num_tasks, 709 batch_size=batch_size, 710 num_workers=num_workers, 711 custom_transforms=custom_transforms, 712 repeat_channels=repeat_channels, 713 to_tensor=to_tensor, 714 resize=resize, 715 ) 716 717 self.original_dataset_constants: type[DatasetConstants] = ( 718 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 719 ) 720 r"""The original dataset constants class.""" 721 722 if isinstance(rotation_degrees, (DictConfig, ListConfig)): 723 rotation_degrees = OmegaConf.to_container(rotation_degrees) 724 725 if rotation_degrees is None: 726 step = 180.0 / num_tasks 727 rotation_degrees = { 728 t: (t - 1) * step for t in range(1, num_tasks + 1) 729 } 730 elif isinstance(rotation_degrees, (list, tuple)): 731 if len(rotation_degrees) != num_tasks: 732 raise ValueError( 733 "rotation_degrees length must match num_tasks." 734 ) 735 rotation_degrees = { 736 t: float(rotation_degrees[t - 1]) 737 for t in range(1, num_tasks + 1) 738 } 739 elif isinstance(rotation_degrees, dict): 740 rotation_degrees = { 741 int(task_id): float(angle) 742 for task_id, angle in rotation_degrees.items() 743 } 744 else: 745 raise TypeError( 746 "rotation_degrees must be a dict, list, or None." 747 ) 748 749 self.rotation_degrees: dict[int, float] = rotation_degrees 750 r"""The rotation degrees for each task.""" 751 self.rotation_degree_t: float 752 r"""The rotation degree for the current task `self.task_id`.""" 753 self.rotation_transform_t: transforms.RandomRotation 754 r"""The rotation transform for the current task `self.task_id`.""" 755 756 CLRotatedDataset.sanity_check(self) 757 758 def sanity_check(self) -> None: 759 r"""Sanity check.""" 760 761 expected_keys = set(range(1, self.num_tasks + 1)) 762 if set(self.rotation_degrees.keys()) != expected_keys: 763 raise ValueError( 764 "rotation_degrees dict keys must be consecutive integers from 1 to num_tasks." 765 ) 766 767 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 768 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 769 770 **Args:** 771 - **task_id** (`int`): the task ID to query the CL class map. 772 773 **Returns:** 774 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 775 - If `self.cl_paradigm` is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 776 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 777 """ 778 num_classes_t = self.original_dataset_constants.NUM_CLASSES 779 class_map_t = self.original_dataset_constants.CLASS_MAP 780 781 if self.cl_paradigm in ["TIL", "DIL"]: 782 return {class_map_t[i]: i for i in range(num_classes_t)} 783 if self.cl_paradigm == "CIL": 784 return { 785 class_map_t[i]: i + (task_id - 1) * num_classes_t 786 for i in range(num_classes_t) 787 } 788 789 raise ValueError(f"Unsupported cl_paradigm: {self.cl_paradigm}") 790 791 def setup_task_id(self, task_id: int) -> None: 792 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 793 794 **Args:** 795 - **task_id** (`int`): the target task ID. 796 """ 797 798 CLDataset.setup_task_id(self, task_id) 799 800 self.mean_t = self.original_dataset_constants.MEAN 801 self.std_t = self.original_dataset_constants.STD 802 803 self.rotation_degree_t = self.rotation_degrees[task_id] 804 self.rotation_transform_t = transforms.RandomRotation( 805 degrees=(self.rotation_degree_t, self.rotation_degree_t), 806 fill=0, 807 ) 808 809 def train_and_val_transforms(self) -> transforms.Compose: 810 r"""Transforms for training and validation datasets. Rotation is applied before `ToTensor()` to keep PIL-based rotation. 811 812 **Returns:** 813 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 814 """ 815 repeat_channels_transform = ( 816 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 817 if self.repeat_channels_t is not None 818 else None 819 ) 820 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 821 resize_transform = ( 822 transforms.Resize(self.resize_t) if self.resize_t is not None else None 823 ) 824 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 825 826 return transforms.Compose( 827 list( 828 filter( 829 None, 830 [ 831 repeat_channels_transform, 832 self.rotation_transform_t, 833 to_tensor_transform, 834 resize_transform, 835 self.custom_transforms_t, 836 normalization_transform, 837 ], 838 ) 839 ) 840 ) # the order of transforms matters 841 842 def test_transforms(self) -> transforms.Compose: 843 r"""Transforms for the test dataset. Rotation is applied before `ToTensor()` to keep PIL-based rotation. 844 845 **Returns:** 846 - **test_transforms** (`transforms.Compose`): the composed test transforms. 847 """ 848 repeat_channels_transform = ( 849 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 850 if self.repeat_channels_t is not None 851 else None 852 ) 853 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 854 resize_transform = ( 855 transforms.Resize(self.resize_t) if self.resize_t is not None else None 856 ) 857 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 858 859 return transforms.Compose( 860 list( 861 filter( 862 None, 863 [ 864 repeat_channels_transform, 865 self.rotation_transform_t, 866 to_tensor_transform, 867 resize_transform, 868 normalization_transform, 869 ], 870 ) 871 ) 872 ) # the order of transforms matters. No custom transforms for test
The base class of continual learning datasets constructed as rotations of an original dataset.
671 def __init__( 672 self, 673 root: str, 674 num_tasks: int, 675 batch_size: int | dict[int, int] = 1, 676 num_workers: int | dict[int, int] = 0, 677 custom_transforms: ( 678 Callable 679 | transforms.Compose 680 | None 681 | dict[int, Callable | transforms.Compose | None] 682 ) = None, 683 repeat_channels: int | None | dict[int, int | None] = None, 684 to_tensor: bool | dict[int, bool] = True, 685 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 686 rotation_degrees: dict[int, float] | list[float] | None = None, 687 ) -> None: 688 r""" 689 **Args:** 690 - **root** (`str`): the root directory where the original dataset live. 691 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 692 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 693 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 694 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 695 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 696 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, and rotation are not included. 697 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 698 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 699 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 700 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 701 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 702 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 703 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 704 - **rotation_degrees** (`dict[int, float]` | `list[float]` | `None`): the rotation degrees for each task. If it is a list, its length must match `num_tasks` and it is mapped to task IDs in order. If it is `None`, angles are evenly spaced in [0, 180) across tasks. 705 """ 706 super().__init__( 707 root=root, 708 num_tasks=num_tasks, 709 batch_size=batch_size, 710 num_workers=num_workers, 711 custom_transforms=custom_transforms, 712 repeat_channels=repeat_channels, 713 to_tensor=to_tensor, 714 resize=resize, 715 ) 716 717 self.original_dataset_constants: type[DatasetConstants] = ( 718 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 719 ) 720 r"""The original dataset constants class.""" 721 722 if isinstance(rotation_degrees, (DictConfig, ListConfig)): 723 rotation_degrees = OmegaConf.to_container(rotation_degrees) 724 725 if rotation_degrees is None: 726 step = 180.0 / num_tasks 727 rotation_degrees = { 728 t: (t - 1) * step for t in range(1, num_tasks + 1) 729 } 730 elif isinstance(rotation_degrees, (list, tuple)): 731 if len(rotation_degrees) != num_tasks: 732 raise ValueError( 733 "rotation_degrees length must match num_tasks." 734 ) 735 rotation_degrees = { 736 t: float(rotation_degrees[t - 1]) 737 for t in range(1, num_tasks + 1) 738 } 739 elif isinstance(rotation_degrees, dict): 740 rotation_degrees = { 741 int(task_id): float(angle) 742 for task_id, angle in rotation_degrees.items() 743 } 744 else: 745 raise TypeError( 746 "rotation_degrees must be a dict, list, or None." 747 ) 748 749 self.rotation_degrees: dict[int, float] = rotation_degrees 750 r"""The rotation degrees for each task.""" 751 self.rotation_degree_t: float 752 r"""The rotation degree for the current task `self.task_id`.""" 753 self.rotation_transform_t: transforms.RandomRotation 754 r"""The rotation transform for the current task `self.task_id`.""" 755 756 CLRotatedDataset.sanity_check(self)
Args:
- root (
str): the root directory where the original dataset live. - num_tasks (
int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 tonum_tasks. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, and rotation are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied. - rotation_degrees (
dict[int, float]|list[float]|None): the rotation degrees for each task. If it is a list, its length must matchnum_tasksand it is mapped to task IDs in order. If it isNone, angles are evenly spaced in [0, 180) across tasks.
The original dataset class. It must be provided in subclasses.
The original dataset constants class.
The rotation transform for the current task self.task_id.
758 def sanity_check(self) -> None: 759 r"""Sanity check.""" 760 761 expected_keys = set(range(1, self.num_tasks + 1)) 762 if set(self.rotation_degrees.keys()) != expected_keys: 763 raise ValueError( 764 "rotation_degrees dict keys must be consecutive integers from 1 to num_tasks." 765 )
Sanity check.
767 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 768 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 769 770 **Args:** 771 - **task_id** (`int`): the task ID to query the CL class map. 772 773 **Returns:** 774 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 775 - If `self.cl_paradigm` is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 776 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 777 """ 778 num_classes_t = self.original_dataset_constants.NUM_CLASSES 779 class_map_t = self.original_dataset_constants.CLASS_MAP 780 781 if self.cl_paradigm in ["TIL", "DIL"]: 782 return {class_map_t[i]: i for i in range(num_classes_t)} 783 if self.cl_paradigm == "CIL": 784 return { 785 class_map_t[i]: i + (task_id - 1) * num_classes_t 786 for i in range(num_classes_t) 787 } 788 789 raise ValueError(f"Unsupported cl_paradigm: {self.cl_paradigm}")
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
791 def setup_task_id(self, task_id: int) -> None: 792 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 793 794 **Args:** 795 - **task_id** (`int`): the target task ID. 796 """ 797 798 CLDataset.setup_task_id(self, task_id) 799 800 self.mean_t = self.original_dataset_constants.MEAN 801 self.std_t = self.original_dataset_constants.STD 802 803 self.rotation_degree_t = self.rotation_degrees[task_id] 804 self.rotation_transform_t = transforms.RandomRotation( 805 degrees=(self.rotation_degree_t, self.rotation_degree_t), 806 fill=0, 807 )
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.
809 def train_and_val_transforms(self) -> transforms.Compose: 810 r"""Transforms for training and validation datasets. Rotation is applied before `ToTensor()` to keep PIL-based rotation. 811 812 **Returns:** 813 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 814 """ 815 repeat_channels_transform = ( 816 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 817 if self.repeat_channels_t is not None 818 else None 819 ) 820 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 821 resize_transform = ( 822 transforms.Resize(self.resize_t) if self.resize_t is not None else None 823 ) 824 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 825 826 return transforms.Compose( 827 list( 828 filter( 829 None, 830 [ 831 repeat_channels_transform, 832 self.rotation_transform_t, 833 to_tensor_transform, 834 resize_transform, 835 self.custom_transforms_t, 836 normalization_transform, 837 ], 838 ) 839 ) 840 ) # the order of transforms matters
Transforms for training and validation datasets. Rotation is applied before ToTensor() to keep PIL-based rotation.
Returns:
- train_and_val_transforms (
transforms.Compose): the composed train/val transforms.
842 def test_transforms(self) -> transforms.Compose: 843 r"""Transforms for the test dataset. Rotation is applied before `ToTensor()` to keep PIL-based rotation. 844 845 **Returns:** 846 - **test_transforms** (`transforms.Compose`): the composed test transforms. 847 """ 848 repeat_channels_transform = ( 849 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 850 if self.repeat_channels_t is not None 851 else None 852 ) 853 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 854 resize_transform = ( 855 transforms.Resize(self.resize_t) if self.resize_t is not None else None 856 ) 857 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 858 859 return transforms.Compose( 860 list( 861 filter( 862 None, 863 [ 864 repeat_channels_transform, 865 self.rotation_transform_t, 866 to_tensor_transform, 867 resize_transform, 868 normalization_transform, 869 ], 870 ) 871 ) 872 ) # the order of transforms matters. No custom transforms for test
Transforms for the test dataset. Rotation is applied before ToTensor() to keep PIL-based rotation.
Returns:
- test_transforms (
transforms.Compose): the composed test transforms.
874class CLSplitDataset(CLDataset): 875 r"""The base class of continual learning datasets constructed as splits of an original dataset.""" 876 877 original_dataset_python_class: type[Dataset] 878 r"""The original dataset class. **It must be provided in subclasses.** """ 879 880 def __init__( 881 self, 882 root: str, 883 class_split: dict[int, list[int]], 884 batch_size: int | dict[int, int] = 1, 885 num_workers: int | dict[int, int] = 0, 886 custom_transforms: ( 887 Callable 888 | transforms.Compose 889 | None 890 | dict[int, Callable | transforms.Compose | None] 891 ) = None, 892 repeat_channels: int | None | dict[int, int | None] = None, 893 to_tensor: bool | dict[int, bool] = True, 894 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 895 ) -> None: 896 r""" 897 **Args:** 898 - **root** (`str`): the root directory where the original dataset live. 899 - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task. 900 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 901 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 902 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 903 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 904 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 905 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 906 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 907 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 908 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 909 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 910 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 911 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 912 """ 913 super().__init__( 914 root=root, 915 num_tasks=len( 916 class_split 917 ), # num_tasks is not explicitly provided, but derived from the class_split length 918 batch_size=batch_size, 919 num_workers=num_workers, 920 custom_transforms=custom_transforms, 921 repeat_channels=repeat_channels, 922 to_tensor=to_tensor, 923 resize=resize, 924 ) 925 926 self.original_dataset_constants: type[DatasetConstants] = ( 927 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 928 ) 929 r"""The original dataset constants class. """ 930 931 self.class_split: dict[int, list[int]] = OmegaConf.to_container(class_split) 932 r"""The dict of class splits for each task.""" 933 934 CLSplitDataset.sanity_check(self) 935 936 def sanity_check(self) -> None: 937 r"""Sanity check.""" 938 939 # check the class split 940 expected_keys = set(range(1, self.num_tasks + 1)) 941 if set(self.class_split.keys()) != expected_keys: 942 raise ValueError( 943 f"{self.class_split} dict keys must be consecutive integers from 1 to num_tasks." 944 ) 945 if any(len(split) < 2 for split in self.class_split.values()): 946 raise ValueError("Each class split must contain at least 2 elements!") 947 948 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 949 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 950 951 **Args:** 952 - **task_id** (`int`): the task ID to query the CL class map. 953 954 **Returns:** 955 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 956 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 957 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 958 """ 959 num_classes_t = len( 960 self.class_split[task_id] 961 ) # the number of classes in the current task, i.e. the length of the class split 962 class_map_t = ( 963 self.original_dataset_constants.CLASS_MAP 964 ) # the same with the original dataset 965 966 if self.cl_paradigm == "TIL": 967 return { 968 class_map_t[self.class_split[task_id][i]]: i 969 for i in range(num_classes_t) 970 } 971 if self.cl_paradigm == "CIL": 972 num_classes_previous = sum( 973 len(self.class_split[i]) for i in range(1, task_id) 974 ) 975 return { 976 class_map_t[self.class_split[task_id][i]]: num_classes_previous + i 977 for i in range(num_classes_t) 978 } 979 980 def setup_task_id(self, task_id: int) -> None: 981 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 982 983 **Args:** 984 - **task_id** (`int`): the target task ID. 985 """ 986 super().setup_task_id(task_id) 987 988 self.mean_t = ( 989 self.original_dataset_constants.MEAN 990 ) # the same with the original dataset 991 self.std_t = ( 992 self.original_dataset_constants.STD 993 ) # the same with the original dataset 994 995 @abstractmethod 996 def get_subset_of_classes(self, dataset: Dataset) -> Dataset: 997 r"""Get a subset of classes from the dataset for the current task `self.task_id`. It is used when constructing the split. **It must be implemented by subclasses.** 998 999 **Args:** 1000 - **dataset** (`Dataset`): the dataset to retrieve the subset from. 1001 1002 **Returns:** 1003 - **subset** (`Dataset`): the subset of classes from the dataset. 1004 """
The base class of continual learning datasets constructed as splits of an original dataset.
880 def __init__( 881 self, 882 root: str, 883 class_split: dict[int, list[int]], 884 batch_size: int | dict[int, int] = 1, 885 num_workers: int | dict[int, int] = 0, 886 custom_transforms: ( 887 Callable 888 | transforms.Compose 889 | None 890 | dict[int, Callable | transforms.Compose | None] 891 ) = None, 892 repeat_channels: int | None | dict[int, int | None] = None, 893 to_tensor: bool | dict[int, bool] = True, 894 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 895 ) -> None: 896 r""" 897 **Args:** 898 - **root** (`str`): the root directory where the original dataset live. 899 - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task. 900 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 901 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 902 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 903 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 904 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 905 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 906 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 907 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 908 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 909 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 910 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 911 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 912 """ 913 super().__init__( 914 root=root, 915 num_tasks=len( 916 class_split 917 ), # num_tasks is not explicitly provided, but derived from the class_split length 918 batch_size=batch_size, 919 num_workers=num_workers, 920 custom_transforms=custom_transforms, 921 repeat_channels=repeat_channels, 922 to_tensor=to_tensor, 923 resize=resize, 924 ) 925 926 self.original_dataset_constants: type[DatasetConstants] = ( 927 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 928 ) 929 r"""The original dataset constants class. """ 930 931 self.class_split: dict[int, list[int]] = OmegaConf.to_container(class_split) 932 r"""The dict of class splits for each task.""" 933 934 CLSplitDataset.sanity_check(self)
Args:
- root (
str): the root directory where the original dataset live. - class_split (
dict[int, list[int]]): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
The original dataset class. It must be provided in subclasses.
The original dataset constants class.
936 def sanity_check(self) -> None: 937 r"""Sanity check.""" 938 939 # check the class split 940 expected_keys = set(range(1, self.num_tasks + 1)) 941 if set(self.class_split.keys()) != expected_keys: 942 raise ValueError( 943 f"{self.class_split} dict keys must be consecutive integers from 1 to num_tasks." 944 ) 945 if any(len(split) < 2 for split in self.class_split.values()): 946 raise ValueError("Each class split must contain at least 2 elements!")
Sanity check.
948 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 949 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 950 951 **Args:** 952 - **task_id** (`int`): the task ID to query the CL class map. 953 954 **Returns:** 955 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 956 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 957 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 958 """ 959 num_classes_t = len( 960 self.class_split[task_id] 961 ) # the number of classes in the current task, i.e. the length of the class split 962 class_map_t = ( 963 self.original_dataset_constants.CLASS_MAP 964 ) # the same with the original dataset 965 966 if self.cl_paradigm == "TIL": 967 return { 968 class_map_t[self.class_split[task_id][i]]: i 969 for i in range(num_classes_t) 970 } 971 if self.cl_paradigm == "CIL": 972 num_classes_previous = sum( 973 len(self.class_split[i]) for i in range(1, task_id) 974 ) 975 return { 976 class_map_t[self.class_split[task_id][i]]: num_classes_previous + i 977 for i in range(num_classes_t) 978 }
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
980 def setup_task_id(self, task_id: int) -> None: 981 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 982 983 **Args:** 984 - **task_id** (`int`): the target task ID. 985 """ 986 super().setup_task_id(task_id) 987 988 self.mean_t = ( 989 self.original_dataset_constants.MEAN 990 ) # the same with the original dataset 991 self.std_t = ( 992 self.original_dataset_constants.STD 993 ) # the same with the original dataset
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.
995 @abstractmethod 996 def get_subset_of_classes(self, dataset: Dataset) -> Dataset: 997 r"""Get a subset of classes from the dataset for the current task `self.task_id`. It is used when constructing the split. **It must be implemented by subclasses.** 998 999 **Args:** 1000 - **dataset** (`Dataset`): the dataset to retrieve the subset from. 1001 1002 **Returns:** 1003 - **subset** (`Dataset`): the subset of classes from the dataset. 1004 """
Get a subset of classes from the dataset for the current task self.task_id. It is used when constructing the split. It must be implemented by subclasses.
Args:
- dataset (
Dataset): the dataset to retrieve the subset from.
Returns:
- subset (
Dataset): the subset of classes from the dataset.
1007class CLCombinedDataset(CLDataset): 1008 r"""The base class of continual learning datasets constructed as combinations of several single-task datasets (one dataset per task).""" 1009 1010 def __init__( 1011 self, 1012 datasets: dict[int, str], 1013 root: str | dict[int, str], 1014 batch_size: int | dict[int, int] = 1, 1015 num_workers: int | dict[int, int] = 0, 1016 custom_transforms: ( 1017 Callable 1018 | transforms.Compose 1019 | None 1020 | dict[int, Callable | transforms.Compose | None] 1021 ) = None, 1022 repeat_channels: int | None | dict[int, int | None] = None, 1023 to_tensor: bool | dict[int, bool] = True, 1024 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 1025 ) -> None: 1026 r""" 1027 **Args:** 1028 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 1029 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live. 1030 If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. 1031 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 1032 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 1033 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 1034 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 1035 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 1036 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 1037 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 1038 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 1039 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 1040 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 1041 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 1042 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 1043 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 1044 """ 1045 super().__init__( 1046 root=root, 1047 num_tasks=len( 1048 datasets 1049 ), # num_tasks is not explicitly provided, but derived from the datasets length 1050 batch_size=batch_size, 1051 num_workers=num_workers, 1052 custom_transforms=custom_transforms, 1053 repeat_channels=repeat_channels, 1054 to_tensor=to_tensor, 1055 resize=resize, 1056 ) 1057 1058 self.original_dataset_python_classes: dict[int, Dataset] = { 1059 t: str_to_class(dataset_class_path) 1060 for t, dataset_class_path in datasets.items() 1061 } 1062 r"""The dict of dataset classes for each task.""" 1063 self.original_dataset_python_class_t: Dataset 1064 r"""The dataset class for the current task `self.task_id`.""" 1065 self.original_dataset_constants_t: type[DatasetConstants] 1066 r"""The original dataset constants class for the current task `self.task_id`.""" 1067 1068 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 1069 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 1070 1071 **Args:** 1072 - **task_id** (`int`): the task ID to query the CL class map. 1073 1074 **Returns:** 1075 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 1076 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 1077 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 1078 """ 1079 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 1080 original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[ 1081 original_dataset_python_class_t 1082 ] 1083 num_classes_t = original_dataset_constants_t.NUM_CLASSES 1084 class_map_t = original_dataset_constants_t.CLASS_MAP 1085 1086 if self.cl_paradigm == "TIL": 1087 return {class_map_t[i]: i for i in range(num_classes_t)} 1088 if self.cl_paradigm == "CIL": 1089 num_classes_previous = sum( 1090 [ 1091 DATASET_CONSTANTS_MAPPING[ 1092 self.original_dataset_python_classes[i] 1093 ].NUM_CLASSES 1094 for i in range(1, task_id) 1095 ] 1096 ) 1097 return { 1098 class_map_t[i]: num_classes_previous + i for i in range(num_classes_t) 1099 } 1100 1101 def setup_task_id(self, task_id: int) -> None: 1102 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 1103 1104 **Args:** 1105 - **task_id** (`int`): the target task ID. 1106 """ 1107 1108 self.original_dataset_python_class_t = self.original_dataset_python_classes[ 1109 task_id 1110 ] 1111 1112 self.original_dataset_constants_t: type[DatasetConstants] = ( 1113 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class_t] 1114 ) 1115 1116 super().setup_task_id(task_id) 1117 1118 self.mean_t = self.original_dataset_constants_t.MEAN 1119 self.std_t = self.original_dataset_constants_t.STD
The base class of continual learning datasets constructed as combinations of several single-task datasets (one dataset per task).
1010 def __init__( 1011 self, 1012 datasets: dict[int, str], 1013 root: str | dict[int, str], 1014 batch_size: int | dict[int, int] = 1, 1015 num_workers: int | dict[int, int] = 0, 1016 custom_transforms: ( 1017 Callable 1018 | transforms.Compose 1019 | None 1020 | dict[int, Callable | transforms.Compose | None] 1021 ) = None, 1022 repeat_channels: int | None | dict[int, int | None] = None, 1023 to_tensor: bool | dict[int, bool] = True, 1024 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 1025 ) -> None: 1026 r""" 1027 **Args:** 1028 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 1029 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live. 1030 If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. 1031 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 1032 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 1033 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 1034 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 1035 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 1036 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 1037 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 1038 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 1039 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 1040 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 1041 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 1042 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 1043 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 1044 """ 1045 super().__init__( 1046 root=root, 1047 num_tasks=len( 1048 datasets 1049 ), # num_tasks is not explicitly provided, but derived from the datasets length 1050 batch_size=batch_size, 1051 num_workers=num_workers, 1052 custom_transforms=custom_transforms, 1053 repeat_channels=repeat_channels, 1054 to_tensor=to_tensor, 1055 resize=resize, 1056 ) 1057 1058 self.original_dataset_python_classes: dict[int, Dataset] = { 1059 t: str_to_class(dataset_class_path) 1060 for t, dataset_class_path in datasets.items() 1061 } 1062 r"""The dict of dataset classes for each task.""" 1063 self.original_dataset_python_class_t: Dataset 1064 r"""The dataset class for the current task `self.task_id`.""" 1065 self.original_dataset_constants_t: type[DatasetConstants] 1066 r"""The original dataset constants class for the current task `self.task_id`."""
Args:
- datasets (
dict[int, str]): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. - root (
str|dict[int, str]): the root directory where the original data files for constructing the CL dataset physically live. If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. - num_tasks (
int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 tonum_tasks. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
The dict of dataset classes for each task.
The dataset class for the current task self.task_id.
The original dataset constants class for the current task self.task_id.
1068 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 1069 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 1070 1071 **Args:** 1072 - **task_id** (`int`): the task ID to query the CL class map. 1073 1074 **Returns:** 1075 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 1076 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 1077 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 1078 """ 1079 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 1080 original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[ 1081 original_dataset_python_class_t 1082 ] 1083 num_classes_t = original_dataset_constants_t.NUM_CLASSES 1084 class_map_t = original_dataset_constants_t.CLASS_MAP 1085 1086 if self.cl_paradigm == "TIL": 1087 return {class_map_t[i]: i for i in range(num_classes_t)} 1088 if self.cl_paradigm == "CIL": 1089 num_classes_previous = sum( 1090 [ 1091 DATASET_CONSTANTS_MAPPING[ 1092 self.original_dataset_python_classes[i] 1093 ].NUM_CLASSES 1094 for i in range(1, task_id) 1095 ] 1096 ) 1097 return { 1098 class_map_t[i]: num_classes_previous + i for i in range(num_classes_t) 1099 }
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
1101 def setup_task_id(self, task_id: int) -> None: 1102 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 1103 1104 **Args:** 1105 - **task_id** (`int`): the target task ID. 1106 """ 1107 1108 self.original_dataset_python_class_t = self.original_dataset_python_classes[ 1109 task_id 1110 ] 1111 1112 self.original_dataset_constants_t: type[DatasetConstants] = ( 1113 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class_t] 1114 ) 1115 1116 super().setup_task_id(task_id) 1117 1118 self.mean_t = self.original_dataset_constants_t.MEAN 1119 self.std_t = self.original_dataset_constants_t.STD
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.