clarena.cl_datasets

Continual Learning Datasets

This submodule provides the continual learning datasets that can be used in CLArena.

Here are the base classes for continual learning datasets, which inherit from Lightning LightningDataModule:

Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement continual learning datasets:

  1r"""
  2
  3# Continual Learning Datasets
  4
  5This submodule provides the **continual learning datasets** that can be used in CLArena.
  6
  7Here are the base classes for continual learning datasets, which inherit from Lightning `LightningDataModule`:
  8
  9- `CLDataset`: The base class for all continual learning datasets.
 10    - `CLPermutedDataset`: The base class for permuted continual learning datasets. A child class of `CLDataset`.
 11    - `CLRotatedDataset`: The base class for rotated continual learning datasets. A child class of `CLDataset`.
 12    - `CLSplitDataset`: The base class for split continual learning datasets. A child class of `CLDataset`.
 13    - `CLCombinedDataset`: The base class for combined continual learning datasets. A child class of `CLDataset`.
 14
 15Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement continual learning datasets:
 16
 17- [**Configure CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/CL-dataset)
 18- [**Implement Custom CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/cl_dataset)
 19- [**A Beginners' Guide to Continual Learning (CL Dataset)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-CL-dataset)
 20
 21
 22"""
 23
 24from .base import (
 25    CLDataset,
 26    CLPermutedDataset,
 27    CLRotatedDataset,
 28    CLSplitDataset,
 29    CLCombinedDataset,
 30)
 31
 32from .permuted_mnist import PermutedMNIST
 33from .rotated_mnist import RotatedMNIST
 34from .permuted_emnist import PermutedEMNIST
 35from .permuted_fashionmnist import PermutedFashionMNIST
 36from .permuted_kmnist import PermutedKMNIST
 37from .permuted_notmnist import PermutedNotMNIST
 38from .permuted_sign_language_mnist import PermutedSignLanguageMNIST
 39from .permuted_ahdd import PermutedArabicHandwrittenDigits
 40from .permuted_kannadamnist import PermutedKannadaMNIST
 41from .permuted_svhn import PermutedSVHN
 42from .permuted_country211 import PermutedCountry211
 43from .permuted_imagenette import PermutedImagenette
 44from .permuted_dtd import PermutedDTD
 45from .permuted_cifar10 import PermutedCIFAR10
 46from .permuted_cifar100 import PermutedCIFAR100
 47from .permuted_caltech101 import PermutedCaltech101
 48from .permuted_caltech256 import PermutedCaltech256
 49from .permuted_eurosat import PermutedEuroSAT
 50from .permuted_fgvc_aircraft import PermutedFGVCAircraft
 51from .permuted_flowers102 import PermutedFlowers102
 52from .permuted_food101 import PermutedFood101
 53from .permuted_celeba import PermutedCelebA
 54from .permuted_fer2013 import PermutedFER2013
 55from .permuted_tinyimagenet import PermutedTinyImageNet
 56from .permuted_oxford_iiit_pet import PermutedOxfordIIITPet
 57from .permuted_pcam import PermutedPCAM
 58from .permuted_renderedsst2 import PermutedRenderedSST2
 59from .permuted_stanfordcars import PermutedStanfordCars
 60from .permuted_sun397 import PermutedSUN397
 61from .permuted_usps import PermutedUSPS
 62from .permuted_SEMEION import PermutedSEMEION
 63from .permuted_facescrub import PermutedFaceScrub
 64from .permuted_cub2002011 import PermutedCUB2002011
 65from .permuted_gtsrb import PermutedGTSRB
 66from .permuted_linnaeus5 import PermutedLinnaeus5
 67
 68from .split_cifar10 import SplitCIFAR10
 69from .split_mnist import SplitMNIST
 70from .split_cifar100 import SplitCIFAR100
 71from .split_tinyimagenet import SplitTinyImageNet
 72from .split_cub2002011 import SplitCUB2002011
 73
 74from .combined import Combined
 75
 76
 77__all__ = [
 78    "CLDataset",
 79    "CLPermutedDataset",
 80    "CLRotatedDataset",
 81    "CLSplitDataset",
 82    "CLCombinedDataset",
 83    "combined",
 84    "permuted_mnist",
 85    "rotated_mnist",
 86    "permuted_emnist",
 87    "permuted_fashionmnist",
 88    "permuted_imagenette",
 89    "permuted_sign_language_mnist",
 90    "permuted_ahdd",
 91    "permuted_kannadamnist",
 92    "permuted_country211",
 93    "permuted_dtd",
 94    "permuted_fer2013",
 95    "permuted_fgvc_aircraft",
 96    "permuted_flowers102",
 97    "permuted_food101",
 98    "permuted_kmnist",
 99    "permuted_notmnist",
100    "permuted_svhn",
101    "permuted_cifar10",
102    "permuted_cifar100",
103    "permuted_caltech101",
104    "permuted_caltech256",
105    "permuted_oxford_iiit_pet",
106    "permuted_celeba",
107    "permuted_eurosat",
108    "permuted_facescrub",
109    "permuted_pcam",
110    "permuted_renderedsst2",
111    "permuted_stanfordcars",
112    "permuted_sun397",
113    "permuted_usps",
114    "permuted_SEMEION",
115    "permuted_tinyimagenet",
116    "permuted_cub2002011",
117    "permuted_gtsrb",
118    "permuted_linnaeus5",
119    "split_mnist",
120    "split_cifar10",
121    "split_cifar100",
122    "split_tinyimagenet",
123    "split_cub2002011",
124]
class CLDataset(lightning.pytorch.core.datamodule.LightningDataModule):
 35class CLDataset(LightningDataModule):
 36    r"""The base class of continual learning datasets."""
 37
 38    def __init__(
 39        self,
 40        root: str | dict[int, str],
 41        num_tasks: int,
 42        batch_size: int | dict[int, int] = 1,
 43        num_workers: int | dict[int, int] = 0,
 44        custom_transforms: (
 45            Callable
 46            | transforms.Compose
 47            | None
 48            | dict[int, Callable | transforms.Compose | None]
 49        ) = None,
 50        repeat_channels: int | None | dict[int, int | None] = None,
 51        to_tensor: bool | dict[int, bool] = True,
 52        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 53    ) -> None:
 54        r"""
 55        **Args:**
 56        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live.
 57        If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
 58        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
 59        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
 60        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
 61        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
 62        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
 63        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
 64        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 65        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 66        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 67        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 68        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 69        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 70        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 71        """
 72        super().__init__()
 73
 74        self.root: dict[int, str] = (
 75            OmegaConf.to_container(root)
 76            if isinstance(root, DictConfig)
 77            else {t: root for t in range(1, num_tasks + 1)}
 78        )
 79        r"""The dict of root directories of the original data files for each task."""
 80        self.num_tasks: int = num_tasks
 81        r"""The maximum number of tasks supported by the dataset."""
 82        self.cl_paradigm: str
 83        r"""The continual learning paradigm."""
 84        self.batch_size: dict[int, int] = (
 85            OmegaConf.to_container(batch_size)
 86            if isinstance(batch_size, DictConfig)
 87            else {t: batch_size for t in range(1, num_tasks + 1)}
 88        )
 89        r"""The dict of batch sizes for each task."""
 90        self.num_workers: dict[int, int] = (
 91            OmegaConf.to_container(num_workers)
 92            if isinstance(num_workers, DictConfig)
 93            else {t: num_workers for t in range(1, num_tasks + 1)}
 94        )
 95        r"""The dict of numbers of workers for each task."""
 96        self.custom_transforms: dict[int, Callable | transforms.Compose | None] = (
 97            OmegaConf.to_container(custom_transforms)
 98            if isinstance(custom_transforms, DictConfig)
 99            else {t: custom_transforms for t in range(1, num_tasks + 1)}
100        )
101        r"""The dict of custom transforms for each task."""
102        self.repeat_channels: dict[int, int | None] = (
103            OmegaConf.to_container(repeat_channels)
104            if isinstance(repeat_channels, DictConfig)
105            else {t: repeat_channels for t in range(1, num_tasks + 1)}
106        )
107        r"""The dict of number of channels to repeat for each task."""
108        self.to_tensor: dict[int, bool] = (
109            OmegaConf.to_container(to_tensor)
110            if isinstance(to_tensor, DictConfig)
111            else {t: to_tensor for t in range(1, num_tasks + 1)}
112        )
113        r"""The dict of to_tensor flag for each task. """
114        self.resize: dict[int, tuple[int, int] | None] = (
115            {t: tuple(rs) if rs else None for t, rs in resize.items()}
116            if isinstance(resize, DictConfig)
117            else {
118                t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1)
119            }
120        )
121        r"""The dict of sizes to resize to for each task."""
122
123        # task-specific attributes
124        self.root_t: str
125        r"""The root directory of the original data files for the current task `self.task_id`."""
126        self.batch_size_t: int
127        r"""The batch size for the current task `self.task_id`."""
128        self.num_workers_t: int
129        r"""The number of workers for the current task `self.task_id`."""
130        self.custom_transforms_t: Callable | transforms.Compose | None
131        r"""The custom transforms for the current task `self.task_id`."""
132        self.repeat_channels_t: int | None
133        r"""The number of channels to repeat for the current task `self.task_id`."""
134        self.to_tensor_t: bool
135        r"""The to_tensor flag for the current task `self.task_id`."""
136        self.resize_t: tuple[int, int] | None
137        r"""The size to resize for the current task `self.task_id`."""
138        self.mean_t: float
139        r"""The mean values for normalization for the current task `self.task_id`."""
140        self.std_t: float
141        r"""The standard deviation values for normalization for the current task `self.task_id`."""
142
143        # dataset containers
144        self.dataset_train_t: Any
145        r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object."""
146        self.dataset_val_t: Any
147        r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object."""
148        self.dataset_test: dict[int, Any] = {}
149        r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects."""
150
151        # task ID control
152        self.task_id: int
153        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
154        self.processed_task_ids: list[int] = []
155        r"""Task IDs that have been processed."""
156
157        CLDataset.sanity_check(self)
158
159    def sanity_check(self) -> None:
160        r"""Sanity check."""
161
162        # check if each task has been provided with necessary arguments
163        for attr in [
164            "root",
165            "batch_size",
166            "num_workers",
167            "custom_transforms",
168            "repeat_channels",
169            "to_tensor",
170            "resize",
171        ]:
172            value = getattr(self, attr)
173            expected_keys = set(range(1, self.num_tasks + 1))
174            if set(value.keys()) != expected_keys:
175                raise ValueError(
176                    f"{attr} dict keys must be consecutive integers from 1 to num_tasks."
177                )
178
179    @abstractmethod
180    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
181        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses.
182
183        **Args:**
184        - **task_id** (`int`): the task ID to query the CL class map.
185
186        **Returns:**
187        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
188            - If `self.cl_paradigm` is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes.
189            - If `self.cl_paradigm` is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
190        """
191
192    @abstractmethod
193    def prepare_data(self) -> None:
194        r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDataModule`. This method is called at the beginning of each task."""
195
196    def setup(self, stage: str) -> None:
197        r"""Set up the dataset for different stages. This method is called at the beginning of each task.
198
199        **Args:**
200        - **stage** (`str`): the stage of the experiment; one of:
201            - 'fit': training and validation datasets of the current task `self.task_id` are assigned to `self.dataset_train_t` and `self.dataset_val_t`.
202            - 'test': a dict of test datasets of all seen tasks should be assigned to `self.dataset_test`.
203        """
204        if stage == "fit":
205            # these two stages must be done together because a sanity check for validation is conducted before training
206            pylogger.debug(
207                "Construct train and validation dataset for task %d...", self.task_id
208            )
209
210            self.dataset_train_t, self.dataset_val_t = self.train_and_val_dataset()
211
212            pylogger.info(
213                "Train and validation dataset for task %d are ready.", self.task_id
214            )
215            pylogger.info(
216                "Train dataset for task %d size: %d",
217                self.task_id,
218                len(self.dataset_train_t),
219            )
220            pylogger.info(
221                "Validation dataset for task %d size: %d",
222                self.task_id,
223                len(self.dataset_val_t),
224            )
225
226        elif stage == "test":
227            pylogger.debug("Construct test dataset for task %d...", self.task_id)
228
229            self.dataset_test[self.task_id] = self.test_dataset()
230
231            pylogger.info("Test dataset for task %d are ready.", self.task_id)
232            pylogger.info(
233                "Test dataset for task %d size: %d",
234                self.task_id,
235                len(self.dataset_test[self.task_id]),
236            )
237
238    def setup_task_id(self, task_id: int) -> None:
239        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
240
241        **Args:**
242        - **task_id** (`int`): the target task ID.
243        """
244
245        self.task_id = task_id
246
247        self.root_t = self.root[task_id]
248        self.batch_size_t = self.batch_size[task_id]
249        self.num_workers_t = self.num_workers[task_id]
250        self.custom_transforms_t = self.custom_transforms[task_id]
251        self.repeat_channels_t = self.repeat_channels[task_id]
252        self.to_tensor_t = self.to_tensor[task_id]
253        self.resize_t = self.resize[task_id]
254
255        self.processed_task_ids.append(task_id)
256
257    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
258        r"""Set up tasks for continual learning main evaluation.
259
260        **Args:**
261        - **eval_tasks** (`list[int]`): the list of task IDs to evaluate.
262        """
263        print(eval_tasks, "eval_tasks")
264        for task_id in eval_tasks:
265            self.setup_task_id(task_id=task_id)
266            self.setup(stage="test")
267
268    def set_cl_paradigm(self, cl_paradigm: str) -> None:
269        r"""Set `cl_paradigm` to `self.cl_paradigm`. It is used to define the CL class map.
270
271        **Args:**
272        - **cl_paradigm** (`str`): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
273        """
274        self.cl_paradigm = cl_paradigm
275
276    def train_and_val_transforms(self) -> transforms.Compose:
277        r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. It can be used in subclasses when constructing the dataset.
278
279        **Returns:**
280        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
281        """
282        repeat_channels_transform = (
283            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
284            if self.repeat_channels_t is not None
285            else None
286        )
287        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
288        resize_transform = (
289            transforms.Resize(self.resize_t) if self.resize_t is not None else None
290        )
291        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
292
293        return transforms.Compose(
294            list(
295                filter(
296                    None,
297                    [
298                        repeat_channels_transform,
299                        to_tensor_transform,
300                        resize_transform,
301                        self.custom_transforms_t,
302                        normalization_transform,
303                    ],
304                )
305            )
306        )  # the order of transforms matters
307
308    def test_transforms(self) -> transforms.Compose:
309        r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. It is used in subclasses when constructing the dataset.
310
311        **Returns:**
312        - **test_transforms** (`transforms.Compose`): the composed test transforms.
313        """
314        repeat_channels_transform = (
315            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
316            if self.repeat_channels_t is not None
317            else None
318        )
319        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
320        resize_transform = (
321            transforms.Resize(self.resize_t) if self.resize_t is not None else None
322        )
323        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
324
325        return transforms.Compose(
326            list(
327                filter(
328                    None,
329                    [
330                        repeat_channels_transform,
331                        to_tensor_transform,
332                        resize_transform,
333                        normalization_transform,
334                    ],
335                )
336            )
337        )  # the order of transforms matters. No custom transforms for test
338
339    def target_transform(self) -> ClassMapping:
340        r"""Target transform to map the original class labels to CL class labels according to `self.cl_paradigm`. It can be used in subclasses when constructing the dataset.
341
342        **Returns:**
343        - **target_transform** (`Callable`): the target transform function.
344        """
345
346        cl_class_map = self.get_cl_class_map(task_id=self.task_id)
347
348        target_transform = ClassMapping(class_map=cl_class_map)
349
350        return target_transform
351
352    @abstractmethod
353    def train_and_val_dataset(self) -> tuple[Any, Any]:
354        r"""Get the training and validation datasets of the current task `self.task_id`. It must be implemented by subclasses.
355
356        **Returns:**
357        - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation datasets of the current task `self.task_id`.
358        """
359
360    @abstractmethod
361    def test_dataset(self) -> Any:
362        r"""Get the test dataset of the current task `self.task_id`. It must be implemented by subclasses.
363
364        **Returns:**
365        - **test_dataset** (`Any`): the test dataset of the current task `self.task_id`.
366        """
367
368    def train_dataloader(self) -> DataLoader:
369        r"""DataLoader generator for the train stage of the current task `self.task_id`. It is automatically called before training the task.
370
371        **Returns:**
372        - **train_dataloader** (`DataLoader`): the train DataLoader of task `self.task_id`.
373        """
374
375        pylogger.debug("Construct train dataloader for task %d...", self.task_id)
376
377        return DataLoader(
378            dataset=self.dataset_train_t,
379            batch_size=self.batch_size_t,
380            shuffle=True,  # shuffle train batch to prevent overfitting
381            num_workers=self.num_workers_t,
382            drop_last=True,  # to avoid batchnorm error (when batch_size is 1)
383        )
384
385    def val_dataloader(self) -> DataLoader:
386        r"""DataLoader generator for the validation stage of the current task `self.task_id`. It is automatically called before the task's validation.
387
388        **Returns:**
389        - **val_dataloader** (`DataLoader`): the validation DataLoader of task `self.task_id`.
390        """
391
392        pylogger.debug("Construct validation dataloader for task %d...", self.task_id)
393
394        return DataLoader(
395            dataset=self.dataset_val_t,
396            batch_size=self.batch_size_t,
397            shuffle=False,  # don't have to shuffle val or test batch
398            num_workers=self.num_workers_t,
399        )
400
401    def test_dataloader(self) -> dict[int, DataLoader]:
402        r"""DataLoader generator for the test stage of the current task `self.task_id`. It is automatically called before testing the task.
403
404        **Returns:**
405        - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders.
406        """
407
408        pylogger.debug("Construct test dataloader for task %d...", self.task_id)
409
410        return {
411            task_id: DataLoader(
412                dataset=dataset_test_t,
413                batch_size=self.batch_size_t,
414                shuffle=False,  # don't have to shuffle val or test batch
415                num_workers=self.num_workers_t,
416            )
417            for task_id, dataset_test_t in self.dataset_test.items()
418        }
419
420    def __len__(self) -> int:
421        r"""Get the number of tasks in the dataset.
422
423        **Returns:**
424        - **num_tasks** (`int`): the number of tasks in the dataset.
425        """
426        return self.num_tasks

The base class of continual learning datasets.

CLDataset( root: str | dict[int, str], num_tasks: int, batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
 38    def __init__(
 39        self,
 40        root: str | dict[int, str],
 41        num_tasks: int,
 42        batch_size: int | dict[int, int] = 1,
 43        num_workers: int | dict[int, int] = 0,
 44        custom_transforms: (
 45            Callable
 46            | transforms.Compose
 47            | None
 48            | dict[int, Callable | transforms.Compose | None]
 49        ) = None,
 50        repeat_channels: int | None | dict[int, int | None] = None,
 51        to_tensor: bool | dict[int, bool] = True,
 52        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 53    ) -> None:
 54        r"""
 55        **Args:**
 56        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live.
 57        If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
 58        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
 59        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
 60        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
 61        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
 62        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
 63        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
 64        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 65        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 66        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 67        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 68        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 69        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 70        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 71        """
 72        super().__init__()
 73
 74        self.root: dict[int, str] = (
 75            OmegaConf.to_container(root)
 76            if isinstance(root, DictConfig)
 77            else {t: root for t in range(1, num_tasks + 1)}
 78        )
 79        r"""The dict of root directories of the original data files for each task."""
 80        self.num_tasks: int = num_tasks
 81        r"""The maximum number of tasks supported by the dataset."""
 82        self.cl_paradigm: str
 83        r"""The continual learning paradigm."""
 84        self.batch_size: dict[int, int] = (
 85            OmegaConf.to_container(batch_size)
 86            if isinstance(batch_size, DictConfig)
 87            else {t: batch_size for t in range(1, num_tasks + 1)}
 88        )
 89        r"""The dict of batch sizes for each task."""
 90        self.num_workers: dict[int, int] = (
 91            OmegaConf.to_container(num_workers)
 92            if isinstance(num_workers, DictConfig)
 93            else {t: num_workers for t in range(1, num_tasks + 1)}
 94        )
 95        r"""The dict of numbers of workers for each task."""
 96        self.custom_transforms: dict[int, Callable | transforms.Compose | None] = (
 97            OmegaConf.to_container(custom_transforms)
 98            if isinstance(custom_transforms, DictConfig)
 99            else {t: custom_transforms for t in range(1, num_tasks + 1)}
100        )
101        r"""The dict of custom transforms for each task."""
102        self.repeat_channels: dict[int, int | None] = (
103            OmegaConf.to_container(repeat_channels)
104            if isinstance(repeat_channels, DictConfig)
105            else {t: repeat_channels for t in range(1, num_tasks + 1)}
106        )
107        r"""The dict of number of channels to repeat for each task."""
108        self.to_tensor: dict[int, bool] = (
109            OmegaConf.to_container(to_tensor)
110            if isinstance(to_tensor, DictConfig)
111            else {t: to_tensor for t in range(1, num_tasks + 1)}
112        )
113        r"""The dict of to_tensor flag for each task. """
114        self.resize: dict[int, tuple[int, int] | None] = (
115            {t: tuple(rs) if rs else None for t, rs in resize.items()}
116            if isinstance(resize, DictConfig)
117            else {
118                t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1)
119            }
120        )
121        r"""The dict of sizes to resize to for each task."""
122
123        # task-specific attributes
124        self.root_t: str
125        r"""The root directory of the original data files for the current task `self.task_id`."""
126        self.batch_size_t: int
127        r"""The batch size for the current task `self.task_id`."""
128        self.num_workers_t: int
129        r"""The number of workers for the current task `self.task_id`."""
130        self.custom_transforms_t: Callable | transforms.Compose | None
131        r"""The custom transforms for the current task `self.task_id`."""
132        self.repeat_channels_t: int | None
133        r"""The number of channels to repeat for the current task `self.task_id`."""
134        self.to_tensor_t: bool
135        r"""The to_tensor flag for the current task `self.task_id`."""
136        self.resize_t: tuple[int, int] | None
137        r"""The size to resize for the current task `self.task_id`."""
138        self.mean_t: float
139        r"""The mean values for normalization for the current task `self.task_id`."""
140        self.std_t: float
141        r"""The standard deviation values for normalization for the current task `self.task_id`."""
142
143        # dataset containers
144        self.dataset_train_t: Any
145        r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object."""
146        self.dataset_val_t: Any
147        r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object."""
148        self.dataset_test: dict[int, Any] = {}
149        r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects."""
150
151        # task ID control
152        self.task_id: int
153        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
154        self.processed_task_ids: list[int] = []
155        r"""Task IDs that have been processed."""
156
157        CLDataset.sanity_check(self)

Args:

  • root (str | dict[int, str]): the root directory where the original data files for constructing the CL dataset physically live. If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to num_tasks.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
root: dict[int, str]

The dict of root directories of the original data files for each task.

num_tasks: int

The maximum number of tasks supported by the dataset.

cl_paradigm: str

The continual learning paradigm.

batch_size: dict[int, int]

The dict of batch sizes for each task.

num_workers: dict[int, int]

The dict of numbers of workers for each task.

custom_transforms: dict[int, typing.Union[typing.Callable, torchvision.transforms.transforms.Compose, NoneType]]

The dict of custom transforms for each task.

repeat_channels: dict[int, int | None]

The dict of number of channels to repeat for each task.

to_tensor: dict[int, bool]

The dict of to_tensor flag for each task.

resize: dict[int, tuple[int, int] | None]

The dict of sizes to resize to for each task.

root_t: str

The root directory of the original data files for the current task self.task_id.

batch_size_t: int

The batch size for the current task self.task_id.

num_workers_t: int

The number of workers for the current task self.task_id.

custom_transforms_t: Union[Callable, torchvision.transforms.transforms.Compose, NoneType]

The custom transforms for the current task self.task_id.

repeat_channels_t: int | None

The number of channels to repeat for the current task self.task_id.

to_tensor_t: bool

The to_tensor flag for the current task self.task_id.

resize_t: tuple[int, int] | None

The size to resize for the current task self.task_id.

mean_t: float

The mean values for normalization for the current task self.task_id.

std_t: float

The standard deviation values for normalization for the current task self.task_id.

dataset_train_t: Any

The training dataset object. Can be a PyTorch Dataset object or any other dataset object.

dataset_val_t: Any

The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.

dataset_test: dict[int, typing.Any]

The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.

processed_task_ids: list[int]

Task IDs that have been processed.

def sanity_check(self) -> None:
159    def sanity_check(self) -> None:
160        r"""Sanity check."""
161
162        # check if each task has been provided with necessary arguments
163        for attr in [
164            "root",
165            "batch_size",
166            "num_workers",
167            "custom_transforms",
168            "repeat_channels",
169            "to_tensor",
170            "resize",
171        ]:
172            value = getattr(self, attr)
173            expected_keys = set(range(1, self.num_tasks + 1))
174            if set(value.keys()) != expected_keys:
175                raise ValueError(
176                    f"{attr} dict keys must be consecutive integers from 1 to num_tasks."
177                )

Sanity check.

@abstractmethod
def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
179    @abstractmethod
180    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
181        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses.
182
183        **Args:**
184        - **task_id** (`int`): the task ID to query the CL class map.
185
186        **Returns:**
187        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
188            - If `self.cl_paradigm` is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes.
189            - If `self.cl_paradigm` is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
190        """

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm. It must be implemented by subclasses.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
@abstractmethod
def prepare_data(self) -> None:
192    @abstractmethod
193    def prepare_data(self) -> None:
194        r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDataModule`. This method is called at the beginning of each task."""

Use this to download and prepare data. It must be implemented by subclasses, as required by LightningDataModule. This method is called at the beginning of each task.

def setup(self, stage: str) -> None:
196    def setup(self, stage: str) -> None:
197        r"""Set up the dataset for different stages. This method is called at the beginning of each task.
198
199        **Args:**
200        - **stage** (`str`): the stage of the experiment; one of:
201            - 'fit': training and validation datasets of the current task `self.task_id` are assigned to `self.dataset_train_t` and `self.dataset_val_t`.
202            - 'test': a dict of test datasets of all seen tasks should be assigned to `self.dataset_test`.
203        """
204        if stage == "fit":
205            # these two stages must be done together because a sanity check for validation is conducted before training
206            pylogger.debug(
207                "Construct train and validation dataset for task %d...", self.task_id
208            )
209
210            self.dataset_train_t, self.dataset_val_t = self.train_and_val_dataset()
211
212            pylogger.info(
213                "Train and validation dataset for task %d are ready.", self.task_id
214            )
215            pylogger.info(
216                "Train dataset for task %d size: %d",
217                self.task_id,
218                len(self.dataset_train_t),
219            )
220            pylogger.info(
221                "Validation dataset for task %d size: %d",
222                self.task_id,
223                len(self.dataset_val_t),
224            )
225
226        elif stage == "test":
227            pylogger.debug("Construct test dataset for task %d...", self.task_id)
228
229            self.dataset_test[self.task_id] = self.test_dataset()
230
231            pylogger.info("Test dataset for task %d are ready.", self.task_id)
232            pylogger.info(
233                "Test dataset for task %d size: %d",
234                self.task_id,
235                len(self.dataset_test[self.task_id]),
236            )

Set up the dataset for different stages. This method is called at the beginning of each task.

Args:

  • stage (str): the stage of the experiment; one of:
    • 'fit': training and validation datasets of the current task self.task_id are assigned to self.dataset_train_t and self.dataset_val_t.
    • 'test': a dict of test datasets of all seen tasks should be assigned to self.dataset_test.
def setup_task_id(self, task_id: int) -> None:
238    def setup_task_id(self, task_id: int) -> None:
239        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
240
241        **Args:**
242        - **task_id** (`int`): the target task ID.
243        """
244
245        self.task_id = task_id
246
247        self.root_t = self.root[task_id]
248        self.batch_size_t = self.batch_size[task_id]
249        self.num_workers_t = self.num_workers[task_id]
250        self.custom_transforms_t = self.custom_transforms[task_id]
251        self.repeat_channels_t = self.repeat_channels[task_id]
252        self.to_tensor_t = self.to_tensor[task_id]
253        self.resize_t = self.resize[task_id]
254
255        self.processed_task_ids.append(task_id)

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
257    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
258        r"""Set up tasks for continual learning main evaluation.
259
260        **Args:**
261        - **eval_tasks** (`list[int]`): the list of task IDs to evaluate.
262        """
263        print(eval_tasks, "eval_tasks")
264        for task_id in eval_tasks:
265            self.setup_task_id(task_id=task_id)
266            self.setup(stage="test")

