clarena.cl_datasets
Continual Learning Datasets
This submodule provides the continual learning datasets that can be used in CLArena.
Here are the base classes for continual learning datasets, which inherit from Lightning LightningDataModule:
CLDataset: The base class for all continual learning datasets.CLPermutedDataset: The base class for permuted continual learning datasets. A child class ofCLDataset.CLSplitDataset: The base class for split continual learning datasets. A child class ofCLDataset.CLCombinedDataset: The base class for combined continual learning datasets. A child class ofCLDataset.
Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement continual learning datasets:
1r""" 2 3# Continual Learning Datasets 4 5This submodule provides the **continual learning datasets** that can be used in CLArena. 6 7Here are the base classes for continual learning datasets, which inherit from Lightning `LightningDataModule`: 8 9- `CLDataset`: The base class for all continual learning datasets. 10 - `CLPermutedDataset`: The base class for permuted continual learning datasets. A child class of `CLDataset`. 11 - `CLSplitDataset`: The base class for split continual learning datasets. A child class of `CLDataset`. 12 - `CLCombinedDataset`: The base class for combined continual learning datasets. A child class of `CLDataset`. 13 14Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement continual learning datasets: 15 16- [**Configure CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/CL-dataset) 17- [**Implement Custom CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/cl_dataset) 18- [**A Beginners' Guide to Continual Learning (CL Dataset)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-CL-dataset) 19 20 21""" 22 23from .base import ( 24 CLDataset, 25 CLPermutedDataset, 26 CLSplitDataset, 27 CLCombinedDataset, 28) 29 30from .permuted_mnist import PermutedMNIST 31from .permuted_emnist import PermutedEMNIST 32from .permuted_fashionmnist import PermutedFashionMNIST 33from .permuted_kmnist import PermutedKMNIST 34from .permuted_notmnist import PermutedNotMNIST 35from .permuted_sign_language_mnist import PermutedSignLanguageMNIST 36from .permuted_ahdd import PermutedArabicHandwrittenDigits 37from .permuted_kannadamnist import PermutedKannadaMNIST 38from .permuted_svhn import PermutedSVHN 39from .permuted_country211 import PermutedCountry211 40from .permuted_imagenette import PermutedImagenette 41from .permuted_dtd import PermutedDTD 42from .permuted_cifar10 import PermutedCIFAR10 43from .permuted_cifar100 import PermutedCIFAR100 44from .permuted_caltech101 import PermutedCaltech101 45from .permuted_caltech256 import PermutedCaltech256 46from .permuted_eurosat import PermutedEuroSAT 47from .permuted_fgvc_aircraft import PermutedFGVCAircraft 48from .permuted_flowers102 import PermutedFlowers102 49from .permuted_food101 import PermutedFood101 50from .permuted_celeba import PermutedCelebA 51from .permuted_fer2013 import PermutedFER2013 52from .permuted_tinyimagenet import PermutedTinyImageNet 53from .permuted_oxford_iiit_pet import PermutedOxfordIIITPet 54from .permuted_pcam import PermutedPCAM 55from .permuted_renderedsst2 import PermutedRenderedSST2 56from .permuted_stanfordcars import PermutedStanfordCars 57from .permuted_sun397 import PermutedSUN397 58from .permuted_usps import PermutedUSPS 59from .permuted_SEMEION import PermutedSEMEION 60from .permuted_facescrub import PermutedFaceScrub 61from .permuted_cub2002011 import PermutedCUB2002011 62from .permuted_gtsrb import PermutedGTSRB 63from .permuted_linnaeus5 import PermutedLinnaeus5 64 65from .split_cifar10 import SplitCIFAR10 66from .split_mnist import SplitMNIST 67from .split_cifar100 import SplitCIFAR100 68from .split_tinyimagenet import SplitTinyImageNet 69from .split_cub2002011 import SplitCUB2002011 70 71from .combined import Combined 72 73 74__all__ = [ 75 "CLDataset", 76 "CLPermutedDataset", 77 "CLSplitDataset", 78 "CLCombinedDataset", 79 "combined", 80 "permuted_mnist", 81 "permuted_emnist", 82 "permuted_fashionmnist", 83 "permuted_imagenette", 84 "permuted_sign_language_mnist", 85 "permuted_ahdd", 86 "permuted_kannadamnist", 87 "permuted_country211", 88 "permuted_dtd", 89 "permuted_fer2013", 90 "permuted_fgvc_aircraft", 91 "permuted_flowers102", 92 "permuted_food101", 93 "permuted_kmnist", 94 "permuted_notmnist", 95 "permuted_svhn", 96 "permuted_cifar10", 97 "permuted_cifar100", 98 "permuted_caltech101", 99 "permuted_caltech256", 100 "permuted_oxford_iiit_pet", 101 "permuted_celeba", 102 "permuted_eurosat", 103 "permuted_facescrub", 104 "permuted_pcam", 105 "permuted_renderedsst2", 106 "permuted_stanfordcars", 107 "permuted_sun397", 108 "permuted_usps", 109 "permuted_SEMEION", 110 "permuted_tinyimagenet", 111 "permuted_cub2002011", 112 "permuted_gtsrb", 113 "permuted_linnaeus5", 114 "split_mnist", 115 "split_cifar10", 116 "split_cifar100", 117 "split_tinyimagenet", 118 "split_cub2002011", 119]
34class CLDataset(LightningDataModule): 35 r"""The base class of continual learning datasets.""" 36 37 def __init__( 38 self, 39 root: str | dict[int, str], 40 num_tasks: int, 41 batch_size: int | dict[int, int] = 1, 42 num_workers: int | dict[int, int] = 0, 43 custom_transforms: ( 44 Callable 45 | transforms.Compose 46 | None 47 | dict[int, Callable | transforms.Compose | None] 48 ) = None, 49 repeat_channels: int | None | dict[int, int | None] = None, 50 to_tensor: bool | dict[int, bool] = True, 51 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 52 ) -> None: 53 r""" 54 **Args:** 55 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live. 56 If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. 57 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 58 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 59 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 60 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 61 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 62 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 63 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 64 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 65 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 66 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 67 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 68 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 69 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 70 """ 71 super().__init__() 72 73 self.root: dict[int, str] = ( 74 OmegaConf.to_container(root) 75 if isinstance(root, DictConfig) 76 else {t: root for t in range(1, num_tasks + 1)} 77 ) 78 r"""The dict of root directories of the original data files for each task.""" 79 self.num_tasks: int = num_tasks 80 r"""The maximum number of tasks supported by the dataset.""" 81 self.cl_paradigm: str 82 r"""The continual learning paradigm.""" 83 self.batch_size: dict[int, int] = ( 84 OmegaConf.to_container(batch_size) 85 if isinstance(batch_size, DictConfig) 86 else {t: batch_size for t in range(1, num_tasks + 1)} 87 ) 88 r"""The dict of batch sizes for each task.""" 89 self.num_workers: dict[int, int] = ( 90 OmegaConf.to_container(num_workers) 91 if isinstance(num_workers, DictConfig) 92 else {t: num_workers for t in range(1, num_tasks + 1)} 93 ) 94 r"""The dict of numbers of workers for each task.""" 95 self.custom_transforms: dict[int, Callable | transforms.Compose | None] = ( 96 OmegaConf.to_container(custom_transforms) 97 if isinstance(custom_transforms, DictConfig) 98 else {t: custom_transforms for t in range(1, num_tasks + 1)} 99 ) 100 r"""The dict of custom transforms for each task.""" 101 self.repeat_channels: dict[int, int | None] = ( 102 OmegaConf.to_container(repeat_channels) 103 if isinstance(repeat_channels, DictConfig) 104 else {t: repeat_channels for t in range(1, num_tasks + 1)} 105 ) 106 r"""The dict of number of channels to repeat for each task.""" 107 self.to_tensor: dict[int, bool] = ( 108 OmegaConf.to_container(to_tensor) 109 if isinstance(to_tensor, DictConfig) 110 else {t: to_tensor for t in range(1, num_tasks + 1)} 111 ) 112 r"""The dict of to_tensor flag for each task. """ 113 self.resize: dict[int, tuple[int, int] | None] = ( 114 {t: tuple(rs) if rs else None for t, rs in resize.items()} 115 if isinstance(resize, DictConfig) 116 else { 117 t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1) 118 } 119 ) 120 r"""The dict of sizes to resize to for each task.""" 121 122 # task-specific attributes 123 self.root_t: str 124 r"""The root directory of the original data files for the current task `self.task_id`.""" 125 self.batch_size_t: int 126 r"""The batch size for the current task `self.task_id`.""" 127 self.num_workers_t: int 128 r"""The number of workers for the current task `self.task_id`.""" 129 self.custom_transforms_t: Callable | transforms.Compose | None 130 r"""The custom transforms for the current task `self.task_id`.""" 131 self.repeat_channels_t: int | None 132 r"""The number of channels to repeat for the current task `self.task_id`.""" 133 self.to_tensor_t: bool 134 r"""The to_tensor flag for the current task `self.task_id`.""" 135 self.resize_t: tuple[int, int] | None 136 r"""The size to resize for the current task `self.task_id`.""" 137 self.mean_t: float 138 r"""The mean values for normalization for the current task `self.task_id`.""" 139 self.std_t: float 140 r"""The standard deviation values for normalization for the current task `self.task_id`.""" 141 142 # dataset containers 143 self.dataset_train_t: Any 144 r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 145 self.dataset_val_t: Any 146 r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 147 self.dataset_test: dict[int, Any] = {} 148 r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.""" 149 150 # task ID control 151 self.task_id: int 152 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 153 self.processed_task_ids: list[int] = [] 154 r"""Task IDs that have been processed.""" 155 156 CLDataset.sanity_check(self) 157 158 def sanity_check(self) -> None: 159 r"""Sanity check.""" 160 161 # check if each task has been provided with necessary arguments 162 for attr in [ 163 "root", 164 "batch_size", 165 "num_workers", 166 "custom_transforms", 167 "repeat_channels", 168 "to_tensor", 169 "resize", 170 ]: 171 value = getattr(self, attr) 172 expected_keys = set(range(1, self.num_tasks + 1)) 173 if set(value.keys()) != expected_keys: 174 raise ValueError( 175 f"{attr} dict keys must be consecutive integers from 1 to num_tasks." 176 ) 177 178 @abstractmethod 179 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 180 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses. 181 182 **Args:** 183 - **task_id** (`int`): the task ID to query the CL class map. 184 185 **Returns:** 186 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 187 - If `self.cl_paradigm` is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes. 188 - If `self.cl_paradigm` is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 189 """ 190 191 @abstractmethod 192 def prepare_data(self) -> None: 193 r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDataModule`. This method is called at the beginning of each task.""" 194 195 def setup(self, stage: str) -> None: 196 r"""Set up the dataset for different stages. This method is called at the beginning of each task. 197 198 **Args:** 199 - **stage** (`str`): the stage of the experiment; one of: 200 - 'fit': training and validation datasets of the current task `self.task_id` are assigned to `self.dataset_train_t` and `self.dataset_val_t`. 201 - 'test': a dict of test datasets of all seen tasks should be assigned to `self.dataset_test`. 202 """ 203 if stage == "fit": 204 # these two stages must be done together because a sanity check for validation is conducted before training 205 pylogger.debug( 206 "Construct train and validation dataset for task %d...", self.task_id 207 ) 208 209 self.dataset_train_t, self.dataset_val_t = self.train_and_val_dataset() 210 211 pylogger.info( 212 "Train and validation dataset for task %d are ready.", self.task_id 213 ) 214 pylogger.info( 215 "Train dataset for task %d size: %d", 216 self.task_id, 217 len(self.dataset_train_t), 218 ) 219 pylogger.info( 220 "Validation dataset for task %d size: %d", 221 self.task_id, 222 len(self.dataset_val_t), 223 ) 224 225 elif stage == "test": 226 227 pylogger.debug("Construct test dataset for task %d...", self.task_id) 228 229 self.dataset_test[self.task_id] = self.test_dataset() 230 231 pylogger.info("Test dataset for task %d are ready.", self.task_id) 232 pylogger.info( 233 "Test dataset for task %d size: %d", 234 self.task_id, 235 len(self.dataset_test[self.task_id]), 236 ) 237 238 def setup_task_id(self, task_id: int) -> None: 239 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 240 241 **Args:** 242 - **task_id** (`int`): the target task ID. 243 """ 244 245 self.task_id = task_id 246 247 self.root_t = self.root[task_id] 248 self.batch_size_t = self.batch_size[task_id] 249 self.num_workers_t = self.num_workers[task_id] 250 self.custom_transforms_t = self.custom_transforms[task_id] 251 self.repeat_channels_t = self.repeat_channels[task_id] 252 self.to_tensor_t = self.to_tensor[task_id] 253 self.resize_t = self.resize[task_id] 254 255 self.processed_task_ids.append(task_id) 256 257 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 258 r"""Set up tasks for continual learning main evaluation. 259 260 **Args:** 261 - **eval_tasks** (`list[int]`): the list of task IDs to evaluate. 262 """ 263 for task_id in eval_tasks: 264 self.setup_task_id(task_id=task_id) 265 self.setup(stage="test") 266 267 def set_cl_paradigm(self, cl_paradigm: str) -> None: 268 r"""Set `cl_paradigm` to `self.cl_paradigm`. It is used to define the CL class map. 269 270 **Args:** 271 - **cl_paradigm** (`str`): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). 272 """ 273 self.cl_paradigm = cl_paradigm 274 275 def train_and_val_transforms(self) -> transforms.Compose: 276 r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. It can be used in subclasses when constructing the dataset. 277 278 **Returns:** 279 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 280 """ 281 repeat_channels_transform = ( 282 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 283 if self.repeat_channels_t is not None 284 else None 285 ) 286 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 287 resize_transform = ( 288 transforms.Resize(self.resize_t) if self.resize_t is not None else None 289 ) 290 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 291 292 return transforms.Compose( 293 list( 294 filter( 295 None, 296 [ 297 repeat_channels_transform, 298 to_tensor_transform, 299 resize_transform, 300 self.custom_transforms_t, 301 normalization_transform, 302 ], 303 ) 304 ) 305 ) # the order of transforms matters 306 307 def test_transforms(self) -> transforms.Compose: 308 r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. It is used in subclasses when constructing the dataset. 309 310 **Returns:** 311 - **test_transforms** (`transforms.Compose`): the composed test transforms. 312 """ 313 repeat_channels_transform = ( 314 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 315 if self.repeat_channels_t is not None 316 else None 317 ) 318 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 319 resize_transform = ( 320 transforms.Resize(self.resize_t) if self.resize_t is not None else None 321 ) 322 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 323 324 return transforms.Compose( 325 list( 326 filter( 327 None, 328 [ 329 repeat_channels_transform, 330 to_tensor_transform, 331 resize_transform, 332 normalization_transform, 333 ], 334 ) 335 ) 336 ) # the order of transforms matters. No custom transforms for test 337 338 def target_transform(self) -> ClassMapping: 339 r"""Target transform to map the original class labels to CL class labels according to `self.cl_paradigm`. It can be used in subclasses when constructing the dataset. 340 341 **Returns:** 342 - **target_transform** (`Callable`): the target transform function. 343 """ 344 345 cl_class_map = self.get_cl_class_map(task_id=self.task_id) 346 347 target_transform = ClassMapping(class_map=cl_class_map) 348 349 return target_transform 350 351 @abstractmethod 352 def train_and_val_dataset(self) -> tuple[Any, Any]: 353 r"""Get the training and validation datasets of the current task `self.task_id`. It must be implemented by subclasses. 354 355 **Returns:** 356 - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation datasets of the current task `self.task_id`. 357 """ 358 359 @abstractmethod 360 def test_dataset(self) -> Any: 361 r"""Get the test dataset of the current task `self.task_id`. It must be implemented by subclasses. 362 363 **Returns:** 364 - **test_dataset** (`Any`): the test dataset of the current task `self.task_id`. 365 """ 366 367 def train_dataloader(self) -> DataLoader: 368 r"""DataLoader generator for the train stage of the current task `self.task_id`. It is automatically called before training the task. 369 370 **Returns:** 371 - **train_dataloader** (`DataLoader`): the train DataLoader of task `self.task_id`. 372 """ 373 374 pylogger.debug("Construct train dataloader for task %d...", self.task_id) 375 376 return DataLoader( 377 dataset=self.dataset_train_t, 378 batch_size=self.batch_size_t, 379 shuffle=True, # shuffle train batch to prevent overfitting 380 num_workers=self.num_workers_t, 381 drop_last=True, # to avoid batchnorm error (when batch_size is 1) 382 ) 383 384 def val_dataloader(self) -> DataLoader: 385 r"""DataLoader generator for the validation stage of the current task `self.task_id`. It is automatically called before the task's validation. 386 387 **Returns:** 388 - **val_dataloader** (`DataLoader`): the validation DataLoader of task `self.task_id`. 389 """ 390 391 pylogger.debug("Construct validation dataloader for task %d...", self.task_id) 392 393 return DataLoader( 394 dataset=self.dataset_val_t, 395 batch_size=self.batch_size_t, 396 shuffle=False, # don't have to shuffle val or test batch 397 num_workers=self.num_workers_t, 398 ) 399 400 def test_dataloader(self) -> dict[int, DataLoader]: 401 r"""DataLoader generator for the test stage of the current task `self.task_id`. It is automatically called before testing the task. 402 403 **Returns:** 404 - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders. 405 """ 406 407 pylogger.debug("Construct test dataloader for task %d...", self.task_id) 408 409 return { 410 task_id: DataLoader( 411 dataset=dataset_test_t, 412 batch_size=self.batch_size_t, 413 shuffle=False, # don't have to shuffle val or test batch 414 num_workers=self.num_workers_t, 415 ) 416 for task_id, dataset_test_t in self.dataset_test.items() 417 } 418 419 def __len__(self) -> int: 420 r"""Get the number of tasks in the dataset. 421 422 **Returns:** 423 - **num_tasks** (`int`): the number of tasks in the dataset. 424 """ 425 return self.num_tasks
The base class of continual learning datasets.
37 def __init__( 38 self, 39 root: str | dict[int, str], 40 num_tasks: int, 41 batch_size: int | dict[int, int] = 1, 42 num_workers: int | dict[int, int] = 0, 43 custom_transforms: ( 44 Callable 45 | transforms.Compose 46 | None 47 | dict[int, Callable | transforms.Compose | None] 48 ) = None, 49 repeat_channels: int | None | dict[int, int | None] = None, 50 to_tensor: bool | dict[int, bool] = True, 51 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 52 ) -> None: 53 r""" 54 **Args:** 55 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live. 56 If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. 57 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 58 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 59 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 60 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 61 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 62 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 63 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 64 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 65 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 66 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 67 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 68 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 69 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 70 """ 71 super().__init__() 72 73 self.root: dict[int, str] = ( 74 OmegaConf.to_container(root) 75 if isinstance(root, DictConfig) 76 else {t: root for t in range(1, num_tasks + 1)} 77 ) 78 r"""The dict of root directories of the original data files for each task.""" 79 self.num_tasks: int = num_tasks 80 r"""The maximum number of tasks supported by the dataset.""" 81 self.cl_paradigm: str 82 r"""The continual learning paradigm.""" 83 self.batch_size: dict[int, int] = ( 84 OmegaConf.to_container(batch_size) 85 if isinstance(batch_size, DictConfig) 86 else {t: batch_size for t in range(1, num_tasks + 1)} 87 ) 88 r"""The dict of batch sizes for each task.""" 89 self.num_workers: dict[int, int] = ( 90 OmegaConf.to_container(num_workers) 91 if isinstance(num_workers, DictConfig) 92 else {t: num_workers for t in range(1, num_tasks + 1)} 93 ) 94 r"""The dict of numbers of workers for each task.""" 95 self.custom_transforms: dict[int, Callable | transforms.Compose | None] = ( 96 OmegaConf.to_container(custom_transforms) 97 if isinstance(custom_transforms, DictConfig) 98 else {t: custom_transforms for t in range(1, num_tasks + 1)} 99 ) 100 r"""The dict of custom transforms for each task.""" 101 self.repeat_channels: dict[int, int | None] = ( 102 OmegaConf.to_container(repeat_channels) 103 if isinstance(repeat_channels, DictConfig) 104 else {t: repeat_channels for t in range(1, num_tasks + 1)} 105 ) 106 r"""The dict of number of channels to repeat for each task.""" 107 self.to_tensor: dict[int, bool] = ( 108 OmegaConf.to_container(to_tensor) 109 if isinstance(to_tensor, DictConfig) 110 else {t: to_tensor for t in range(1, num_tasks + 1)} 111 ) 112 r"""The dict of to_tensor flag for each task. """ 113 self.resize: dict[int, tuple[int, int] | None] = ( 114 {t: tuple(rs) if rs else None for t, rs in resize.items()} 115 if isinstance(resize, DictConfig) 116 else { 117 t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1) 118 } 119 ) 120 r"""The dict of sizes to resize to for each task.""" 121 122 # task-specific attributes 123 self.root_t: str 124 r"""The root directory of the original data files for the current task `self.task_id`.""" 125 self.batch_size_t: int 126 r"""The batch size for the current task `self.task_id`.""" 127 self.num_workers_t: int 128 r"""The number of workers for the current task `self.task_id`.""" 129 self.custom_transforms_t: Callable | transforms.Compose | None 130 r"""The custom transforms for the current task `self.task_id`.""" 131 self.repeat_channels_t: int | None 132 r"""The number of channels to repeat for the current task `self.task_id`.""" 133 self.to_tensor_t: bool 134 r"""The to_tensor flag for the current task `self.task_id`.""" 135 self.resize_t: tuple[int, int] | None 136 r"""The size to resize for the current task `self.task_id`.""" 137 self.mean_t: float 138 r"""The mean values for normalization for the current task `self.task_id`.""" 139 self.std_t: float 140 r"""The standard deviation values for normalization for the current task `self.task_id`.""" 141 142 # dataset containers 143 self.dataset_train_t: Any 144 r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 145 self.dataset_val_t: Any 146 r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.""" 147 self.dataset_test: dict[int, Any] = {} 148 r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.""" 149 150 # task ID control 151 self.task_id: int 152 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 153 self.processed_task_ids: list[int] = [] 154 r"""Task IDs that have been processed.""" 155 156 CLDataset.sanity_check(self)
Args:
- root (
str|dict[int, str]): the root directory where the original data files for constructing the CL dataset physically live. If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. - num_tasks (
int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 tonum_tasks. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
The dict of custom transforms for each task.
The custom transforms for the current task self.task_id.
The training dataset object. Can be a PyTorch Dataset object or any other dataset object.
The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.
The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.
158 def sanity_check(self) -> None: 159 r"""Sanity check.""" 160 161 # check if each task has been provided with necessary arguments 162 for attr in [ 163 "root", 164 "batch_size", 165 "num_workers", 166 "custom_transforms", 167 "repeat_channels", 168 "to_tensor", 169 "resize", 170 ]: 171 value = getattr(self, attr) 172 expected_keys = set(range(1, self.num_tasks + 1)) 173 if set(value.keys()) != expected_keys: 174 raise ValueError( 175 f"{attr} dict keys must be consecutive integers from 1 to num_tasks." 176 )
Sanity check.
178 @abstractmethod 179 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 180 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses. 181 182 **Args:** 183 - **task_id** (`int`): the task ID to query the CL class map. 184 185 **Returns:** 186 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 187 - If `self.cl_paradigm` is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes. 188 - If `self.cl_paradigm` is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 189 """
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm. It must be implemented by subclasses.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
191 @abstractmethod 192 def prepare_data(self) -> None: 193 r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDataModule`. This method is called at the beginning of each task."""
Use this to download and prepare data. It must be implemented by subclasses, as required by LightningDataModule. This method is called at the beginning of each task.
195 def setup(self, stage: str) -> None: 196 r"""Set up the dataset for different stages. This method is called at the beginning of each task. 197 198 **Args:** 199 - **stage** (`str`): the stage of the experiment; one of: 200 - 'fit': training and validation datasets of the current task `self.task_id` are assigned to `self.dataset_train_t` and `self.dataset_val_t`. 201 - 'test': a dict of test datasets of all seen tasks should be assigned to `self.dataset_test`. 202 """ 203 if stage == "fit": 204 # these two stages must be done together because a sanity check for validation is conducted before training 205 pylogger.debug( 206 "Construct train and validation dataset for task %d...", self.task_id 207 ) 208 209 self.dataset_train_t, self.dataset_val_t = self.train_and_val_dataset() 210 211 pylogger.info( 212 "Train and validation dataset for task %d are ready.", self.task_id 213 ) 214 pylogger.info( 215 "Train dataset for task %d size: %d", 216 self.task_id, 217 len(self.dataset_train_t), 218 ) 219 pylogger.info( 220 "Validation dataset for task %d size: %d", 221 self.task_id, 222 len(self.dataset_val_t), 223 ) 224 225 elif stage == "test": 226 227 pylogger.debug("Construct test dataset for task %d...", self.task_id) 228 229 self.dataset_test[self.task_id] = self.test_dataset() 230 231 pylogger.info("Test dataset for task %d are ready.", self.task_id) 232 pylogger.info( 233 "Test dataset for task %d size: %d", 234 self.task_id, 235 len(self.dataset_test[self.task_id]), 236 )
Set up the dataset for different stages. This method is called at the beginning of each task.
Args:
- stage (
str): the stage of the experiment; one of:- 'fit': training and validation datasets of the current task
self.task_idare assigned toself.dataset_train_tandself.dataset_val_t. - 'test': a dict of test datasets of all seen tasks should be assigned to
self.dataset_test.
- 'fit': training and validation datasets of the current task
238 def setup_task_id(self, task_id: int) -> None: 239 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 240 241 **Args:** 242 - **task_id** (`int`): the target task ID. 243 """ 244 245 self.task_id = task_id 246 247 self.root_t = self.root[task_id] 248 self.batch_size_t = self.batch_size[task_id] 249 self.num_workers_t = self.num_workers[task_id] 250 self.custom_transforms_t = self.custom_transforms[task_id] 251 self.repeat_channels_t = self.repeat_channels[task_id] 252 self.to_tensor_t = self.to_tensor[task_id] 253 self.resize_t = self.resize[task_id] 254 255 self.processed_task_ids.append(task_id)
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.
257 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 258 r"""Set up tasks for continual learning main evaluation. 259 260 **Args:** 261 - **eval_tasks** (`list[int]`): the list of task IDs to evaluate. 262 """ 263 for task_id in eval_tasks: 264 self.setup_task_id(task_id=task_id) 265 self.setup(stage="test")
Set up tasks for continual learning main evaluation.
Args:
- eval_tasks (
list[int]): the list of task IDs to evaluate.
267 def set_cl_paradigm(self, cl_paradigm: str) -> None: 268 r"""Set `cl_paradigm` to `self.cl_paradigm`. It is used to define the CL class map. 269 270 **Args:** 271 - **cl_paradigm** (`str`): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). 272 """ 273 self.cl_paradigm = cl_paradigm
Set cl_paradigm to self.cl_paradigm. It is used to define the CL class map.
Args:
- cl_paradigm (
str): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
275 def train_and_val_transforms(self) -> transforms.Compose: 276 r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. It can be used in subclasses when constructing the dataset. 277 278 **Returns:** 279 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 280 """ 281 repeat_channels_transform = ( 282 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 283 if self.repeat_channels_t is not None 284 else None 285 ) 286 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 287 resize_transform = ( 288 transforms.Resize(self.resize_t) if self.resize_t is not None else None 289 ) 290 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 291 292 return transforms.Compose( 293 list( 294 filter( 295 None, 296 [ 297 repeat_channels_transform, 298 to_tensor_transform, 299 resize_transform, 300 self.custom_transforms_t, 301 normalization_transform, 302 ], 303 ) 304 ) 305 ) # the order of transforms matters
Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and ToTensor(). It can be used in subclasses when constructing the dataset.
Returns:
- train_and_val_transforms (
transforms.Compose): the composed train/val transforms.
307 def test_transforms(self) -> transforms.Compose: 308 r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. It is used in subclasses when constructing the dataset. 309 310 **Returns:** 311 - **test_transforms** (`transforms.Compose`): the composed test transforms. 312 """ 313 repeat_channels_transform = ( 314 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 315 if self.repeat_channels_t is not None 316 else None 317 ) 318 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 319 resize_transform = ( 320 transforms.Resize(self.resize_t) if self.resize_t is not None else None 321 ) 322 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 323 324 return transforms.Compose( 325 list( 326 filter( 327 None, 328 [ 329 repeat_channels_transform, 330 to_tensor_transform, 331 resize_transform, 332 normalization_transform, 333 ], 334 ) 335 ) 336 ) # the order of transforms matters. No custom transforms for test
Transforms for the test dataset. Only basic transforms like normalization and ToTensor() are included. It is used in subclasses when constructing the dataset.
Returns:
- test_transforms (
transforms.Compose): the composed test transforms.
338 def target_transform(self) -> ClassMapping: 339 r"""Target transform to map the original class labels to CL class labels according to `self.cl_paradigm`. It can be used in subclasses when constructing the dataset. 340 341 **Returns:** 342 - **target_transform** (`Callable`): the target transform function. 343 """ 344 345 cl_class_map = self.get_cl_class_map(task_id=self.task_id) 346 347 target_transform = ClassMapping(class_map=cl_class_map) 348 349 return target_transform
Target transform to map the original class labels to CL class labels according to self.cl_paradigm. It can be used in subclasses when constructing the dataset.
Returns:
- target_transform (
Callable): the target transform function.
351 @abstractmethod 352 def train_and_val_dataset(self) -> tuple[Any, Any]: 353 r"""Get the training and validation datasets of the current task `self.task_id`. It must be implemented by subclasses. 354 355 **Returns:** 356 - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation datasets of the current task `self.task_id`. 357 """
Get the training and validation datasets of the current task self.task_id. It must be implemented by subclasses.
Returns:
- train_and_val_dataset (
tuple[Any, Any]): the train and validation datasets of the current taskself.task_id.
359 @abstractmethod 360 def test_dataset(self) -> Any: 361 r"""Get the test dataset of the current task `self.task_id`. It must be implemented by subclasses. 362 363 **Returns:** 364 - **test_dataset** (`Any`): the test dataset of the current task `self.task_id`. 365 """
Get the test dataset of the current task self.task_id. It must be implemented by subclasses.
Returns:
- test_dataset (
Any): the test dataset of the current taskself.task_id.
367 def train_dataloader(self) -> DataLoader: 368 r"""DataLoader generator for the train stage of the current task `self.task_id`. It is automatically called before training the task. 369 370 **Returns:** 371 - **train_dataloader** (`DataLoader`): the train DataLoader of task `self.task_id`. 372 """ 373 374 pylogger.debug("Construct train dataloader for task %d...", self.task_id) 375 376 return DataLoader( 377 dataset=self.dataset_train_t, 378 batch_size=self.batch_size_t, 379 shuffle=True, # shuffle train batch to prevent overfitting 380 num_workers=self.num_workers_t, 381 drop_last=True, # to avoid batchnorm error (when batch_size is 1) 382 )
DataLoader generator for the train stage of the current task self.task_id. It is automatically called before training the task.
Returns:
- train_dataloader (
DataLoader): the train DataLoader of taskself.task_id.
384 def val_dataloader(self) -> DataLoader: 385 r"""DataLoader generator for the validation stage of the current task `self.task_id`. It is automatically called before the task's validation. 386 387 **Returns:** 388 - **val_dataloader** (`DataLoader`): the validation DataLoader of task `self.task_id`. 389 """ 390 391 pylogger.debug("Construct validation dataloader for task %d...", self.task_id) 392 393 return DataLoader( 394 dataset=self.dataset_val_t, 395 batch_size=self.batch_size_t, 396 shuffle=False, # don't have to shuffle val or test batch 397 num_workers=self.num_workers_t, 398 )
DataLoader generator for the validation stage of the current task self.task_id. It is automatically called before the task's validation.
Returns:
- val_dataloader (
DataLoader): the validation DataLoader of taskself.task_id.
400 def test_dataloader(self) -> dict[int, DataLoader]: 401 r"""DataLoader generator for the test stage of the current task `self.task_id`. It is automatically called before testing the task. 402 403 **Returns:** 404 - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders. 405 """ 406 407 pylogger.debug("Construct test dataloader for task %d...", self.task_id) 408 409 return { 410 task_id: DataLoader( 411 dataset=dataset_test_t, 412 batch_size=self.batch_size_t, 413 shuffle=False, # don't have to shuffle val or test batch 414 num_workers=self.num_workers_t, 415 ) 416 for task_id, dataset_test_t in self.dataset_test.items() 417 }
DataLoader generator for the test stage of the current task self.task_id. It is automatically called before testing the task.
Returns:
- test_dataloader (
dict[int, DataLoader]): the test DataLoader dict ofself.task_idand all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders.
428class CLPermutedDataset(CLDataset): 429 r"""The base class of continual learning datasets constructed as permutations of an original dataset.""" 430 431 original_dataset_python_class: type[Dataset] 432 r"""The original dataset class. **It must be provided in subclasses.** """ 433 434 def __init__( 435 self, 436 root: str, 437 num_tasks: int, 438 batch_size: int | dict[int, int] = 1, 439 num_workers: int | dict[int, int] = 0, 440 custom_transforms: ( 441 Callable 442 | transforms.Compose 443 | None 444 | dict[int, Callable | transforms.Compose | None] 445 ) = None, 446 repeat_channels: int | None | dict[int, int | None] = None, 447 to_tensor: bool | dict[int, bool] = True, 448 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 449 permutation_mode: str = "first_channel_only", 450 permutation_seeds: dict[int, int] | None = None, 451 ) -> None: 452 r""" 453 **Args:** 454 - **root** (`str`): the root directory where the original dataset live. 455 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 456 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 457 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 458 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 459 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 460 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 461 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 462 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 463 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 464 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 465 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 466 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 467 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 468 - **permutation_mode** (`str`): the mode of permutation; one of: 469 1. 'all': permute all pixels. 470 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 471 3. 'first_channel_only': permute only the first channel. 472 - **permutation_seeds** (`dict[int, int]` | `None`): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is `None`, which creates a dict of seeds from 0 to `num_tasks`-1. 473 """ 474 super().__init__( 475 root=root, 476 num_tasks=num_tasks, 477 batch_size=batch_size, 478 num_workers=num_workers, 479 custom_transforms=custom_transforms, 480 repeat_channels=repeat_channels, 481 to_tensor=to_tensor, 482 resize=resize, 483 ) 484 485 self.original_dataset_constants: type[DatasetConstants] = ( 486 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 487 ) 488 r"""The original dataset constants class.""" 489 490 self.permutation_mode: str = permutation_mode 491 r"""The mode of permutation.""" 492 self.permutation_seeds: dict[int, int] = ( 493 permutation_seeds 494 if permutation_seeds 495 else {t: t - 1 for t in range(1, num_tasks + 1)} 496 ) 497 r"""The dict of permutation seeds for each task.""" 498 499 self.permutation_seed_t: int 500 r"""The permutation seed for the current task `self.task_id`.""" 501 self.permute_transform_t: Permute 502 r"""The permutation transform for the current task `self.task_id`.""" 503 504 CLPermutedDataset.sanity_check(self) 505 506 def sanity_check(self) -> None: 507 r"""Sanity check.""" 508 509 # check the permutation mode 510 if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]: 511 raise ValueError( 512 "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'." 513 ) 514 515 # check the permutation seeds 516 expected_keys = set(range(1, self.num_tasks + 1)) 517 if set(self.permutation_seeds.keys()) != expected_keys: 518 raise ValueError( 519 f"{self.permutation_seeds} dict keys must be consecutive integers from 1 to num_tasks." 520 ) 521 522 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 523 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 524 525 **Args:** 526 - **task_id** (`int`): the task ID to query the CL class map. 527 528 **Returns:** 529 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 530 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 531 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 532 """ 533 534 num_classes_t = ( 535 self.original_dataset_constants.NUM_CLASSES 536 ) # the same with the original dataset 537 class_map_t = ( 538 self.original_dataset_constants.CLASS_MAP 539 ) # the same with the original dataset 540 541 if self.cl_paradigm == "TIL": 542 return {class_map_t[i]: i for i in range(num_classes_t)} 543 if self.cl_paradigm == "CIL": 544 return { 545 class_map_t[i]: i + (task_id - 1) * num_classes_t 546 for i in range(num_classes_t) 547 } 548 549 def setup_task_id(self, task_id: int) -> None: 550 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 551 552 **Args:** 553 - **task_id** (`int`): the target task ID. 554 """ 555 556 CLDataset.setup_task_id(self, task_id) 557 558 self.mean_t = ( 559 self.original_dataset_constants.MEAN 560 ) # the same with the original dataset 561 self.std_t = ( 562 self.original_dataset_constants.STD 563 ) # the same with the original dataset 564 565 num_channels = ( 566 self.original_dataset_constants.NUM_CHANNELS 567 if self.repeat_channels_t is None 568 else self.repeat_channels_t 569 ) 570 571 if ( 572 hasattr(self.original_dataset_constants, "IMG_SIZE") 573 or self.resize_t is not None 574 ): 575 img_size = ( 576 self.original_dataset_constants.IMG_SIZE 577 if self.resize_t is None 578 else torch.Size(self.resize_t) 579 ) 580 else: 581 raise AttributeError( 582 "The original dataset has different image sizes. Please resize the images to a fixed size by specifying hyperparameter: resize." 583 ) 584 585 # set up the permutation transform 586 self.permutation_seed_t = self.permutation_seeds[task_id] 587 self.permute_transform_t = Permute( 588 num_channels=num_channels, 589 img_size=img_size, 590 mode=self.permutation_mode, 591 seed=self.permutation_seed_t, 592 ) 593 594 def train_and_val_transforms(self) -> transforms.Compose: 595 r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. In permuted CL datasets, a permute transform also applies. 596 597 **Returns:** 598 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 599 """ 600 601 repeat_channels_transform = ( 602 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 603 if self.repeat_channels_t is not None 604 else None 605 ) 606 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 607 resize_transform = ( 608 transforms.Resize(self.resize_t) if self.resize_t is not None else None 609 ) 610 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 611 612 return transforms.Compose( 613 list( 614 filter( 615 None, 616 [ 617 repeat_channels_transform, 618 to_tensor_transform, 619 resize_transform, 620 self.permute_transform_t, # permutation is included here 621 self.custom_transforms_t, 622 normalization_transform, 623 ], 624 ) 625 ) 626 ) # the order of transforms matters 627 628 def test_transforms(self) -> transforms.Compose: 629 r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. In permuted CL datasets, a permute transform also applies. 630 631 **Returns:** 632 - **test_transforms** (`transforms.Compose`): the composed test transforms. 633 """ 634 635 repeat_channels_transform = ( 636 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 637 if self.repeat_channels_t is not None 638 else None 639 ) 640 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 641 resize_transform = ( 642 transforms.Resize(self.resize_t) if self.resize_t is not None else None 643 ) 644 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 645 646 return transforms.Compose( 647 list( 648 filter( 649 None, 650 [ 651 repeat_channels_transform, 652 to_tensor_transform, 653 resize_transform, 654 self.permute_transform_t, # permutation is included here 655 normalization_transform, 656 ], 657 ) 658 ) 659 ) # the order of transforms matters. No custom transforms for test
The base class of continual learning datasets constructed as permutations of an original dataset.
434 def __init__( 435 self, 436 root: str, 437 num_tasks: int, 438 batch_size: int | dict[int, int] = 1, 439 num_workers: int | dict[int, int] = 0, 440 custom_transforms: ( 441 Callable 442 | transforms.Compose 443 | None 444 | dict[int, Callable | transforms.Compose | None] 445 ) = None, 446 repeat_channels: int | None | dict[int, int | None] = None, 447 to_tensor: bool | dict[int, bool] = True, 448 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 449 permutation_mode: str = "first_channel_only", 450 permutation_seeds: dict[int, int] | None = None, 451 ) -> None: 452 r""" 453 **Args:** 454 - **root** (`str`): the root directory where the original dataset live. 455 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 456 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 457 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 458 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 459 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 460 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 461 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 462 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 463 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 464 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 465 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 466 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 467 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 468 - **permutation_mode** (`str`): the mode of permutation; one of: 469 1. 'all': permute all pixels. 470 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 471 3. 'first_channel_only': permute only the first channel. 472 - **permutation_seeds** (`dict[int, int]` | `None`): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is `None`, which creates a dict of seeds from 0 to `num_tasks`-1. 473 """ 474 super().__init__( 475 root=root, 476 num_tasks=num_tasks, 477 batch_size=batch_size, 478 num_workers=num_workers, 479 custom_transforms=custom_transforms, 480 repeat_channels=repeat_channels, 481 to_tensor=to_tensor, 482 resize=resize, 483 ) 484 485 self.original_dataset_constants: type[DatasetConstants] = ( 486 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 487 ) 488 r"""The original dataset constants class.""" 489 490 self.permutation_mode: str = permutation_mode 491 r"""The mode of permutation.""" 492 self.permutation_seeds: dict[int, int] = ( 493 permutation_seeds 494 if permutation_seeds 495 else {t: t - 1 for t in range(1, num_tasks + 1)} 496 ) 497 r"""The dict of permutation seeds for each task.""" 498 499 self.permutation_seed_t: int 500 r"""The permutation seed for the current task `self.task_id`.""" 501 self.permute_transform_t: Permute 502 r"""The permutation transform for the current task `self.task_id`.""" 503 504 CLPermutedDataset.sanity_check(self)
Args:
- root (
str): the root directory where the original dataset live. - num_tasks (
int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 tonum_tasks. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied. - permutation_mode (
str): the mode of permutation; one of:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- permutation_seeds (
dict[int, int]|None): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default isNone, which creates a dict of seeds from 0 tonum_tasks-1.
The original dataset class. It must be provided in subclasses.
The original dataset constants class.
The permutation transform for the current task self.task_id.
506 def sanity_check(self) -> None: 507 r"""Sanity check.""" 508 509 # check the permutation mode 510 if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]: 511 raise ValueError( 512 "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'." 513 ) 514 515 # check the permutation seeds 516 expected_keys = set(range(1, self.num_tasks + 1)) 517 if set(self.permutation_seeds.keys()) != expected_keys: 518 raise ValueError( 519 f"{self.permutation_seeds} dict keys must be consecutive integers from 1 to num_tasks." 520 )
Sanity check.
522 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 523 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 524 525 **Args:** 526 - **task_id** (`int`): the task ID to query the CL class map. 527 528 **Returns:** 529 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 530 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 531 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 532 """ 533 534 num_classes_t = ( 535 self.original_dataset_constants.NUM_CLASSES 536 ) # the same with the original dataset 537 class_map_t = ( 538 self.original_dataset_constants.CLASS_MAP 539 ) # the same with the original dataset 540 541 if self.cl_paradigm == "TIL": 542 return {class_map_t[i]: i for i in range(num_classes_t)} 543 if self.cl_paradigm == "CIL": 544 return { 545 class_map_t[i]: i + (task_id - 1) * num_classes_t 546 for i in range(num_classes_t) 547 }
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
549 def setup_task_id(self, task_id: int) -> None: 550 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 551 552 **Args:** 553 - **task_id** (`int`): the target task ID. 554 """ 555 556 CLDataset.setup_task_id(self, task_id) 557 558 self.mean_t = ( 559 self.original_dataset_constants.MEAN 560 ) # the same with the original dataset 561 self.std_t = ( 562 self.original_dataset_constants.STD 563 ) # the same with the original dataset 564 565 num_channels = ( 566 self.original_dataset_constants.NUM_CHANNELS 567 if self.repeat_channels_t is None 568 else self.repeat_channels_t 569 ) 570 571 if ( 572 hasattr(self.original_dataset_constants, "IMG_SIZE") 573 or self.resize_t is not None 574 ): 575 img_size = ( 576 self.original_dataset_constants.IMG_SIZE 577 if self.resize_t is None 578 else torch.Size(self.resize_t) 579 ) 580 else: 581 raise AttributeError( 582 "The original dataset has different image sizes. Please resize the images to a fixed size by specifying hyperparameter: resize." 583 ) 584 585 # set up the permutation transform 586 self.permutation_seed_t = self.permutation_seeds[task_id] 587 self.permute_transform_t = Permute( 588 num_channels=num_channels, 589 img_size=img_size, 590 mode=self.permutation_mode, 591 seed=self.permutation_seed_t, 592 )
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.
594 def train_and_val_transforms(self) -> transforms.Compose: 595 r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. In permuted CL datasets, a permute transform also applies. 596 597 **Returns:** 598 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 599 """ 600 601 repeat_channels_transform = ( 602 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 603 if self.repeat_channels_t is not None 604 else None 605 ) 606 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 607 resize_transform = ( 608 transforms.Resize(self.resize_t) if self.resize_t is not None else None 609 ) 610 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 611 612 return transforms.Compose( 613 list( 614 filter( 615 None, 616 [ 617 repeat_channels_transform, 618 to_tensor_transform, 619 resize_transform, 620 self.permute_transform_t, # permutation is included here 621 self.custom_transforms_t, 622 normalization_transform, 623 ], 624 ) 625 ) 626 ) # the order of transforms matters
Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and ToTensor(). In permuted CL datasets, a permute transform also applies.
Returns:
- train_and_val_transforms (
transforms.Compose): the composed train/val transforms.
628 def test_transforms(self) -> transforms.Compose: 629 r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. In permuted CL datasets, a permute transform also applies. 630 631 **Returns:** 632 - **test_transforms** (`transforms.Compose`): the composed test transforms. 633 """ 634 635 repeat_channels_transform = ( 636 transforms.Grayscale(num_output_channels=self.repeat_channels_t) 637 if self.repeat_channels_t is not None 638 else None 639 ) 640 to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None 641 resize_transform = ( 642 transforms.Resize(self.resize_t) if self.resize_t is not None else None 643 ) 644 normalization_transform = transforms.Normalize(self.mean_t, self.std_t) 645 646 return transforms.Compose( 647 list( 648 filter( 649 None, 650 [ 651 repeat_channels_transform, 652 to_tensor_transform, 653 resize_transform, 654 self.permute_transform_t, # permutation is included here 655 normalization_transform, 656 ], 657 ) 658 ) 659 ) # the order of transforms matters. No custom transforms for test
Transforms for the test dataset. Only basic transforms like normalization and ToTensor() are included. In permuted CL datasets, a permute transform also applies.
Returns:
- test_transforms (
transforms.Compose): the composed test transforms.
662class CLSplitDataset(CLDataset): 663 r"""The base class of continual learning datasets constructed as splits of an original dataset.""" 664 665 original_dataset_python_class: type[Dataset] 666 r"""The original dataset class. **It must be provided in subclasses.** """ 667 668 def __init__( 669 self, 670 root: str, 671 class_split: dict[int, list[int]], 672 batch_size: int | dict[int, int] = 1, 673 num_workers: int | dict[int, int] = 0, 674 custom_transforms: ( 675 Callable 676 | transforms.Compose 677 | None 678 | dict[int, Callable | transforms.Compose | None] 679 ) = None, 680 repeat_channels: int | None | dict[int, int | None] = None, 681 to_tensor: bool | dict[int, bool] = True, 682 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 683 ) -> None: 684 r""" 685 **Args:** 686 - **root** (`str`): the root directory where the original dataset live. 687 - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task. 688 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 689 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 690 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 691 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 692 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 693 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 694 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 695 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 696 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 697 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 698 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 699 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 700 """ 701 super().__init__( 702 root=root, 703 num_tasks=len( 704 class_split 705 ), # num_tasks is not explicitly provided, but derived from the class_split length 706 batch_size=batch_size, 707 num_workers=num_workers, 708 custom_transforms=custom_transforms, 709 repeat_channels=repeat_channels, 710 to_tensor=to_tensor, 711 resize=resize, 712 ) 713 714 self.original_dataset_constants: type[DatasetConstants] = ( 715 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 716 ) 717 r"""The original dataset constants class. """ 718 719 self.class_split: dict[int, list[int]] = OmegaConf.to_container(class_split) 720 r"""The dict of class splits for each task.""" 721 722 CLSplitDataset.sanity_check(self) 723 724 def sanity_check(self) -> None: 725 r"""Sanity check.""" 726 727 # check the class split 728 expected_keys = set(range(1, self.num_tasks + 1)) 729 if set(self.class_split.keys()) != expected_keys: 730 raise ValueError( 731 f"{self.class_split} dict keys must be consecutive integers from 1 to num_tasks." 732 ) 733 if any(len(split) < 2 for split in self.class_split.values()): 734 raise ValueError("Each class split must contain at least 2 elements!") 735 736 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 737 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 738 739 **Args:** 740 - **task_id** (`int`): the task ID to query the CL class map. 741 742 **Returns:** 743 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 744 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 745 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 746 """ 747 num_classes_t = len( 748 self.class_split[task_id] 749 ) # the number of classes in the current task, i.e. the length of the class split 750 class_map_t = ( 751 self.original_dataset_constants.CLASS_MAP 752 ) # the same with the original dataset 753 754 if self.cl_paradigm == "TIL": 755 return { 756 class_map_t[self.class_split[task_id][i]]: i 757 for i in range(num_classes_t) 758 } 759 if self.cl_paradigm == "CIL": 760 num_classes_previous = sum( 761 len(self.class_split[i]) for i in range(1, task_id) 762 ) 763 return { 764 class_map_t[self.class_split[task_id][i]]: num_classes_previous + i 765 for i in range(num_classes_t) 766 } 767 768 def setup_task_id(self, task_id: int) -> None: 769 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 770 771 **Args:** 772 - **task_id** (`int`): the target task ID. 773 """ 774 super().setup_task_id(task_id) 775 776 self.mean_t = ( 777 self.original_dataset_constants.MEAN 778 ) # the same with the original dataset 779 self.std_t = ( 780 self.original_dataset_constants.STD 781 ) # the same with the original dataset 782 783 @abstractmethod 784 def get_subset_of_classes(self, dataset: Dataset) -> Dataset: 785 r"""Get a subset of classes from the dataset for the current task `self.task_id`. It is used when constructing the split. **It must be implemented by subclasses.** 786 787 **Args:** 788 - **dataset** (`Dataset`): the dataset to retrieve the subset from. 789 790 **Returns:** 791 - **subset** (`Dataset`): the subset of classes from the dataset. 792 """
The base class of continual learning datasets constructed as splits of an original dataset.
668 def __init__( 669 self, 670 root: str, 671 class_split: dict[int, list[int]], 672 batch_size: int | dict[int, int] = 1, 673 num_workers: int | dict[int, int] = 0, 674 custom_transforms: ( 675 Callable 676 | transforms.Compose 677 | None 678 | dict[int, Callable | transforms.Compose | None] 679 ) = None, 680 repeat_channels: int | None | dict[int, int | None] = None, 681 to_tensor: bool | dict[int, bool] = True, 682 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 683 ) -> None: 684 r""" 685 **Args:** 686 - **root** (`str`): the root directory where the original dataset live. 687 - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task. 688 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 689 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 690 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 691 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 692 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 693 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 694 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 695 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 696 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 697 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 698 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 699 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 700 """ 701 super().__init__( 702 root=root, 703 num_tasks=len( 704 class_split 705 ), # num_tasks is not explicitly provided, but derived from the class_split length 706 batch_size=batch_size, 707 num_workers=num_workers, 708 custom_transforms=custom_transforms, 709 repeat_channels=repeat_channels, 710 to_tensor=to_tensor, 711 resize=resize, 712 ) 713 714 self.original_dataset_constants: type[DatasetConstants] = ( 715 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class] 716 ) 717 r"""The original dataset constants class. """ 718 719 self.class_split: dict[int, list[int]] = OmegaConf.to_container(class_split) 720 r"""The dict of class splits for each task.""" 721 722 CLSplitDataset.sanity_check(self)
Args:
- root (
str): the root directory where the original dataset live. - class_split (
dict[int, list[int]]): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
The original dataset class. It must be provided in subclasses.
The original dataset constants class.
724 def sanity_check(self) -> None: 725 r"""Sanity check.""" 726 727 # check the class split 728 expected_keys = set(range(1, self.num_tasks + 1)) 729 if set(self.class_split.keys()) != expected_keys: 730 raise ValueError( 731 f"{self.class_split} dict keys must be consecutive integers from 1 to num_tasks." 732 ) 733 if any(len(split) < 2 for split in self.class_split.values()): 734 raise ValueError("Each class split must contain at least 2 elements!")
Sanity check.
736 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 737 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 738 739 **Args:** 740 - **task_id** (`int`): the task ID to query the CL class map. 741 742 **Returns:** 743 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 744 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 745 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 746 """ 747 num_classes_t = len( 748 self.class_split[task_id] 749 ) # the number of classes in the current task, i.e. the length of the class split 750 class_map_t = ( 751 self.original_dataset_constants.CLASS_MAP 752 ) # the same with the original dataset 753 754 if self.cl_paradigm == "TIL": 755 return { 756 class_map_t[self.class_split[task_id][i]]: i 757 for i in range(num_classes_t) 758 } 759 if self.cl_paradigm == "CIL": 760 num_classes_previous = sum( 761 len(self.class_split[i]) for i in range(1, task_id) 762 ) 763 return { 764 class_map_t[self.class_split[task_id][i]]: num_classes_previous + i 765 for i in range(num_classes_t) 766 }
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
768 def setup_task_id(self, task_id: int) -> None: 769 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 770 771 **Args:** 772 - **task_id** (`int`): the target task ID. 773 """ 774 super().setup_task_id(task_id) 775 776 self.mean_t = ( 777 self.original_dataset_constants.MEAN 778 ) # the same with the original dataset 779 self.std_t = ( 780 self.original_dataset_constants.STD 781 ) # the same with the original dataset
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.
783 @abstractmethod 784 def get_subset_of_classes(self, dataset: Dataset) -> Dataset: 785 r"""Get a subset of classes from the dataset for the current task `self.task_id`. It is used when constructing the split. **It must be implemented by subclasses.** 786 787 **Args:** 788 - **dataset** (`Dataset`): the dataset to retrieve the subset from. 789 790 **Returns:** 791 - **subset** (`Dataset`): the subset of classes from the dataset. 792 """
Get a subset of classes from the dataset for the current task self.task_id. It is used when constructing the split. It must be implemented by subclasses.
Args:
- dataset (
Dataset): the dataset to retrieve the subset from.
Returns:
- subset (
Dataset): the subset of classes from the dataset.
795class CLCombinedDataset(CLDataset): 796 r"""The base class of continual learning datasets constructed as combinations of several single-task datasets (one dataset per task).""" 797 798 def __init__( 799 self, 800 datasets: dict[int, str], 801 root: str | dict[int, str], 802 batch_size: int | dict[int, int] = 1, 803 num_workers: int | dict[int, int] = 0, 804 custom_transforms: ( 805 Callable 806 | transforms.Compose 807 | None 808 | dict[int, Callable | transforms.Compose | None] 809 ) = None, 810 repeat_channels: int | None | dict[int, int | None] = None, 811 to_tensor: bool | dict[int, bool] = True, 812 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 813 ) -> None: 814 r""" 815 **Args:** 816 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 817 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live. 818 If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. 819 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 820 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 821 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 822 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 823 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 824 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 825 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 826 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 827 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 828 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 829 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 830 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 831 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 832 """ 833 super().__init__( 834 root=root, 835 num_tasks=len( 836 datasets 837 ), # num_tasks is not explicitly provided, but derived from the datasets length 838 batch_size=batch_size, 839 num_workers=num_workers, 840 custom_transforms=custom_transforms, 841 repeat_channels=repeat_channels, 842 to_tensor=to_tensor, 843 resize=resize, 844 ) 845 846 self.original_dataset_python_classes: dict[int, Dataset] = { 847 t: str_to_class(dataset_class_path) 848 for t, dataset_class_path in datasets.items() 849 } 850 r"""The dict of dataset classes for each task.""" 851 self.original_dataset_python_class_t: Dataset 852 r"""The dataset class for the current task `self.task_id`.""" 853 self.original_dataset_constants_t: type[DatasetConstants] 854 r"""The original dataset constants class for the current task `self.task_id`.""" 855 856 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 857 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 858 859 **Args:** 860 - **task_id** (`int`): the task ID to query the CL class map. 861 862 **Returns:** 863 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 864 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 865 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 866 """ 867 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 868 original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[ 869 original_dataset_python_class_t 870 ] 871 num_classes_t = original_dataset_constants_t.NUM_CLASSES 872 class_map_t = original_dataset_constants_t.CLASS_MAP 873 874 if self.cl_paradigm == "TIL": 875 return {class_map_t[i]: i for i in range(num_classes_t)} 876 if self.cl_paradigm == "CIL": 877 num_classes_previous = sum( 878 [ 879 DATASET_CONSTANTS_MAPPING[ 880 self.original_dataset_python_classes[i] 881 ].NUM_CLASSES 882 for i in range(1, task_id) 883 ] 884 ) 885 return { 886 class_map_t[i]: num_classes_previous + i for i in range(num_classes_t) 887 } 888 889 def setup_task_id(self, task_id: int) -> None: 890 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 891 892 **Args:** 893 - **task_id** (`int`): the target task ID. 894 """ 895 896 self.original_dataset_python_class_t = self.original_dataset_python_classes[ 897 task_id 898 ] 899 900 self.original_dataset_constants_t: type[DatasetConstants] = ( 901 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class_t] 902 ) 903 904 super().setup_task_id(task_id) 905 906 self.mean_t = self.original_dataset_constants_t.MEAN 907 self.std_t = self.original_dataset_constants_t.STD
The base class of continual learning datasets constructed as combinations of several single-task datasets (one dataset per task).
798 def __init__( 799 self, 800 datasets: dict[int, str], 801 root: str | dict[int, str], 802 batch_size: int | dict[int, int] = 1, 803 num_workers: int | dict[int, int] = 0, 804 custom_transforms: ( 805 Callable 806 | transforms.Compose 807 | None 808 | dict[int, Callable | transforms.Compose | None] 809 ) = None, 810 repeat_channels: int | None | dict[int, int | None] = None, 811 to_tensor: bool | dict[int, bool] = True, 812 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 813 ) -> None: 814 r""" 815 **Args:** 816 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 817 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live. 818 If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. 819 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`. 820 - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders. 821 If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks. 822 - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders. 823 If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks. 824 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included. 825 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 826 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 827 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 828 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 829 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 830 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 831 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 832 """ 833 super().__init__( 834 root=root, 835 num_tasks=len( 836 datasets 837 ), # num_tasks is not explicitly provided, but derived from the datasets length 838 batch_size=batch_size, 839 num_workers=num_workers, 840 custom_transforms=custom_transforms, 841 repeat_channels=repeat_channels, 842 to_tensor=to_tensor, 843 resize=resize, 844 ) 845 846 self.original_dataset_python_classes: dict[int, Dataset] = { 847 t: str_to_class(dataset_class_path) 848 for t, dataset_class_path in datasets.items() 849 } 850 r"""The dict of dataset classes for each task.""" 851 self.original_dataset_python_class_t: Dataset 852 r"""The dataset class for the current task `self.task_id`.""" 853 self.original_dataset_constants_t: type[DatasetConstants] 854 r"""The original dataset constants class for the current task `self.task_id`."""
Args:
- datasets (
dict[int, str]): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. - root (
str|dict[int, str]): the root directory where the original data files for constructing the CL dataset physically live. If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks. - num_tasks (
int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 tonum_tasks. - batch_size (
int|dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is anint, it is the same batch size for all tasks. - num_workers (
int|dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is anint, it is the same number of workers for all tasks. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
The dict of dataset classes for each task.
The dataset class for the current task self.task_id.
The original dataset constants class for the current task self.task_id.
856 def get_cl_class_map(self, task_id: int) -> dict[str | int, int]: 857 r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. 858 859 **Args:** 860 - **task_id** (`int`): the task ID to query the CL class map. 861 862 **Returns:** 863 - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning. 864 - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. 865 - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task. 866 """ 867 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 868 original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[ 869 original_dataset_python_class_t 870 ] 871 num_classes_t = original_dataset_constants_t.NUM_CLASSES 872 class_map_t = original_dataset_constants_t.CLASS_MAP 873 874 if self.cl_paradigm == "TIL": 875 return {class_map_t[i]: i for i in range(num_classes_t)} 876 if self.cl_paradigm == "CIL": 877 num_classes_previous = sum( 878 [ 879 DATASET_CONSTANTS_MAPPING[ 880 self.original_dataset_python_classes[i] 881 ].NUM_CLASSES 882 for i in range(1, task_id) 883 ] 884 ) 885 return { 886 class_map_t[i]: num_classes_previous + i for i in range(num_classes_t) 887 }
Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.
Args:
- task_id (
int): the task ID to query the CL class map.
Returns:
- cl_class_map (
dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.- If
self.cl_paradigmis 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes. - If
self.cl_paradigmis 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
- If
889 def setup_task_id(self, task_id: int) -> None: 890 r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called. 891 892 **Args:** 893 - **task_id** (`int`): the target task ID. 894 """ 895 896 self.original_dataset_python_class_t = self.original_dataset_python_classes[ 897 task_id 898 ] 899 900 self.original_dataset_constants_t: type[DatasetConstants] = ( 901 DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class_t] 902 ) 903 904 super().setup_task_id(task_id) 905 906 self.mean_t = self.original_dataset_constants_t.MEAN 907 self.std_t = self.original_dataset_constants_t.STD
Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.
Args:
- task_id (
int): the target task ID.