Set up tasks for continual learning main evaluation.

Args:

  • eval_tasks (list[int]): the list of task IDs to evaluate.
def set_cl_paradigm(self, cl_paradigm: str) -> None:
268    def set_cl_paradigm(self, cl_paradigm: str) -> None:
269        r"""Set `cl_paradigm` to `self.cl_paradigm`. It is used to define the CL class map.
270
271        **Args:**
272        - **cl_paradigm** (`str`): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
273        """
274        self.cl_paradigm = cl_paradigm

Set cl_paradigm to self.cl_paradigm. It is used to define the CL class map.

Args:

  • cl_paradigm (str): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
def train_and_val_transforms(self) -> torchvision.transforms.transforms.Compose:
276    def train_and_val_transforms(self) -> transforms.Compose:
277        r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. It can be used in subclasses when constructing the dataset.
278
279        **Returns:**
280        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
281        """
282        repeat_channels_transform = (
283            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
284            if self.repeat_channels_t is not None
285            else None
286        )
287        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
288        resize_transform = (
289            transforms.Resize(self.resize_t) if self.resize_t is not None else None
290        )
291        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
292
293        return transforms.Compose(
294            list(
295                filter(
296                    None,
297                    [
298                        repeat_channels_transform,
299                        to_tensor_transform,
300                        resize_transform,
301                        self.custom_transforms_t,
302                        normalization_transform,
303                    ],
304                )
305            )
306        )  # the order of transforms matters

Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and ToTensor(). It can be used in subclasses when constructing the dataset.

Returns:

  • train_and_val_transforms (transforms.Compose): the composed train/val transforms.
def test_transforms(self) -> torchvision.transforms.transforms.Compose:
308    def test_transforms(self) -> transforms.Compose:
309        r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. It is used in subclasses when constructing the dataset.
310
311        **Returns:**
312        - **test_transforms** (`transforms.Compose`): the composed test transforms.
313        """
314        repeat_channels_transform = (
315            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
316            if self.repeat_channels_t is not None
317            else None
318        )
319        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
320        resize_transform = (
321            transforms.Resize(self.resize_t) if self.resize_t is not None else None
322        )
323        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
324
325        return transforms.Compose(
326            list(
327                filter(
328                    None,
329                    [
330                        repeat_channels_transform,
331                        to_tensor_transform,
332                        resize_transform,
333                        normalization_transform,
334                    ],
335                )
336            )
337        )  # the order of transforms matters. No custom transforms for test

Transforms for the test dataset. Only basic transforms like normalization and ToTensor() are included. It is used in subclasses when constructing the dataset.

Returns:

  • test_transforms (transforms.Compose): the composed test transforms.
def target_transform(self) -> clarena.utils.transforms.ClassMapping:
339    def target_transform(self) -> ClassMapping:
340        r"""Target transform to map the original class labels to CL class labels according to `self.cl_paradigm`. It can be used in subclasses when constructing the dataset.
341
342        **Returns:**
343        - **target_transform** (`Callable`): the target transform function.
344        """
345
346        cl_class_map = self.get_cl_class_map(task_id=self.task_id)
347
348        target_transform = ClassMapping(class_map=cl_class_map)
349
350        return target_transform

Target transform to map the original class labels to CL class labels according to self.cl_paradigm. It can be used in subclasses when constructing the dataset.

Returns:

  • target_transform (Callable): the target transform function.
@abstractmethod
def train_and_val_dataset(self) -> tuple[typing.Any, typing.Any]:
352    @abstractmethod
353    def train_and_val_dataset(self) -> tuple[Any, Any]:
354        r"""Get the training and validation datasets of the current task `self.task_id`. It must be implemented by subclasses.
355
356        **Returns:**
357        - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation datasets of the current task `self.task_id`.
358        """

Get the training and validation datasets of the current task self.task_id. It must be implemented by subclasses.

Returns:

  • train_and_val_dataset (tuple[Any, Any]): the train and validation datasets of the current task self.task_id.
@abstractmethod
def test_dataset(self) -> Any:
360    @abstractmethod
361    def test_dataset(self) -> Any:
362        r"""Get the test dataset of the current task `self.task_id`. It must be implemented by subclasses.
363
364        **Returns:**
365        - **test_dataset** (`Any`): the test dataset of the current task `self.task_id`.
366        """

Get the test dataset of the current task self.task_id. It must be implemented by subclasses.

Returns:

  • test_dataset (Any): the test dataset of the current task self.task_id.
def train_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
368    def train_dataloader(self) -> DataLoader:
369        r"""DataLoader generator for the train stage of the current task `self.task_id`. It is automatically called before training the task.
370
371        **Returns:**
372        - **train_dataloader** (`DataLoader`): the train DataLoader of task `self.task_id`.
373        """
374
375        pylogger.debug("Construct train dataloader for task %d...", self.task_id)
376
377        return DataLoader(
378            dataset=self.dataset_train_t,
379            batch_size=self.batch_size_t,
380            shuffle=True,  # shuffle train batch to prevent overfitting
381            num_workers=self.num_workers_t,
382            drop_last=True,  # to avoid batchnorm error (when batch_size is 1)
383        )

DataLoader generator for the train stage of the current task self.task_id. It is automatically called before training the task.

Returns:

  • train_dataloader (DataLoader): the train DataLoader of task self.task_id.
def val_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
385    def val_dataloader(self) -> DataLoader:
386        r"""DataLoader generator for the validation stage of the current task `self.task_id`. It is automatically called before the task's validation.
387
388        **Returns:**
389        - **val_dataloader** (`DataLoader`): the validation DataLoader of task `self.task_id`.
390        """
391
392        pylogger.debug("Construct validation dataloader for task %d...", self.task_id)
393
394        return DataLoader(
395            dataset=self.dataset_val_t,
396            batch_size=self.batch_size_t,
397            shuffle=False,  # don't have to shuffle val or test batch
398            num_workers=self.num_workers_t,
399        )

DataLoader generator for the validation stage of the current task self.task_id. It is automatically called before the task's validation.

Returns:

  • val_dataloader (DataLoader): the validation DataLoader of task self.task_id.
def test_dataloader(self) -> dict[int, torch.utils.data.dataloader.DataLoader]:
401    def test_dataloader(self) -> dict[int, DataLoader]:
402        r"""DataLoader generator for the test stage of the current task `self.task_id`. It is automatically called before testing the task.
403
404        **Returns:**
405        - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders.
406        """
407
408        pylogger.debug("Construct test dataloader for task %d...", self.task_id)
409
410        return {
411            task_id: DataLoader(
412                dataset=dataset_test_t,
413                batch_size=self.batch_size_t,
414                shuffle=False,  # don't have to shuffle val or test batch
415                num_workers=self.num_workers_t,
416            )
417            for task_id, dataset_test_t in self.dataset_test.items()
418        }

DataLoader generator for the test stage of the current task self.task_id. It is automatically called before testing the task.

Returns:

  • test_dataloader (dict[int, DataLoader]): the test DataLoader dict of self.task_id and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders.
class CLPermutedDataset(clarena.cl_datasets.CLDataset):
429class CLPermutedDataset(CLDataset):
430    r"""The base class of continual learning datasets constructed as permutations of an original dataset."""
431
432    original_dataset_python_class: type[Dataset]
433    r"""The original dataset class. **It must be provided in subclasses.** """
434
435    def __init__(
436        self,
437        root: str,
438        num_tasks: int,
439        batch_size: int | dict[int, int] = 1,
440        num_workers: int | dict[int, int] = 0,
441        custom_transforms: (
442            Callable
443            | transforms.Compose
444            | None
445            | dict[int, Callable | transforms.Compose | None]
446        ) = None,
447        repeat_channels: int | None | dict[int, int | None] = None,
448        to_tensor: bool | dict[int, bool] = True,
449        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
450        permutation_mode: str = "first_channel_only",
451        permutation_seeds: dict[int, int] | None = None,
452    ) -> None:
453        r"""
454        **Args:**
455        - **root** (`str`): the root directory where the original dataset live.
456        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
457        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
458        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
459        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
460        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
461        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
462        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
463        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
464        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
465        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
466        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
467        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
468        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
469        - **permutation_mode** (`str`): the mode of permutation; one of:
470            1. 'all': permute all pixels.
471            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
472            3. 'first_channel_only': permute only the first channel.
473        - **permutation_seeds** (`dict[int, int]` | `None`): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is `None`, which creates a dict of seeds from 0 to `num_tasks`-1.
474        """
475        super().__init__(
476            root=root,
477            num_tasks=num_tasks,
478            batch_size=batch_size,
479            num_workers=num_workers,
480            custom_transforms=custom_transforms,
481            repeat_channels=repeat_channels,
482            to_tensor=to_tensor,
483            resize=resize,
484        )
485
486        self.original_dataset_constants: type[DatasetConstants] = (
487            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
488        )
489        r"""The original dataset constants class."""
490
491        self.permutation_mode: str = permutation_mode
492        r"""The mode of permutation."""
493        self.permutation_seeds: dict[int, int] = (
494            permutation_seeds
495            if permutation_seeds
496            else {t: t - 1 for t in range(1, num_tasks + 1)}
497        )
498        r"""The dict of permutation seeds for each task."""
499
500        self.permutation_seed_t: int
501        r"""The permutation seed for the current task `self.task_id`."""
502        self.permute_transform_t: Permute
503        r"""The permutation transform for the current task `self.task_id`."""
504
505        CLPermutedDataset.sanity_check(self)
506
507    def sanity_check(self) -> None:
508        r"""Sanity check."""
509
510        # check the permutation mode
511        if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]:
512            raise ValueError(
513                "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'."
514            )
515
516        # check the permutation seeds
517        expected_keys = set(range(1, self.num_tasks + 1))
518        if set(self.permutation_seeds.keys()) != expected_keys:
519            raise ValueError(
520                f"{self.permutation_seeds} dict keys must be consecutive integers from 1 to num_tasks."
521            )
522
523    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
524        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
525
526        **Args:**
527        - **task_id** (`int`): the task ID to query the CL class map.
528
529        **Returns:**
530        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
531            - If `self.cl_paradigm` is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
532            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
533        """
534
535        num_classes_t = (
536            self.original_dataset_constants.NUM_CLASSES
537        )  # the same with the original dataset
538        class_map_t = (
539            self.original_dataset_constants.CLASS_MAP
540        )  # the same with the original dataset
541
542        if self.cl_paradigm == "TIL" or "DIL":
543            return {class_map_t[i]: i for i in range(num_classes_t)}
544        if self.cl_paradigm == "CIL":
545            return {
546                class_map_t[i]: i + (task_id - 1) * num_classes_t
547                for i in range(num_classes_t)
548            }
549
550    def setup_task_id(self, task_id: int) -> None:
551        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
552
553        **Args:**
554        - **task_id** (`int`): the target task ID.
555        """
556
557        CLDataset.setup_task_id(self, task_id)
558
559        self.mean_t = (
560            self.original_dataset_constants.MEAN
561        )  # the same with the original dataset
562        self.std_t = (
563            self.original_dataset_constants.STD
564        )  # the same with the original dataset
565
566        num_channels = (
567            self.original_dataset_constants.NUM_CHANNELS
568            if self.repeat_channels_t is None
569            else self.repeat_channels_t
570        )
571
572        if (
573            hasattr(self.original_dataset_constants, "IMG_SIZE")
574            or self.resize_t is not None
575        ):
576            img_size = (
577                self.original_dataset_constants.IMG_SIZE
578                if self.resize_t is None
579                else torch.Size(self.resize_t)
580            )
581        else:
582            raise AttributeError(
583                "The original dataset has different image sizes. Please resize the images to a fixed size by specifying hyperparameter: resize."
584            )
585
586        # set up the permutation transform
587        self.permutation_seed_t = self.permutation_seeds[task_id]
588        self.permute_transform_t = Permute(
589            num_channels=num_channels,
590            img_size=img_size,
591            mode=self.permutation_mode,
592            seed=self.permutation_seed_t,
593        )
594
595    def train_and_val_transforms(self) -> transforms.Compose:
596        r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. In permuted CL datasets, a permute transform also applies.
597
598        **Returns:**
599        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
600        """
601
602        repeat_channels_transform = (
603            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
604            if self.repeat_channels_t is not None
605            else None
606        )
607        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
608        resize_transform = (
609            transforms.Resize(self.resize_t) if self.resize_t is not None else None
610        )
611        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
612
613        return transforms.Compose(
614            list(
615                filter(
616                    None,
617                    [
618                        repeat_channels_transform,
619                        to_tensor_transform,
620                        resize_transform,
621                        self.permute_transform_t,  # permutation is included here
622                        self.custom_transforms_t,
623                        normalization_transform,
624                    ],
625                )
626            )
627        )  # the order of transforms matters
628
629    def test_transforms(self) -> transforms.Compose:
630        r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. In permuted CL datasets, a permute transform also applies.
631
632        **Returns:**
633        - **test_transforms** (`transforms.Compose`): the composed test transforms.
634        """
635
636        repeat_channels_transform = (
637            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
638            if self.repeat_channels_t is not None
639            else None
640        )
641        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
642        resize_transform = (
643            transforms.Resize(self.resize_t) if self.resize_t is not None else None
644        )
645        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
646
647        return transforms.Compose(
648            list(
649                filter(
650                    None,
651                    [
652                        repeat_channels_transform,
653                        to_tensor_transform,
654                        resize_transform,
655                        self.permute_transform_t,  # permutation is included here
656                        normalization_transform,
657                    ],
658                )
659            )
660        )  # the order of transforms matters. No custom transforms for test

The base class of continual learning datasets constructed as permutations of an original dataset.

CLPermutedDataset( root: str, num_tasks: int, batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, permutation_mode: str = 'first_channel_only', permutation_seeds: dict[int, int] | None = None)
435    def __init__(
436        self,
437        root: str,
438        num_tasks: int,
439        batch_size: int | dict[int, int] = 1,
440        num_workers: int | dict[int, int] = 0,
441        custom_transforms: (
442            Callable
443            | transforms.Compose
444            | None
445            | dict[int, Callable | transforms.Compose | None]
446        ) = None,
447        repeat_channels: int | None | dict[int, int | None] = None,
448        to_tensor: bool | dict[int, bool] = True,
449        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
450        permutation_mode: str = "first_channel_only",
451        permutation_seeds: dict[int, int] | None = None,
452    ) -> None:
453        r"""
454        **Args:**
455        - **root** (`str`): the root directory where the original dataset live.
456        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
457        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
458        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
459        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
460        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
461        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
462        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
463        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
464        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
465        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
466        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
467        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
468        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
469        - **permutation_mode** (`str`): the mode of permutation; one of:
470            1. 'all': permute all pixels.
471            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
472            3. 'first_channel_only': permute only the first channel.
473        - **permutation_seeds** (`dict[int, int]` | `None`): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is `None`, which creates a dict of seeds from 0 to `num_tasks`-1.
474        """
475        super().__init__(
476            root=root,
477            num_tasks=num_tasks,
478            batch_size=batch_size,
479            num_workers=num_workers,
480            custom_transforms=custom_transforms,
481            repeat_channels=repeat_channels,
482            to_tensor=to_tensor,
483            resize=resize,
484        )
485
486        self.original_dataset_constants: type[DatasetConstants] = (
487            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
488        )
489        r"""The original dataset constants class."""
490
491        self.permutation_mode: str = permutation_mode
492        r"""The mode of permutation."""
493        self.permutation_seeds: dict[int, int] = (
494            permutation_seeds
495            if permutation_seeds
496            else {t: t - 1 for t in range(1, num_tasks + 1)}
497        )
498        r"""The dict of permutation seeds for each task."""
499
500        self.permutation_seed_t: int
501        r"""The permutation seed for the current task `self.task_id`."""
502        self.permute_transform_t: Permute
503        r"""The permutation transform for the current task `self.task_id`."""
504
505        CLPermutedDataset.sanity_check(self)

Args:

  • root (str): the root directory where the original dataset live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to num_tasks.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
  • permutation_mode (str): the mode of permutation; one of:
    1. 'all': permute all pixels.
    2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    3. 'first_channel_only': permute only the first channel.
  • permutation_seeds (dict[int, int] | None): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is None, which creates a dict of seeds from 0 to num_tasks-1.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset]

The original dataset class. It must be provided in subclasses.

original_dataset_constants: type[clarena.stl_datasets.raw.constants.DatasetConstants]

The original dataset constants class.

permutation_mode: str

The mode of permutation.

permutation_seeds: dict[int, int]

The dict of permutation seeds for each task.

permutation_seed_t: int

The permutation seed for the current task self.task_id.

permute_transform_t: clarena.utils.transforms.Permute

The permutation transform for the current task self.task_id.

def sanity_check(self) -> None:
507    def sanity_check(self) -> None:
508        r"""Sanity check."""
509
510        # check the permutation mode
511        if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]:
512            raise ValueError(
513                "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'."
514            )
515
516        # check the permutation seeds
517        expected_keys = set(range(1, self.num_tasks + 1))
518        if set(self.permutation_seeds.keys()) != expected_keys:
519            raise ValueError(
520                f"{self.permutation_seeds} dict keys must be consecutive integers from 1 to num_tasks."
521            )

Sanity check.

def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
523    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
524        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
525
526        **Args:**
527        - **task_id** (`int`): the task ID to query the CL class map.
528
529        **Returns:**
530        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
531            - If `self.cl_paradigm` is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
532            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
533        """
534
535        num_classes_t = (
536            self.original_dataset_constants.NUM_CLASSES
537        )  # the same with the original dataset
538        class_map_t = (
539            self.original_dataset_constants.CLASS_MAP
540        )  # the same with the original dataset
541
542        if self.cl_paradigm == "TIL" or "DIL":
543            return {class_map_t[i]: i for i in range(num_classes_t)}
544        if self.cl_paradigm == "CIL":
545            return {
546                class_map_t[i]: i + (task_id - 1) * num_classes_t
547                for i in range(num_classes_t)
548            }

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
550    def setup_task_id(self, task_id: int) -> None:
551        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
552
553        **Args:**
554        - **task_id** (`int`): the target task ID.
555        """
556
557        CLDataset.setup_task_id(self, task_id)
558
559        self.mean_t = (
560            self.original_dataset_constants.MEAN
561        )  # the same with the original dataset
562        self.std_t = (
563            self.original_dataset_constants.STD
564        )  # the same with the original dataset
565
566        num_channels = (
567            self.original_dataset_constants.NUM_CHANNELS
568            if self.repeat_channels_t is None
569            else self.repeat_channels_t
570        )
571
572        if (
573            hasattr(self.original_dataset_constants, "IMG_SIZE")
574            or self.resize_t is not None
575        ):
576            img_size = (
577                self.original_dataset_constants.IMG_SIZE
578                if self.resize_t is None
579                else torch.Size(self.resize_t)
580            )
581        else:
582            raise AttributeError(
583                "The original dataset has different image sizes. Please resize the images to a fixed size by specifying hyperparameter: resize."
584            )
585
586        # set up the permutation transform
587        self.permutation_seed_t = self.permutation_seeds[task_id]
588        self.permute_transform_t = Permute(
589            num_channels=num_channels,
590            img_size=img_size,
591            mode=self.permutation_mode,
592            seed=self.permutation_seed_t,
593        )

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def train_and_val_transforms(self) -> torchvision.transforms.transforms.Compose:
595    def train_and_val_transforms(self) -> transforms.Compose:
596        r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. In permuted CL datasets, a permute transform also applies.
597
598        **Returns:**
599        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
600        """
601
602        repeat_channels_transform = (
603            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
604            if self.repeat_channels_t is not None
605            else None
606        )
607        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
608        resize_transform = (
609            transforms.Resize(self.resize_t) if self.resize_t is not None else None
610        )
611        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
612
613        return transforms.Compose(
614            list(
615                filter(
616                    None,
617                    [
618                        repeat_channels_transform,
619                        to_tensor_transform,
620                        resize_transform,
621                        self.permute_transform_t,  # permutation is included here
622                        self.custom_transforms_t,
623                        normalization_transform,
624                    ],
625                )
626            )
627        )  # the order of transforms matters

Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and ToTensor(). In permuted CL datasets, a permute transform also applies.

Returns:

  • train_and_val_transforms (transforms.Compose): the composed train/val transforms.
def test_transforms(self) -> torchvision.transforms.transforms.Compose:
629    def test_transforms(self) -> transforms.Compose:
630        r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. In permuted CL datasets, a permute transform also applies.
631
632        **Returns:**
633        - **test_transforms** (`transforms.Compose`): the composed test transforms.
634        """
635
636        repeat_channels_transform = (
637            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
638            if self.repeat_channels_t is not None
639            else None
640        )
641        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
642        resize_transform = (
643            transforms.Resize(self.resize_t) if self.resize_t is not None else None
644        )
645        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
646
647        return transforms.Compose(
648            list(
649                filter(
650                    None,
651                    [
652                        repeat_channels_transform,
653                        to_tensor_transform,
654                        resize_transform,
655                        self.permute_transform_t,  # permutation is included here
656                        normalization_transform,
657                    ],
658                )
659            )
660        )  # the order of transforms matters. No custom transforms for test

Transforms for the test dataset. Only basic transforms like normalization and ToTensor() are included. In permuted CL datasets, a permute transform also applies.

Returns:

  • test_transforms (transforms.Compose): the composed test transforms.
class CLRotatedDataset(clarena.cl_datasets.CLDataset):
665class CLRotatedDataset(CLDataset):
666    r"""The base class of continual learning datasets constructed as rotations of an original dataset."""
667
668    original_dataset_python_class: type[Dataset]
669    r"""The original dataset class. **It must be provided in subclasses.** """
670
671    def __init__(
672        self,
673        root: str,
674        num_tasks: int,
675        batch_size: int | dict[int, int] = 1,
676        num_workers: int | dict[int, int] = 0,
677        custom_transforms: (
678            Callable
679            | transforms.Compose
680            | None
681            | dict[int, Callable | transforms.Compose | None]
682        ) = None,
683        repeat_channels: int | None | dict[int, int | None] = None,
684        to_tensor: bool | dict[int, bool] = True,
685        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
686        rotation_degrees: dict[int, float] | list[float] | None = None,
687    ) -> None:
688        r"""
689        **Args:**
690        - **root** (`str`): the root directory where the original dataset live.
691        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
692        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
693        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
694        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
695        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
696        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, and rotation are not included.
697        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
698        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
699        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
700        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
701        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
702        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
703        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
704        - **rotation_degrees** (`dict[int, float]` | `list[float]` | `None`): the rotation degrees for each task. If it is a list, its length must match `num_tasks` and it is mapped to task IDs in order. If it is `None`, angles are evenly spaced in [0, 180) across tasks.
705        """
706        super().__init__(
707            root=root,
708            num_tasks=num_tasks,
709            batch_size=batch_size,
710            num_workers=num_workers,
711            custom_transforms=custom_transforms,
712            repeat_channels=repeat_channels,
713            to_tensor=to_tensor,
714            resize=resize,
715        )
716
717        self.original_dataset_constants: type[DatasetConstants] = (
718            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
719        )
720        r"""The original dataset constants class."""
721
722        if isinstance(rotation_degrees, (DictConfig, ListConfig)):
723            rotation_degrees = OmegaConf.to_container(rotation_degrees)
724
725        if rotation_degrees is None:
726            step = 180.0 / num_tasks
727            rotation_degrees = {
728                t: (t - 1) * step for t in range(1, num_tasks + 1)
729            }
730        elif isinstance(rotation_degrees, (list, tuple)):
731            if len(rotation_degrees) != num_tasks:
732                raise ValueError(
733                    "rotation_degrees length must match num_tasks."
734                )
735            rotation_degrees = {
736                t: float(rotation_degrees[t - 1])
737                for t in range(1, num_tasks + 1)
738            }
739        elif isinstance(rotation_degrees, dict):
740            rotation_degrees = {
741                int(task_id): float(angle)
742                for task_id, angle in rotation_degrees.items()
743            }
744        else:
745            raise TypeError(
746                "rotation_degrees must be a dict, list, or None."
747            )
748
749        self.rotation_degrees: dict[int, float] = rotation_degrees
750        r"""The rotation degrees for each task."""
751        self.rotation_degree_t: float
752        r"""The rotation degree for the current task `self.task_id`."""
753        self.rotation_transform_t: transforms.RandomRotation
754        r"""The rotation transform for the current task `self.task_id`."""
755
756        CLRotatedDataset.sanity_check(self)
757
758    def sanity_check(self) -> None:
759        r"""Sanity check."""
760
761        expected_keys = set(range(1, self.num_tasks + 1))
762        if set(self.rotation_degrees.keys()) != expected_keys:
763            raise ValueError(
764                "rotation_degrees dict keys must be consecutive integers from 1 to num_tasks."
765            )
766
767    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
768        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
769
770        **Args:**
771        - **task_id** (`int`): the task ID to query the CL class map.
772
773        **Returns:**
774        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
775            - If `self.cl_paradigm` is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
776            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
777        """
778        num_classes_t = self.original_dataset_constants.NUM_CLASSES
779        class_map_t = self.original_dataset_constants.CLASS_MAP
780
781        if self.cl_paradigm in ["TIL", "DIL"]:
782            return {class_map_t[i]: i for i in range(num_classes_t)}
783        if self.cl_paradigm == "CIL":
784            return {
785                class_map_t[i]: i + (task_id - 1) * num_classes_t
786                for i in range(num_classes_t)
787            }
788
789        raise ValueError(f"Unsupported cl_paradigm: {self.cl_paradigm}")
790
791    def setup_task_id(self, task_id: int) -> None:
792        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
793
794        **Args:**
795        - **task_id** (`int`): the target task ID.
796        """
797
798        CLDataset.setup_task_id(self, task_id)
799
800        self.mean_t = self.original_dataset_constants.MEAN
801        self.std_t = self.original_dataset_constants.STD
802
803        self.rotation_degree_t = self.rotation_degrees[task_id]
804        self.rotation_transform_t = transforms.RandomRotation(
805            degrees=(self.rotation_degree_t, self.rotation_degree_t),
806            fill=0,
807        )
808
809    def train_and_val_transforms(self) -> transforms.Compose:
810        r"""Transforms for training and validation datasets. Rotation is applied before `ToTensor()` to keep PIL-based rotation.
811
812        **Returns:**
813        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
814        """
815        repeat_channels_transform = (
816            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
817            if self.repeat_channels_t is not None
818            else None
819        )
820        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
821        resize_transform = (
822            transforms.Resize(self.resize_t) if self.resize_t is not None else None
823        )
824        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
825
826        return transforms.Compose(
827            list(
828                filter(
829                    None,
830                    [
831                        repeat_channels_transform,
832                        self.rotation_transform_t,
833                        to_tensor_transform,
834                        resize_transform,
835                        self.custom_transforms_t,
836                        normalization_transform,
837                    ],
838                )
839            )
840        )  # the order of transforms matters
841
842    def test_transforms(self) -> transforms.Compose:
843        r"""Transforms for the test dataset. Rotation is applied before `ToTensor()` to keep PIL-based rotation.
844
845        **Returns:**
846        - **test_transforms** (`transforms.Compose`): the composed test transforms.
847        """
848        repeat_channels_transform = (
849            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
850            if self.repeat_channels_t is not None
851            else None
852        )
853        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
854        resize_transform = (
855            transforms.Resize(self.resize_t) if self.resize_t is not None else None
856        )
857        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
858
859        return transforms.Compose(
860            list(
861                filter(
862                    None,
863                    [
864                        repeat_channels_transform,
865                        self.rotation_transform_t,
866                        to_tensor_transform,
867                        resize_transform,
868                        normalization_transform,
869                    ],
870                )
871            )
872        )  # the order of transforms matters. No custom transforms for test

The base class of continual learning datasets constructed as rotations of an original dataset.

CLRotatedDataset( root: str, num_tasks: int, batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, rotation_degrees: dict[int, float] | list[float] | None = None)
671    def __init__(
672        self,
673        root: str,
674        num_tasks: int,
675        batch_size: int | dict[int, int] = 1,
676        num_workers: int | dict[int, int] = 0,
677        custom_transforms: (
678            Callable
679            | transforms.Compose
680            | None
681            | dict[int, Callable | transforms.Compose | None]
682        ) = None,
683        repeat_channels: int | None | dict[int, int | None] = None,
684        to_tensor: bool | dict[int, bool] = True,
685        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
686        rotation_degrees: dict[int, float] | list[float] | None = None,
687    ) -> None:
688        r"""
689        **Args:**
690        - **root** (`str`): the root directory where the original dataset live.
691        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
692        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
693        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
694        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
695        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
696        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, and rotation are not included.
697        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
698        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
699        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
700        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
701        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
702        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
703        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
704        - **rotation_degrees** (`dict[int, float]` | `list[float]` | `None`): the rotation degrees for each task. If it is a list, its length must match `num_tasks` and it is mapped to task IDs in order. If it is `None`, angles are evenly spaced in [0, 180) across tasks.
705        """
706        super().__init__(
707            root=root,
708            num_tasks=num_tasks,
709            batch_size=batch_size,
710            num_workers=num_workers,
711            custom_transforms=custom_transforms,
712            repeat_channels=repeat_channels,
713            to_tensor=to_tensor,
714            resize=resize,
715        )
716
717        self.original_dataset_constants: type[DatasetConstants] = (
718            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
719        )
720        r"""The original dataset constants class."""
721
722        if isinstance(rotation_degrees, (DictConfig, ListConfig)):
723            rotation_degrees = OmegaConf.to_container(rotation_degrees)
724
725        if rotation_degrees is None:
726            step = 180.0 / num_tasks
727            rotation_degrees = {
728                t: (t - 1) * step for t in range(1, num_tasks + 1)
729            }
730        elif isinstance(rotation_degrees, (list, tuple)):
731            if len(rotation_degrees) != num_tasks:
732                raise ValueError(
733                    "rotation_degrees length must match num_tasks."
734                )
735            rotation_degrees = {
736                t: float(rotation_degrees[t - 1])
737                for t in range(1, num_tasks + 1)
738            }
739        elif isinstance(rotation_degrees, dict):
740            rotation_degrees = {
741                int(task_id): float(angle)
742                for task_id, angle in rotation_degrees.items()
743            }
744        else:
745            raise TypeError(
746                "rotation_degrees must be a dict, list, or None."
747            )
748
749        self.rotation_degrees: dict[int, float] = rotation_degrees
750        r"""The rotation degrees for each task."""
751        self.rotation_degree_t: float
752        r"""The rotation degree for the current task `self.task_id`."""
753        self.rotation_transform_t: transforms.RandomRotation
754        r"""The rotation transform for the current task `self.task_id`."""
755
756        CLRotatedDataset.sanity_check(self)

Args:

  • root (str): the root directory where the original dataset live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to num_tasks.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, and rotation are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
  • rotation_degrees (dict[int, float] | list[float] | None): the rotation degrees for each task. If it is a list, its length must match num_tasks and it is mapped to task IDs in order. If it is None, angles are evenly spaced in [0, 180) across tasks.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset]

The original dataset class. It must be provided in subclasses.

original_dataset_constants: type[clarena.stl_datasets.raw.constants.DatasetConstants]

The original dataset constants class.

rotation_degrees: dict[int, float]

The rotation degrees for each task.

rotation_degree_t: float

The rotation degree for the current task self.task_id.

rotation_transform_t: torchvision.transforms.transforms.RandomRotation

The rotation transform for the current task self.task_id.

def sanity_check(self) -> None:
758    def sanity_check(self) -> None:
759        r"""Sanity check."""
760
761        expected_keys = set(range(1, self.num_tasks + 1))
762        if set(self.rotation_degrees.keys()) != expected_keys:
763            raise ValueError(
764                "rotation_degrees dict keys must be consecutive integers from 1 to num_tasks."
765            )

Sanity check.

def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
767    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
768        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
769
770        **Args:**
771        - **task_id** (`int`): the task ID to query the CL class map.
772
773        **Returns:**
774        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
775            - If `self.cl_paradigm` is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
776            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
777        """
778        num_classes_t = self.original_dataset_constants.NUM_CLASSES
779        class_map_t = self.original_dataset_constants.CLASS_MAP
780
781        if self.cl_paradigm in ["TIL", "DIL"]:
782            return {class_map_t[i]: i for i in range(num_classes_t)}
783        if self.cl_paradigm == "CIL":
784            return {
785                class_map_t[i]: i + (task_id - 1) * num_classes_t
786                for i in range(num_classes_t)
787            }
788
789        raise ValueError(f"Unsupported cl_paradigm: {self.cl_paradigm}")

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL' or 'DIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
791    def setup_task_id(self, task_id: int) -> None:
792        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
793
794        **Args:**
795        - **task_id** (`int`): the target task ID.
796        """
797
798        CLDataset.setup_task_id(self, task_id)
799
800        self.mean_t = self.original_dataset_constants.MEAN
801        self.std_t = self.original_dataset_constants.STD
802
803        self.rotation_degree_t = self.rotation_degrees[task_id]
804        self.rotation_transform_t = transforms.RandomRotation(
805            degrees=(self.rotation_degree_t, self.rotation_degree_t),
806            fill=0,
807        )

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def train_and_val_transforms(self) -> torchvision.transforms.transforms.Compose:
809    def train_and_val_transforms(self) -> transforms.Compose:
810        r"""Transforms for training and validation datasets. Rotation is applied before `ToTensor()` to keep PIL-based rotation.
811
812        **Returns:**
813        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
814        """
815        repeat_channels_transform = (
816            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
817            if self.repeat_channels_t is not None
818            else None
819        )
820        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
821        resize_transform = (
822            transforms.Resize(self.resize_t) if self.resize_t is not None else None
823        )
824        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
825
826        return transforms.Compose(
827            list(
828                filter(
829                    None,
830                    [
831                        repeat_channels_transform,
832                        self.rotation_transform_t,
833                        to_tensor_transform,
834                        resize_transform,
835                        self.custom_transforms_t,
836                        normalization_transform,
837                    ],
838                )
839            )
840        )  # the order of transforms matters

Transforms for training and validation datasets. Rotation is applied before ToTensor() to keep PIL-based rotation.

Returns:

  • train_and_val_transforms (transforms.Compose): the composed train/val transforms.
def test_transforms(self) -> torchvision.transforms.transforms.Compose:
842    def test_transforms(self) -> transforms.Compose:
843        r"""Transforms for the test dataset. Rotation is applied before `ToTensor()` to keep PIL-based rotation.
844
845        **Returns:**
846        - **test_transforms** (`transforms.Compose`): the composed test transforms.
847        """
848        repeat_channels_transform = (
849            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
850            if self.repeat_channels_t is not None
851            else None
852        )
853        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
854        resize_transform = (
855            transforms.Resize(self.resize_t) if self.resize_t is not None else None
856        )
857        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
858
859        return transforms.Compose(
860            list(
861                filter(
862                    None,
863                    [
864                        repeat_channels_transform,
865                        self.rotation_transform_t,
866                        to_tensor_transform,
867                        resize_transform,
868                        normalization_transform,
869                    ],
870                )
871            )
872        )  # the order of transforms matters. No custom transforms for test

Transforms for the test dataset. Rotation is applied before ToTensor() to keep PIL-based rotation.

Returns:

  • test_transforms (transforms.Compose): the composed test transforms.
class CLSplitDataset(clarena.cl_datasets.CLDataset):
 874class CLSplitDataset(CLDataset):
 875    r"""The base class of continual learning datasets constructed as splits of an original dataset."""
 876
 877    original_dataset_python_class: type[Dataset]
 878    r"""The original dataset class. **It must be provided in subclasses.** """
 879
 880    def __init__(
 881        self,
 882        root: str,
 883        class_split: dict[int, list[int]],
 884        batch_size: int | dict[int, int] = 1,
 885        num_workers: int | dict[int, int] = 0,
 886        custom_transforms: (
 887            Callable
 888            | transforms.Compose
 889            | None
 890            | dict[int, Callable | transforms.Compose | None]
 891        ) = None,
 892        repeat_channels: int | None | dict[int, int | None] = None,
 893        to_tensor: bool | dict[int, bool] = True,
 894        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 895    ) -> None:
 896        r"""
 897        **Args:**
 898        - **root** (`str`): the root directory where the original dataset live.
 899        - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
 900        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
 901        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
 902        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
 903        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
 904        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
 905        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 906        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 907        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 908        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 909        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 910        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 911        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 912        """
 913        super().__init__(
 914            root=root,
 915            num_tasks=len(
 916                class_split
 917            ),  # num_tasks is not explicitly provided, but derived from the class_split length
 918            batch_size=batch_size,
 919            num_workers=num_workers,
 920            custom_transforms=custom_transforms,
 921            repeat_channels=repeat_channels,
 922            to_tensor=to_tensor,
 923            resize=resize,
 924        )
 925
 926        self.original_dataset_constants: type[DatasetConstants] = (
 927            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
 928        )
 929        r"""The original dataset constants class. """
 930
 931        self.class_split: dict[int, list[int]] = OmegaConf.to_container(class_split)
 932        r"""The dict of class splits for each task."""
 933
 934        CLSplitDataset.sanity_check(self)
 935
 936    def sanity_check(self) -> None:
 937        r"""Sanity check."""
 938
 939        # check the class split
 940        expected_keys = set(range(1, self.num_tasks + 1))
 941        if set(self.class_split.keys()) != expected_keys:
 942            raise ValueError(
 943                f"{self.class_split} dict keys must be consecutive integers from 1 to num_tasks."
 944            )
 945        if any(len(split) < 2 for split in self.class_split.values()):
 946            raise ValueError("Each class split must contain at least 2 elements!")
 947
 948    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
 949        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
 950
 951        **Args:**
 952        - **task_id** (`int`): the task ID to query the CL class map.
 953
 954        **Returns:**
 955        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
 956            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
 957            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
 958        """
 959        num_classes_t = len(
 960            self.class_split[task_id]
 961        )  # the number of classes in the current task, i.e. the length of the class split
 962        class_map_t = (
 963            self.original_dataset_constants.CLASS_MAP
 964        )  # the same with the original dataset
 965
 966        if self.cl_paradigm == "TIL":
 967            return {
 968                class_map_t[self.class_split[task_id][i]]: i
 969                for i in range(num_classes_t)
 970            }
 971        if self.cl_paradigm == "CIL":
 972            num_classes_previous = sum(
 973                len(self.class_split[i]) for i in range(1, task_id)
 974            )
 975            return {
 976                class_map_t[self.class_split[task_id][i]]: num_classes_previous + i
 977                for i in range(num_classes_t)
 978            }
 979
 980    def setup_task_id(self, task_id: int) -> None:
 981        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
 982
 983        **Args:**
 984        - **task_id** (`int`): the target task ID.
 985        """
 986        super().setup_task_id(task_id)
 987
 988        self.mean_t = (
 989            self.original_dataset_constants.MEAN
 990        )  # the same with the original dataset
 991        self.std_t = (
 992            self.original_dataset_constants.STD
 993        )  # the same with the original dataset
 994
 995    @abstractmethod
 996    def get_subset_of_classes(self, dataset: Dataset) -> Dataset:
 997        r"""Get a subset of classes from the dataset for the current task `self.task_id`. It is used when constructing the split. **It must be implemented by subclasses.**
 998
 999        **Args:**
1000        - **dataset** (`Dataset`): the dataset to retrieve the subset from.
1001
1002        **Returns:**
1003        - **subset** (`Dataset`): the subset of classes from the dataset.
1004        """

The base class of continual learning datasets constructed as splits of an original dataset.

CLSplitDataset( root: str, class_split: dict[int, list[int]], batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
880    def __init__(
881        self,
882        root: str,
883        class_split: dict[int, list[int]],
884        batch_size: int | dict[int, int] = 1,
885        num_workers: int | dict[int, int] = 0,
886        custom_transforms: (
887            Callable
888            | transforms.Compose
889            | None
890            | dict[int, Callable | transforms.Compose | None]
891        ) = None,
892        repeat_channels: int | None | dict[int, int | None] = None,
893        to_tensor: bool | dict[int, bool] = True,
894        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
895    ) -> None:
896        r"""
897        **Args:**
898        - **root** (`str`): the root directory where the original dataset live.
899        - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
900        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
901        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
902        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
903        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
904        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
905        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
906        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
907        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
908        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
909        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
910        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
911        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
912        """
913        super().__init__(
914            root=root,
915            num_tasks=len(
916                class_split
917            ),  # num_tasks is not explicitly provided, but derived from the class_split length
918            batch_size=batch_size,
919            num_workers=num_workers,
920            custom_transforms=custom_transforms,
921            repeat_channels=repeat_channels,
922            to_tensor=to_tensor,
923            resize=resize,
924        )
925
926        self.original_dataset_constants: type[DatasetConstants] = (
927            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
928        )
929        r"""The original dataset constants class. """
930
931        self.class_split: dict[int, list[int]] = OmegaConf.to_container(class_split)
932        r"""The dict of class splits for each task."""
933
934        CLSplitDataset.sanity_check(self)

Args:

  • root (str): the root directory where the original dataset live.
  • class_split (dict[int, list[int]]): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset]

The original dataset class. It must be provided in subclasses.

original_dataset_constants: type[clarena.stl_datasets.raw.constants.DatasetConstants]

The original dataset constants class.

class_split: dict[int, list[int]]

The dict of class splits for each task.

def sanity_check(self) -> None:
936    def sanity_check(self) -> None:
937        r"""Sanity check."""
938
939        # check the class split
940        expected_keys = set(range(1, self.num_tasks + 1))
941        if set(self.class_split.keys()) != expected_keys:
942            raise ValueError(
943                f"{self.class_split} dict keys must be consecutive integers from 1 to num_tasks."
944            )
945        if any(len(split) < 2 for split in self.class_split.values()):
946            raise ValueError("Each class split must contain at least 2 elements!")

Sanity check.

def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
948    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
949        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
950
951        **Args:**
952        - **task_id** (`int`): the task ID to query the CL class map.
953
954        **Returns:**
955        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
956            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
957            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
958        """
959        num_classes_t = len(
960            self.class_split[task_id]
961        )  # the number of classes in the current task, i.e. the length of the class split
962        class_map_t = (
963            self.original_dataset_constants.CLASS_MAP
964        )  # the same with the original dataset
965
966        if self.cl_paradigm == "TIL":
967            return {
968                class_map_t[self.class_split[task_id][i]]: i
969                for i in range(num_classes_t)
970            }
971        if self.cl_paradigm == "CIL":
972            num_classes_previous = sum(
973                len(self.class_split[i]) for i in range(1, task_id)
974            )
975            return {
976                class_map_t[self.class_split[task_id][i]]: num_classes_previous + i
977                for i in range(num_classes_t)
978            }

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
980    def setup_task_id(self, task_id: int) -> None:
981        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
982
983        **Args:**
984        - **task_id** (`int`): the target task ID.
985        """
986        super().setup_task_id(task_id)
987
988        self.mean_t = (
989            self.original_dataset_constants.MEAN
990        )  # the same with the original dataset
991        self.std_t = (
992            self.original_dataset_constants.STD
993        )  # the same with the original dataset

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
@abstractmethod
def get_subset_of_classes( self, dataset: torch.utils.data.dataset.Dataset) -> torch.utils.data.dataset.Dataset:
 995    @abstractmethod
 996    def get_subset_of_classes(self, dataset: Dataset) -> Dataset:
 997        r"""Get a subset of classes from the dataset for the current task `self.task_id`. It is used when constructing the split. **It must be implemented by subclasses.**
 998
 999        **Args:**
1000        - **dataset** (`Dataset`): the dataset to retrieve the subset from.
1001
1002        **Returns:**
1003        - **subset** (`Dataset`): the subset of classes from the dataset.
1004        """

Get a subset of classes from the dataset for the current task self.task_id. It is used when constructing the split. It must be implemented by subclasses.

Args:

  • dataset (Dataset): the dataset to retrieve the subset from.

Returns:

  • subset (Dataset): the subset of classes from the dataset.
class CLCombinedDataset(clarena.cl_datasets.CLDataset):
1007class CLCombinedDataset(CLDataset):
1008    r"""The base class of continual learning datasets constructed as combinations of several single-task datasets (one dataset per task)."""
1009
1010    def __init__(
1011        self,
1012        datasets: dict[int, str],
1013        root: str | dict[int, str],
1014        batch_size: int | dict[int, int] = 1,
1015        num_workers: int | dict[int, int] = 0,
1016        custom_transforms: (
1017            Callable
1018            | transforms.Compose
1019            | None
1020            | dict[int, Callable | transforms.Compose | None]
1021        ) = None,
1022        repeat_channels: int | None | dict[int, int | None] = None,
1023        to_tensor: bool | dict[int, bool] = True,
1024        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
1025    ) -> None:
1026        r"""
1027        **Args:**
1028        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
1029        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live.
1030        If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
1031        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
1032        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
1033        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
1034        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
1035        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
1036        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
1037        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
1038        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
1039        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
1040        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
1041        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
1042        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
1043        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
1044        """
1045        super().__init__(
1046            root=root,
1047            num_tasks=len(
1048                datasets
1049            ),  # num_tasks is not explicitly provided, but derived from the datasets length
1050            batch_size=batch_size,
1051            num_workers=num_workers,
1052            custom_transforms=custom_transforms,
1053            repeat_channels=repeat_channels,
1054            to_tensor=to_tensor,
1055            resize=resize,
1056        )
1057
1058        self.original_dataset_python_classes: dict[int, Dataset] = {
1059            t: str_to_class(dataset_class_path)
1060            for t, dataset_class_path in datasets.items()
1061        }
1062        r"""The dict of dataset classes for each task."""
1063        self.original_dataset_python_class_t: Dataset
1064        r"""The dataset class for the current task `self.task_id`."""
1065        self.original_dataset_constants_t: type[DatasetConstants]
1066        r"""The original dataset constants class for the current task `self.task_id`."""
1067
1068    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
1069        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
1070
1071        **Args:**
1072        - **task_id** (`int`): the task ID to query the CL class map.
1073
1074        **Returns:**
1075        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
1076            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
1077            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
1078        """
1079        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
1080        original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[
1081            original_dataset_python_class_t
1082        ]
1083        num_classes_t = original_dataset_constants_t.NUM_CLASSES
1084        class_map_t = original_dataset_constants_t.CLASS_MAP
1085
1086        if self.cl_paradigm == "TIL":
1087            return {class_map_t[i]: i for i in range(num_classes_t)}
1088        if self.cl_paradigm == "CIL":
1089            num_classes_previous = sum(
1090                [
1091                    DATASET_CONSTANTS_MAPPING[
1092                        self.original_dataset_python_classes[i]
1093                    ].NUM_CLASSES
1094                    for i in range(1, task_id)
1095                ]
1096            )
1097            return {
1098                class_map_t[i]: num_classes_previous + i for i in range(num_classes_t)
1099            }
1100
1101    def setup_task_id(self, task_id: int) -> None:
1102        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
1103
1104        **Args:**
1105        - **task_id** (`int`): the target task ID.
1106        """
1107
1108        self.original_dataset_python_class_t = self.original_dataset_python_classes[
1109            task_id
1110        ]
1111
1112        self.original_dataset_constants_t: type[DatasetConstants] = (
1113            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class_t]
1114        )
1115
1116        super().setup_task_id(task_id)
1117
1118        self.mean_t = self.original_dataset_constants_t.MEAN
1119        self.std_t = self.original_dataset_constants_t.STD

The base class of continual learning datasets constructed as combinations of several single-task datasets (one dataset per task).

CLCombinedDataset( datasets: dict[int, str], root: str | dict[int, str], batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
1010    def __init__(
1011        self,
1012        datasets: dict[int, str],
1013        root: str | dict[int, str],
1014        batch_size: int | dict[int, int] = 1,
1015        num_workers: int | dict[int, int] = 0,
1016        custom_transforms: (
1017            Callable
1018            | transforms.Compose
1019            | None
1020            | dict[int, Callable | transforms.Compose | None]
1021        ) = None,
1022        repeat_channels: int | None | dict[int, int | None] = None,
1023        to_tensor: bool | dict[int, bool] = True,
1024        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
1025    ) -> None:
1026        r"""
1027        **Args:**
1028        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
1029        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live.
1030        If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
1031        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
1032        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
1033        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
1034        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
1035        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
1036        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
1037        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
1038        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
1039        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
1040        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
1041        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
1042        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
1043        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
1044        """
1045        super().__init__(
1046            root=root,
1047            num_tasks=len(
1048                datasets
1049            ),  # num_tasks is not explicitly provided, but derived from the datasets length
1050            batch_size=batch_size,
1051            num_workers=num_workers,
1052            custom_transforms=custom_transforms,
1053            repeat_channels=repeat_channels,
1054            to_tensor=to_tensor,
1055            resize=resize,
1056        )
1057
1058        self.original_dataset_python_classes: dict[int, Dataset] = {
1059            t: str_to_class(dataset_class_path)
1060            for t, dataset_class_path in datasets.items()
1061        }
1062        r"""The dict of dataset classes for each task."""
1063        self.original_dataset_python_class_t: Dataset
1064        r"""The dataset class for the current task `self.task_id`."""
1065        self.original_dataset_constants_t: type[DatasetConstants]
1066        r"""The original dataset constants class for the current task `self.task_id`."""

Args:

  • datasets (dict[int, str]): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
  • root (str | dict[int, str]): the root directory where the original data files for constructing the CL dataset physically live. If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to num_tasks.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
original_dataset_python_classes: dict[int, torch.utils.data.dataset.Dataset]

The dict of dataset classes for each task.

original_dataset_python_class_t: torch.utils.data.dataset.Dataset

The dataset class for the current task self.task_id.

original_dataset_constants_t: type[clarena.stl_datasets.raw.constants.DatasetConstants]

The original dataset constants class for the current task self.task_id.

def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
1068    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
1069        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
1070
1071        **Args:**
1072        - **task_id** (`int`): the task ID to query the CL class map.
1073
1074        **Returns:**
1075        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
1076            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
1077            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
1078        """
1079        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
1080        original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[
1081            original_dataset_python_class_t
1082        ]
1083        num_classes_t = original_dataset_constants_t.NUM_CLASSES
1084        class_map_t = original_dataset_constants_t.CLASS_MAP
1085
1086        if self.cl_paradigm == "TIL":
1087            return {class_map_t[i]: i for i in range(num_classes_t)}
1088        if self.cl_paradigm == "CIL":
1089            num_classes_previous = sum(
1090                [
1091                    DATASET_CONSTANTS_MAPPING[
1092                        self.original_dataset_python_classes[i]
1093                    ].NUM_CLASSES
1094                    for i in range(1, task_id)
1095                ]
1096            )
1097            return {
1098                class_map_t[i]: num_classes_previous + i for i in range(num_classes_t)
1099            }

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
1101    def setup_task_id(self, task_id: int) -> None:
1102        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
1103
1104        **Args:**
1105        - **task_id** (`int`): the target task ID.
1106        """
1107
1108        self.original_dataset_python_class_t = self.original_dataset_python_classes[
1109            task_id
1110        ]
1111
1112        self.original_dataset_constants_t: type[DatasetConstants] = (
1113            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class_t]
1114        )
1115
1116        super().setup_task_id(task_id)
1117
1118        self.mean_t = self.original_dataset_constants_t.MEAN
1119        self.std_t = self.original_dataset_constants_t.STD

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.