clarena.cl_datasets

Continual Learning Datasets

This submodule provides the continual learning datasets that can be used in CLArena.

Here are the base classes for continual learning datasets, which inherit from Lightning LightningDataModule:

Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement continual learning datasets:

  1r"""
  2
  3# Continual Learning Datasets
  4
  5This submodule provides the **continual learning datasets** that can be used in CLArena.
  6
  7Here are the base classes for continual learning datasets, which inherit from Lightning `LightningDataModule`:
  8
  9- `CLDataset`: The base class for all continual learning datasets.
 10    - `CLPermutedDataset`: The base class for permuted continual learning datasets. A child class of `CLDataset`.
 11    - `CLSplitDataset`: The base class for split continual learning datasets. A child class of `CLDataset`.
 12    - `CLCombinedDataset`: The base class for combined continual learning datasets. A child class of `CLDataset`.
 13
 14Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement continual learning datasets:
 15
 16- [**Configure CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/CL-dataset)
 17- [**Implement Custom CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/cl_dataset)
 18- [**A Beginners' Guide to Continual Learning (CL Dataset)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-CL-dataset)
 19
 20
 21"""
 22
 23from .base import (
 24    CLDataset,
 25    CLPermutedDataset,
 26    CLSplitDataset,
 27    CLCombinedDataset,
 28)
 29
 30from .permuted_mnist import PermutedMNIST
 31from .permuted_emnist import PermutedEMNIST
 32from .permuted_fashionmnist import PermutedFashionMNIST
 33from .permuted_kmnist import PermutedKMNIST
 34from .permuted_notmnist import PermutedNotMNIST
 35from .permuted_sign_language_mnist import PermutedSignLanguageMNIST
 36from .permuted_ahdd import PermutedArabicHandwrittenDigits
 37from .permuted_kannadamnist import PermutedKannadaMNIST
 38from .permuted_svhn import PermutedSVHN
 39from .permuted_country211 import PermutedCountry211
 40from .permuted_imagenette import PermutedImagenette
 41from .permuted_dtd import PermutedDTD
 42from .permuted_cifar10 import PermutedCIFAR10
 43from .permuted_cifar100 import PermutedCIFAR100
 44from .permuted_caltech101 import PermutedCaltech101
 45from .permuted_caltech256 import PermutedCaltech256
 46from .permuted_eurosat import PermutedEuroSAT
 47from .permuted_fgvc_aircraft import PermutedFGVCAircraft
 48from .permuted_flowers102 import PermutedFlowers102
 49from .permuted_food101 import PermutedFood101
 50from .permuted_celeba import PermutedCelebA
 51from .permuted_fer2013 import PermutedFER2013
 52from .permuted_tinyimagenet import PermutedTinyImageNet
 53from .permuted_oxford_iiit_pet import PermutedOxfordIIITPet
 54from .permuted_pcam import PermutedPCAM
 55from .permuted_renderedsst2 import PermutedRenderedSST2
 56from .permuted_stanfordcars import PermutedStanfordCars
 57from .permuted_sun397 import PermutedSUN397
 58from .permuted_usps import PermutedUSPS
 59from .permuted_SEMEION import PermutedSEMEION
 60from .permuted_facescrub import PermutedFaceScrub
 61from .permuted_cub2002011 import PermutedCUB2002011
 62from .permuted_gtsrb import PermutedGTSRB
 63from .permuted_linnaeus5 import PermutedLinnaeus5
 64
 65from .split_cifar10 import SplitCIFAR10
 66from .split_mnist import SplitMNIST
 67from .split_cifar100 import SplitCIFAR100
 68from .split_tinyimagenet import SplitTinyImageNet
 69from .split_cub2002011 import SplitCUB2002011
 70
 71from .combined import Combined
 72
 73
 74__all__ = [
 75    "CLDataset",
 76    "CLPermutedDataset",
 77    "CLSplitDataset",
 78    "CLCombinedDataset",
 79    "combined",
 80    "permuted_mnist",
 81    "permuted_emnist",
 82    "permuted_fashionmnist",
 83    "permuted_imagenette",
 84    "permuted_sign_language_mnist",
 85    "permuted_ahdd",
 86    "permuted_kannadamnist",
 87    "permuted_country211",
 88    "permuted_dtd",
 89    "permuted_fer2013",
 90    "permuted_fgvc_aircraft",
 91    "permuted_flowers102",
 92    "permuted_food101",
 93    "permuted_kmnist",
 94    "permuted_notmnist",
 95    "permuted_svhn",
 96    "permuted_cifar10",
 97    "permuted_cifar100",
 98    "permuted_caltech101",
 99    "permuted_caltech256",
100    "permuted_oxford_iiit_pet",
101    "permuted_celeba",
102    "permuted_eurosat",
103    "permuted_facescrub",
104    "permuted_pcam",
105    "permuted_renderedsst2",
106    "permuted_stanfordcars",
107    "permuted_sun397",
108    "permuted_usps",
109    "permuted_SEMEION",
110    "permuted_tinyimagenet",
111    "permuted_cub2002011",
112    "permuted_gtsrb",
113    "permuted_linnaeus5",
114    "split_mnist",
115    "split_cifar10",
116    "split_cifar100",
117    "split_tinyimagenet",
118    "split_cub2002011",
119]
class CLDataset(lightning.pytorch.core.datamodule.LightningDataModule):
 34class CLDataset(LightningDataModule):
 35    r"""The base class of continual learning datasets."""
 36
 37    def __init__(
 38        self,
 39        root: str | dict[int, str],
 40        num_tasks: int,
 41        batch_size: int | dict[int, int] = 1,
 42        num_workers: int | dict[int, int] = 0,
 43        custom_transforms: (
 44            Callable
 45            | transforms.Compose
 46            | None
 47            | dict[int, Callable | transforms.Compose | None]
 48        ) = None,
 49        repeat_channels: int | None | dict[int, int | None] = None,
 50        to_tensor: bool | dict[int, bool] = True,
 51        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 52    ) -> None:
 53        r"""
 54        **Args:**
 55        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live.
 56        If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
 57        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
 58        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
 59        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
 60        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
 61        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
 62        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
 63        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 64        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 65        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 66        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 67        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 68        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 69        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 70        """
 71        super().__init__()
 72
 73        self.root: dict[int, str] = (
 74            OmegaConf.to_container(root)
 75            if isinstance(root, DictConfig)
 76            else {t: root for t in range(1, num_tasks + 1)}
 77        )
 78        r"""The dict of root directories of the original data files for each task."""
 79        self.num_tasks: int = num_tasks
 80        r"""The maximum number of tasks supported by the dataset."""
 81        self.cl_paradigm: str
 82        r"""The continual learning paradigm."""
 83        self.batch_size: dict[int, int] = (
 84            OmegaConf.to_container(batch_size)
 85            if isinstance(batch_size, DictConfig)
 86            else {t: batch_size for t in range(1, num_tasks + 1)}
 87        )
 88        r"""The dict of batch sizes for each task."""
 89        self.num_workers: dict[int, int] = (
 90            OmegaConf.to_container(num_workers)
 91            if isinstance(num_workers, DictConfig)
 92            else {t: num_workers for t in range(1, num_tasks + 1)}
 93        )
 94        r"""The dict of numbers of workers for each task."""
 95        self.custom_transforms: dict[int, Callable | transforms.Compose | None] = (
 96            OmegaConf.to_container(custom_transforms)
 97            if isinstance(custom_transforms, DictConfig)
 98            else {t: custom_transforms for t in range(1, num_tasks + 1)}
 99        )
100        r"""The dict of custom transforms for each task."""
101        self.repeat_channels: dict[int, int | None] = (
102            OmegaConf.to_container(repeat_channels)
103            if isinstance(repeat_channels, DictConfig)
104            else {t: repeat_channels for t in range(1, num_tasks + 1)}
105        )
106        r"""The dict of number of channels to repeat for each task."""
107        self.to_tensor: dict[int, bool] = (
108            OmegaConf.to_container(to_tensor)
109            if isinstance(to_tensor, DictConfig)
110            else {t: to_tensor for t in range(1, num_tasks + 1)}
111        )
112        r"""The dict of to_tensor flag for each task. """
113        self.resize: dict[int, tuple[int, int] | None] = (
114            {t: tuple(rs) if rs else None for t, rs in resize.items()}
115            if isinstance(resize, DictConfig)
116            else {
117                t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1)
118            }
119        )
120        r"""The dict of sizes to resize to for each task."""
121
122        # task-specific attributes
123        self.root_t: str
124        r"""The root directory of the original data files for the current task `self.task_id`."""
125        self.batch_size_t: int
126        r"""The batch size for the current task `self.task_id`."""
127        self.num_workers_t: int
128        r"""The number of workers for the current task `self.task_id`."""
129        self.custom_transforms_t: Callable | transforms.Compose | None
130        r"""The custom transforms for the current task `self.task_id`."""
131        self.repeat_channels_t: int | None
132        r"""The number of channels to repeat for the current task `self.task_id`."""
133        self.to_tensor_t: bool
134        r"""The to_tensor flag for the current task `self.task_id`."""
135        self.resize_t: tuple[int, int] | None
136        r"""The size to resize for the current task `self.task_id`."""
137        self.mean_t: float
138        r"""The mean values for normalization for the current task `self.task_id`."""
139        self.std_t: float
140        r"""The standard deviation values for normalization for the current task `self.task_id`."""
141
142        # dataset containers
143        self.dataset_train_t: Any
144        r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object."""
145        self.dataset_val_t: Any
146        r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object."""
147        self.dataset_test: dict[int, Any] = {}
148        r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects."""
149
150        # task ID control
151        self.task_id: int
152        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
153        self.processed_task_ids: list[int] = []
154        r"""Task IDs that have been processed."""
155
156        CLDataset.sanity_check(self)
157
158    def sanity_check(self) -> None:
159        r"""Sanity check."""
160
161        # check if each task has been provided with necessary arguments
162        for attr in [
163            "root",
164            "batch_size",
165            "num_workers",
166            "custom_transforms",
167            "repeat_channels",
168            "to_tensor",
169            "resize",
170        ]:
171            value = getattr(self, attr)
172            expected_keys = set(range(1, self.num_tasks + 1))
173            if set(value.keys()) != expected_keys:
174                raise ValueError(
175                    f"{attr} dict keys must be consecutive integers from 1 to num_tasks."
176                )
177
178    @abstractmethod
179    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
180        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses.
181
182        **Args:**
183        - **task_id** (`int`): the task ID to query the CL class map.
184
185        **Returns:**
186        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
187            - If `self.cl_paradigm` is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes.
188            - If `self.cl_paradigm` is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
189        """
190
191    @abstractmethod
192    def prepare_data(self) -> None:
193        r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDataModule`. This method is called at the beginning of each task."""
194
195    def setup(self, stage: str) -> None:
196        r"""Set up the dataset for different stages. This method is called at the beginning of each task.
197
198        **Args:**
199        - **stage** (`str`): the stage of the experiment; one of:
200            - 'fit': training and validation datasets of the current task `self.task_id` are assigned to `self.dataset_train_t` and `self.dataset_val_t`.
201            - 'test': a dict of test datasets of all seen tasks should be assigned to `self.dataset_test`.
202        """
203        if stage == "fit":
204            # these two stages must be done together because a sanity check for validation is conducted before training
205            pylogger.debug(
206                "Construct train and validation dataset for task %d...", self.task_id
207            )
208
209            self.dataset_train_t, self.dataset_val_t = self.train_and_val_dataset()
210
211            pylogger.info(
212                "Train and validation dataset for task %d are ready.", self.task_id
213            )
214            pylogger.info(
215                "Train dataset for task %d size: %d",
216                self.task_id,
217                len(self.dataset_train_t),
218            )
219            pylogger.info(
220                "Validation dataset for task %d size: %d",
221                self.task_id,
222                len(self.dataset_val_t),
223            )
224
225        elif stage == "test":
226
227            pylogger.debug("Construct test dataset for task %d...", self.task_id)
228
229            self.dataset_test[self.task_id] = self.test_dataset()
230
231            pylogger.info("Test dataset for task %d are ready.", self.task_id)
232            pylogger.info(
233                "Test dataset for task %d size: %d",
234                self.task_id,
235                len(self.dataset_test[self.task_id]),
236            )
237
238    def setup_task_id(self, task_id: int) -> None:
239        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
240
241        **Args:**
242        - **task_id** (`int`): the target task ID.
243        """
244
245        self.task_id = task_id
246
247        self.root_t = self.root[task_id]
248        self.batch_size_t = self.batch_size[task_id]
249        self.num_workers_t = self.num_workers[task_id]
250        self.custom_transforms_t = self.custom_transforms[task_id]
251        self.repeat_channels_t = self.repeat_channels[task_id]
252        self.to_tensor_t = self.to_tensor[task_id]
253        self.resize_t = self.resize[task_id]
254
255        self.processed_task_ids.append(task_id)
256
257    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
258        r"""Set up tasks for continual learning main evaluation.
259
260        **Args:**
261        - **eval_tasks** (`list[int]`): the list of task IDs to evaluate.
262        """
263        for task_id in eval_tasks:
264            self.setup_task_id(task_id=task_id)
265            self.setup(stage="test")
266
267    def set_cl_paradigm(self, cl_paradigm: str) -> None:
268        r"""Set `cl_paradigm` to `self.cl_paradigm`. It is used to define the CL class map.
269
270        **Args:**
271        - **cl_paradigm** (`str`): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
272        """
273        self.cl_paradigm = cl_paradigm
274
275    def train_and_val_transforms(self) -> transforms.Compose:
276        r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. It can be used in subclasses when constructing the dataset.
277
278        **Returns:**
279        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
280        """
281        repeat_channels_transform = (
282            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
283            if self.repeat_channels_t is not None
284            else None
285        )
286        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
287        resize_transform = (
288            transforms.Resize(self.resize_t) if self.resize_t is not None else None
289        )
290        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
291
292        return transforms.Compose(
293            list(
294                filter(
295                    None,
296                    [
297                        repeat_channels_transform,
298                        to_tensor_transform,
299                        resize_transform,
300                        self.custom_transforms_t,
301                        normalization_transform,
302                    ],
303                )
304            )
305        )  # the order of transforms matters
306
307    def test_transforms(self) -> transforms.Compose:
308        r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. It is used in subclasses when constructing the dataset.
309
310        **Returns:**
311        - **test_transforms** (`transforms.Compose`): the composed test transforms.
312        """
313        repeat_channels_transform = (
314            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
315            if self.repeat_channels_t is not None
316            else None
317        )
318        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
319        resize_transform = (
320            transforms.Resize(self.resize_t) if self.resize_t is not None else None
321        )
322        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
323
324        return transforms.Compose(
325            list(
326                filter(
327                    None,
328                    [
329                        repeat_channels_transform,
330                        to_tensor_transform,
331                        resize_transform,
332                        normalization_transform,
333                    ],
334                )
335            )
336        )  # the order of transforms matters. No custom transforms for test
337
338    def target_transform(self) -> ClassMapping:
339        r"""Target transform to map the original class labels to CL class labels according to `self.cl_paradigm`. It can be used in subclasses when constructing the dataset.
340
341        **Returns:**
342        - **target_transform** (`Callable`): the target transform function.
343        """
344
345        cl_class_map = self.get_cl_class_map(task_id=self.task_id)
346
347        target_transform = ClassMapping(class_map=cl_class_map)
348
349        return target_transform
350
351    @abstractmethod
352    def train_and_val_dataset(self) -> tuple[Any, Any]:
353        r"""Get the training and validation datasets of the current task `self.task_id`. It must be implemented by subclasses.
354
355        **Returns:**
356        - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation datasets of the current task `self.task_id`.
357        """
358
359    @abstractmethod
360    def test_dataset(self) -> Any:
361        r"""Get the test dataset of the current task `self.task_id`. It must be implemented by subclasses.
362
363        **Returns:**
364        - **test_dataset** (`Any`): the test dataset of the current task `self.task_id`.
365        """
366
367    def train_dataloader(self) -> DataLoader:
368        r"""DataLoader generator for the train stage of the current task `self.task_id`. It is automatically called before training the task.
369
370        **Returns:**
371        - **train_dataloader** (`DataLoader`): the train DataLoader of task `self.task_id`.
372        """
373
374        pylogger.debug("Construct train dataloader for task %d...", self.task_id)
375
376        return DataLoader(
377            dataset=self.dataset_train_t,
378            batch_size=self.batch_size_t,
379            shuffle=True,  # shuffle train batch to prevent overfitting
380            num_workers=self.num_workers_t,
381            drop_last=True,  # to avoid batchnorm error (when batch_size is 1)
382        )
383
384    def val_dataloader(self) -> DataLoader:
385        r"""DataLoader generator for the validation stage of the current task `self.task_id`. It is automatically called before the task's validation.
386
387        **Returns:**
388        - **val_dataloader** (`DataLoader`): the validation DataLoader of task `self.task_id`.
389        """
390
391        pylogger.debug("Construct validation dataloader for task %d...", self.task_id)
392
393        return DataLoader(
394            dataset=self.dataset_val_t,
395            batch_size=self.batch_size_t,
396            shuffle=False,  # don't have to shuffle val or test batch
397            num_workers=self.num_workers_t,
398        )
399
400    def test_dataloader(self) -> dict[int, DataLoader]:
401        r"""DataLoader generator for the test stage of the current task `self.task_id`. It is automatically called before testing the task.
402
403        **Returns:**
404        - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders.
405        """
406
407        pylogger.debug("Construct test dataloader for task %d...", self.task_id)
408
409        return {
410            task_id: DataLoader(
411                dataset=dataset_test_t,
412                batch_size=self.batch_size_t,
413                shuffle=False,  # don't have to shuffle val or test batch
414                num_workers=self.num_workers_t,
415            )
416            for task_id, dataset_test_t in self.dataset_test.items()
417        }
418
419    def __len__(self) -> int:
420        r"""Get the number of tasks in the dataset.
421
422        **Returns:**
423        - **num_tasks** (`int`): the number of tasks in the dataset.
424        """
425        return self.num_tasks

The base class of continual learning datasets.

CLDataset( root: str | dict[int, str], num_tasks: int, batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
 37    def __init__(
 38        self,
 39        root: str | dict[int, str],
 40        num_tasks: int,
 41        batch_size: int | dict[int, int] = 1,
 42        num_workers: int | dict[int, int] = 0,
 43        custom_transforms: (
 44            Callable
 45            | transforms.Compose
 46            | None
 47            | dict[int, Callable | transforms.Compose | None]
 48        ) = None,
 49        repeat_channels: int | None | dict[int, int | None] = None,
 50        to_tensor: bool | dict[int, bool] = True,
 51        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 52    ) -> None:
 53        r"""
 54        **Args:**
 55        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live.
 56        If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
 57        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
 58        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
 59        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
 60        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
 61        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
 62        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
 63        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 64        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 65        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 66        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 67        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 68        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 69        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 70        """
 71        super().__init__()
 72
 73        self.root: dict[int, str] = (
 74            OmegaConf.to_container(root)
 75            if isinstance(root, DictConfig)
 76            else {t: root for t in range(1, num_tasks + 1)}
 77        )
 78        r"""The dict of root directories of the original data files for each task."""
 79        self.num_tasks: int = num_tasks
 80        r"""The maximum number of tasks supported by the dataset."""
 81        self.cl_paradigm: str
 82        r"""The continual learning paradigm."""
 83        self.batch_size: dict[int, int] = (
 84            OmegaConf.to_container(batch_size)
 85            if isinstance(batch_size, DictConfig)
 86            else {t: batch_size for t in range(1, num_tasks + 1)}
 87        )
 88        r"""The dict of batch sizes for each task."""
 89        self.num_workers: dict[int, int] = (
 90            OmegaConf.to_container(num_workers)
 91            if isinstance(num_workers, DictConfig)
 92            else {t: num_workers for t in range(1, num_tasks + 1)}
 93        )
 94        r"""The dict of numbers of workers for each task."""
 95        self.custom_transforms: dict[int, Callable | transforms.Compose | None] = (
 96            OmegaConf.to_container(custom_transforms)
 97            if isinstance(custom_transforms, DictConfig)
 98            else {t: custom_transforms for t in range(1, num_tasks + 1)}
 99        )
100        r"""The dict of custom transforms for each task."""
101        self.repeat_channels: dict[int, int | None] = (
102            OmegaConf.to_container(repeat_channels)
103            if isinstance(repeat_channels, DictConfig)
104            else {t: repeat_channels for t in range(1, num_tasks + 1)}
105        )
106        r"""The dict of number of channels to repeat for each task."""
107        self.to_tensor: dict[int, bool] = (
108            OmegaConf.to_container(to_tensor)
109            if isinstance(to_tensor, DictConfig)
110            else {t: to_tensor for t in range(1, num_tasks + 1)}
111        )
112        r"""The dict of to_tensor flag for each task. """
113        self.resize: dict[int, tuple[int, int] | None] = (
114            {t: tuple(rs) if rs else None for t, rs in resize.items()}
115            if isinstance(resize, DictConfig)
116            else {
117                t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1)
118            }
119        )
120        r"""The dict of sizes to resize to for each task."""
121
122        # task-specific attributes
123        self.root_t: str
124        r"""The root directory of the original data files for the current task `self.task_id`."""
125        self.batch_size_t: int
126        r"""The batch size for the current task `self.task_id`."""
127        self.num_workers_t: int
128        r"""The number of workers for the current task `self.task_id`."""
129        self.custom_transforms_t: Callable | transforms.Compose | None
130        r"""The custom transforms for the current task `self.task_id`."""
131        self.repeat_channels_t: int | None
132        r"""The number of channels to repeat for the current task `self.task_id`."""
133        self.to_tensor_t: bool
134        r"""The to_tensor flag for the current task `self.task_id`."""
135        self.resize_t: tuple[int, int] | None
136        r"""The size to resize for the current task `self.task_id`."""
137        self.mean_t: float
138        r"""The mean values for normalization for the current task `self.task_id`."""
139        self.std_t: float
140        r"""The standard deviation values for normalization for the current task `self.task_id`."""
141
142        # dataset containers
143        self.dataset_train_t: Any
144        r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object."""
145        self.dataset_val_t: Any
146        r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object."""
147        self.dataset_test: dict[int, Any] = {}
148        r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects."""
149
150        # task ID control
151        self.task_id: int
152        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
153        self.processed_task_ids: list[int] = []
154        r"""Task IDs that have been processed."""
155
156        CLDataset.sanity_check(self)

Args:

  • root (str | dict[int, str]): the root directory where the original data files for constructing the CL dataset physically live. If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to num_tasks.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
root: dict[int, str]

The dict of root directories of the original data files for each task.

num_tasks: int

The maximum number of tasks supported by the dataset.

cl_paradigm: str

The continual learning paradigm.

batch_size: dict[int, int]

The dict of batch sizes for each task.

num_workers: dict[int, int]

The dict of numbers of workers for each task.

custom_transforms: dict[int, typing.Union[typing.Callable, torchvision.transforms.transforms.Compose, NoneType]]

The dict of custom transforms for each task.

repeat_channels: dict[int, int | None]

The dict of number of channels to repeat for each task.

to_tensor: dict[int, bool]

The dict of to_tensor flag for each task.

resize: dict[int, tuple[int, int] | None]

The dict of sizes to resize to for each task.

root_t: str

The root directory of the original data files for the current task self.task_id.

batch_size_t: int

The batch size for the current task self.task_id.

num_workers_t: int

The number of workers for the current task self.task_id.

custom_transforms_t: Union[Callable, torchvision.transforms.transforms.Compose, NoneType]

The custom transforms for the current task self.task_id.

repeat_channels_t: int | None

The number of channels to repeat for the current task self.task_id.

to_tensor_t: bool

The to_tensor flag for the current task self.task_id.

resize_t: tuple[int, int] | None

The size to resize for the current task self.task_id.

mean_t: float

The mean values for normalization for the current task self.task_id.

std_t: float

The standard deviation values for normalization for the current task self.task_id.

dataset_train_t: Any

The training dataset object. Can be a PyTorch Dataset object or any other dataset object.

dataset_val_t: Any

The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.

dataset_test: dict[int, typing.Any]

The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.

processed_task_ids: list[int]

Task IDs that have been processed.

def sanity_check(self) -> None:
158    def sanity_check(self) -> None:
159        r"""Sanity check."""
160
161        # check if each task has been provided with necessary arguments
162        for attr in [
163            "root",
164            "batch_size",
165            "num_workers",
166            "custom_transforms",
167            "repeat_channels",
168            "to_tensor",
169            "resize",
170        ]:
171            value = getattr(self, attr)
172            expected_keys = set(range(1, self.num_tasks + 1))
173            if set(value.keys()) != expected_keys:
174                raise ValueError(
175                    f"{attr} dict keys must be consecutive integers from 1 to num_tasks."
176                )

Sanity check.

@abstractmethod
def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
178    @abstractmethod
179    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
180        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses.
181
182        **Args:**
183        - **task_id** (`int`): the task ID to query the CL class map.
184
185        **Returns:**
186        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
187            - If `self.cl_paradigm` is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes.
188            - If `self.cl_paradigm` is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
189        """

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm. It must be implemented by subclasses.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of each task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of each task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
@abstractmethod
def prepare_data(self) -> None:
191    @abstractmethod
192    def prepare_data(self) -> None:
193        r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDataModule`. This method is called at the beginning of each task."""

Use this to download and prepare data. It must be implemented by subclasses, as required by LightningDataModule. This method is called at the beginning of each task.

def setup(self, stage: str) -> None:
195    def setup(self, stage: str) -> None:
196        r"""Set up the dataset for different stages. This method is called at the beginning of each task.
197
198        **Args:**
199        - **stage** (`str`): the stage of the experiment; one of:
200            - 'fit': training and validation datasets of the current task `self.task_id` are assigned to `self.dataset_train_t` and `self.dataset_val_t`.
201            - 'test': a dict of test datasets of all seen tasks should be assigned to `self.dataset_test`.
202        """
203        if stage == "fit":
204            # these two stages must be done together because a sanity check for validation is conducted before training
205            pylogger.debug(
206                "Construct train and validation dataset for task %d...", self.task_id
207            )
208
209            self.dataset_train_t, self.dataset_val_t = self.train_and_val_dataset()
210
211            pylogger.info(
212                "Train and validation dataset for task %d are ready.", self.task_id
213            )
214            pylogger.info(
215                "Train dataset for task %d size: %d",
216                self.task_id,
217                len(self.dataset_train_t),
218            )
219            pylogger.info(
220                "Validation dataset for task %d size: %d",
221                self.task_id,
222                len(self.dataset_val_t),
223            )
224
225        elif stage == "test":
226
227            pylogger.debug("Construct test dataset for task %d...", self.task_id)
228
229            self.dataset_test[self.task_id] = self.test_dataset()
230
231            pylogger.info("Test dataset for task %d are ready.", self.task_id)
232            pylogger.info(
233                "Test dataset for task %d size: %d",
234                self.task_id,
235                len(self.dataset_test[self.task_id]),
236            )

Set up the dataset for different stages. This method is called at the beginning of each task.

Args:

  • stage (str): the stage of the experiment; one of:
    • 'fit': training and validation datasets of the current task self.task_id are assigned to self.dataset_train_t and self.dataset_val_t.
    • 'test': a dict of test datasets of all seen tasks should be assigned to self.dataset_test.
def setup_task_id(self, task_id: int) -> None:
238    def setup_task_id(self, task_id: int) -> None:
239        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
240
241        **Args:**
242        - **task_id** (`int`): the target task ID.
243        """
244
245        self.task_id = task_id
246
247        self.root_t = self.root[task_id]
248        self.batch_size_t = self.batch_size[task_id]
249        self.num_workers_t = self.num_workers[task_id]
250        self.custom_transforms_t = self.custom_transforms[task_id]
251        self.repeat_channels_t = self.repeat_channels[task_id]
252        self.to_tensor_t = self.to_tensor[task_id]
253        self.resize_t = self.resize[task_id]
254
255        self.processed_task_ids.append(task_id)

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
257    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
258        r"""Set up tasks for continual learning main evaluation.
259
260        **Args:**
261        - **eval_tasks** (`list[int]`): the list of task IDs to evaluate.
262        """
263        for task_id in eval_tasks:
264            self.setup_task_id(task_id=task_id)
265            self.setup(stage="test")

Set up tasks for continual learning main evaluation.

Args:

  • eval_tasks (list[int]): the list of task IDs to evaluate.
def set_cl_paradigm(self, cl_paradigm: str) -> None:
267    def set_cl_paradigm(self, cl_paradigm: str) -> None:
268        r"""Set `cl_paradigm` to `self.cl_paradigm`. It is used to define the CL class map.
269
270        **Args:**
271        - **cl_paradigm** (`str`): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
272        """
273        self.cl_paradigm = cl_paradigm

Set cl_paradigm to self.cl_paradigm. It is used to define the CL class map.

Args:

  • cl_paradigm (str): the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
def train_and_val_transforms(self) -> torchvision.transforms.transforms.Compose:
275    def train_and_val_transforms(self) -> transforms.Compose:
276        r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. It can be used in subclasses when constructing the dataset.
277
278        **Returns:**
279        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
280        """
281        repeat_channels_transform = (
282            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
283            if self.repeat_channels_t is not None
284            else None
285        )
286        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
287        resize_transform = (
288            transforms.Resize(self.resize_t) if self.resize_t is not None else None
289        )
290        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
291
292        return transforms.Compose(
293            list(
294                filter(
295                    None,
296                    [
297                        repeat_channels_transform,
298                        to_tensor_transform,
299                        resize_transform,
300                        self.custom_transforms_t,
301                        normalization_transform,
302                    ],
303                )
304            )
305        )  # the order of transforms matters

Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and ToTensor(). It can be used in subclasses when constructing the dataset.

Returns:

  • train_and_val_transforms (transforms.Compose): the composed train/val transforms.
def test_transforms(self) -> torchvision.transforms.transforms.Compose:
307    def test_transforms(self) -> transforms.Compose:
308        r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. It is used in subclasses when constructing the dataset.
309
310        **Returns:**
311        - **test_transforms** (`transforms.Compose`): the composed test transforms.
312        """
313        repeat_channels_transform = (
314            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
315            if self.repeat_channels_t is not None
316            else None
317        )
318        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
319        resize_transform = (
320            transforms.Resize(self.resize_t) if self.resize_t is not None else None
321        )
322        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
323
324        return transforms.Compose(
325            list(
326                filter(
327                    None,
328                    [
329                        repeat_channels_transform,
330                        to_tensor_transform,
331                        resize_transform,
332                        normalization_transform,
333                    ],
334                )
335            )
336        )  # the order of transforms matters. No custom transforms for test

Transforms for the test dataset. Only basic transforms like normalization and ToTensor() are included. It is used in subclasses when constructing the dataset.

Returns:

  • test_transforms (transforms.Compose): the composed test transforms.
def target_transform(self) -> clarena.utils.transforms.ClassMapping:
338    def target_transform(self) -> ClassMapping:
339        r"""Target transform to map the original class labels to CL class labels according to `self.cl_paradigm`. It can be used in subclasses when constructing the dataset.
340
341        **Returns:**
342        - **target_transform** (`Callable`): the target transform function.
343        """
344
345        cl_class_map = self.get_cl_class_map(task_id=self.task_id)
346
347        target_transform = ClassMapping(class_map=cl_class_map)
348
349        return target_transform

Target transform to map the original class labels to CL class labels according to self.cl_paradigm. It can be used in subclasses when constructing the dataset.

Returns:

  • target_transform (Callable): the target transform function.
@abstractmethod
def train_and_val_dataset(self) -> tuple[typing.Any, typing.Any]:
351    @abstractmethod
352    def train_and_val_dataset(self) -> tuple[Any, Any]:
353        r"""Get the training and validation datasets of the current task `self.task_id`. It must be implemented by subclasses.
354
355        **Returns:**
356        - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation datasets of the current task `self.task_id`.
357        """

Get the training and validation datasets of the current task self.task_id. It must be implemented by subclasses.

Returns:

  • train_and_val_dataset (tuple[Any, Any]): the train and validation datasets of the current task self.task_id.
@abstractmethod
def test_dataset(self) -> Any:
359    @abstractmethod
360    def test_dataset(self) -> Any:
361        r"""Get the test dataset of the current task `self.task_id`. It must be implemented by subclasses.
362
363        **Returns:**
364        - **test_dataset** (`Any`): the test dataset of the current task `self.task_id`.
365        """

Get the test dataset of the current task self.task_id. It must be implemented by subclasses.

Returns:

  • test_dataset (Any): the test dataset of the current task self.task_id.
def train_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
367    def train_dataloader(self) -> DataLoader:
368        r"""DataLoader generator for the train stage of the current task `self.task_id`. It is automatically called before training the task.
369
370        **Returns:**
371        - **train_dataloader** (`DataLoader`): the train DataLoader of task `self.task_id`.
372        """
373
374        pylogger.debug("Construct train dataloader for task %d...", self.task_id)
375
376        return DataLoader(
377            dataset=self.dataset_train_t,
378            batch_size=self.batch_size_t,
379            shuffle=True,  # shuffle train batch to prevent overfitting
380            num_workers=self.num_workers_t,
381            drop_last=True,  # to avoid batchnorm error (when batch_size is 1)
382        )

DataLoader generator for the train stage of the current task self.task_id. It is automatically called before training the task.

Returns:

  • train_dataloader (DataLoader): the train DataLoader of task self.task_id.
def val_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
384    def val_dataloader(self) -> DataLoader:
385        r"""DataLoader generator for the validation stage of the current task `self.task_id`. It is automatically called before the task's validation.
386
387        **Returns:**
388        - **val_dataloader** (`DataLoader`): the validation DataLoader of task `self.task_id`.
389        """
390
391        pylogger.debug("Construct validation dataloader for task %d...", self.task_id)
392
393        return DataLoader(
394            dataset=self.dataset_val_t,
395            batch_size=self.batch_size_t,
396            shuffle=False,  # don't have to shuffle val or test batch
397            num_workers=self.num_workers_t,
398        )

DataLoader generator for the validation stage of the current task self.task_id. It is automatically called before the task's validation.

Returns:

  • val_dataloader (DataLoader): the validation DataLoader of task self.task_id.
def test_dataloader(self) -> dict[int, torch.utils.data.dataloader.DataLoader]:
400    def test_dataloader(self) -> dict[int, DataLoader]:
401        r"""DataLoader generator for the test stage of the current task `self.task_id`. It is automatically called before testing the task.
402
403        **Returns:**
404        - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders.
405        """
406
407        pylogger.debug("Construct test dataloader for task %d...", self.task_id)
408
409        return {
410            task_id: DataLoader(
411                dataset=dataset_test_t,
412                batch_size=self.batch_size_t,
413                shuffle=False,  # don't have to shuffle val or test batch
414                num_workers=self.num_workers_t,
415            )
416            for task_id, dataset_test_t in self.dataset_test.items()
417        }

DataLoader generator for the test stage of the current task self.task_id. It is automatically called before testing the task.

Returns:

  • test_dataloader (dict[int, DataLoader]): the test DataLoader dict of self.task_id and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs and values are the DataLoaders.
class CLPermutedDataset(clarena.cl_datasets.CLDataset):
428class CLPermutedDataset(CLDataset):
429    r"""The base class of continual learning datasets constructed as permutations of an original dataset."""
430
431    original_dataset_python_class: type[Dataset]
432    r"""The original dataset class. **It must be provided in subclasses.** """
433
434    def __init__(
435        self,
436        root: str,
437        num_tasks: int,
438        batch_size: int | dict[int, int] = 1,
439        num_workers: int | dict[int, int] = 0,
440        custom_transforms: (
441            Callable
442            | transforms.Compose
443            | None
444            | dict[int, Callable | transforms.Compose | None]
445        ) = None,
446        repeat_channels: int | None | dict[int, int | None] = None,
447        to_tensor: bool | dict[int, bool] = True,
448        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
449        permutation_mode: str = "first_channel_only",
450        permutation_seeds: dict[int, int] | None = None,
451    ) -> None:
452        r"""
453        **Args:**
454        - **root** (`str`): the root directory where the original dataset live.
455        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
456        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
457        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
458        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
459        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
460        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
461        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
462        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
463        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
464        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
465        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
466        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
467        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
468        - **permutation_mode** (`str`): the mode of permutation; one of:
469            1. 'all': permute all pixels.
470            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
471            3. 'first_channel_only': permute only the first channel.
472        - **permutation_seeds** (`dict[int, int]` | `None`): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is `None`, which creates a dict of seeds from 0 to `num_tasks`-1.
473        """
474        super().__init__(
475            root=root,
476            num_tasks=num_tasks,
477            batch_size=batch_size,
478            num_workers=num_workers,
479            custom_transforms=custom_transforms,
480            repeat_channels=repeat_channels,
481            to_tensor=to_tensor,
482            resize=resize,
483        )
484
485        self.original_dataset_constants: type[DatasetConstants] = (
486            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
487        )
488        r"""The original dataset constants class."""
489
490        self.permutation_mode: str = permutation_mode
491        r"""The mode of permutation."""
492        self.permutation_seeds: dict[int, int] = (
493            permutation_seeds
494            if permutation_seeds
495            else {t: t - 1 for t in range(1, num_tasks + 1)}
496        )
497        r"""The dict of permutation seeds for each task."""
498
499        self.permutation_seed_t: int
500        r"""The permutation seed for the current task `self.task_id`."""
501        self.permute_transform_t: Permute
502        r"""The permutation transform for the current task `self.task_id`."""
503
504        CLPermutedDataset.sanity_check(self)
505
506    def sanity_check(self) -> None:
507        r"""Sanity check."""
508
509        # check the permutation mode
510        if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]:
511            raise ValueError(
512                "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'."
513            )
514
515        # check the permutation seeds
516        expected_keys = set(range(1, self.num_tasks + 1))
517        if set(self.permutation_seeds.keys()) != expected_keys:
518            raise ValueError(
519                f"{self.permutation_seeds} dict keys must be consecutive integers from 1 to num_tasks."
520            )
521
522    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
523        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
524
525        **Args:**
526        - **task_id** (`int`): the task ID to query the CL class map.
527
528        **Returns:**
529        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
530            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
531            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
532        """
533
534        num_classes_t = (
535            self.original_dataset_constants.NUM_CLASSES
536        )  # the same with the original dataset
537        class_map_t = (
538            self.original_dataset_constants.CLASS_MAP
539        )  # the same with the original dataset
540
541        if self.cl_paradigm == "TIL":
542            return {class_map_t[i]: i for i in range(num_classes_t)}
543        if self.cl_paradigm == "CIL":
544            return {
545                class_map_t[i]: i + (task_id - 1) * num_classes_t
546                for i in range(num_classes_t)
547            }
548
549    def setup_task_id(self, task_id: int) -> None:
550        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
551
552        **Args:**
553        - **task_id** (`int`): the target task ID.
554        """
555
556        CLDataset.setup_task_id(self, task_id)
557
558        self.mean_t = (
559            self.original_dataset_constants.MEAN
560        )  # the same with the original dataset
561        self.std_t = (
562            self.original_dataset_constants.STD
563        )  # the same with the original dataset
564
565        num_channels = (
566            self.original_dataset_constants.NUM_CHANNELS
567            if self.repeat_channels_t is None
568            else self.repeat_channels_t
569        )
570
571        if (
572            hasattr(self.original_dataset_constants, "IMG_SIZE")
573            or self.resize_t is not None
574        ):
575            img_size = (
576                self.original_dataset_constants.IMG_SIZE
577                if self.resize_t is None
578                else torch.Size(self.resize_t)
579            )
580        else:
581            raise AttributeError(
582                "The original dataset has different image sizes. Please resize the images to a fixed size by specifying hyperparameter: resize."
583            )
584
585        # set up the permutation transform
586        self.permutation_seed_t = self.permutation_seeds[task_id]
587        self.permute_transform_t = Permute(
588            num_channels=num_channels,
589            img_size=img_size,
590            mode=self.permutation_mode,
591            seed=self.permutation_seed_t,
592        )
593
594    def train_and_val_transforms(self) -> transforms.Compose:
595        r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. In permuted CL datasets, a permute transform also applies.
596
597        **Returns:**
598        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
599        """
600
601        repeat_channels_transform = (
602            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
603            if self.repeat_channels_t is not None
604            else None
605        )
606        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
607        resize_transform = (
608            transforms.Resize(self.resize_t) if self.resize_t is not None else None
609        )
610        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
611
612        return transforms.Compose(
613            list(
614                filter(
615                    None,
616                    [
617                        repeat_channels_transform,
618                        to_tensor_transform,
619                        resize_transform,
620                        self.permute_transform_t,  # permutation is included here
621                        self.custom_transforms_t,
622                        normalization_transform,
623                    ],
624                )
625            )
626        )  # the order of transforms matters
627
628    def test_transforms(self) -> transforms.Compose:
629        r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. In permuted CL datasets, a permute transform also applies.
630
631        **Returns:**
632        - **test_transforms** (`transforms.Compose`): the composed test transforms.
633        """
634
635        repeat_channels_transform = (
636            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
637            if self.repeat_channels_t is not None
638            else None
639        )
640        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
641        resize_transform = (
642            transforms.Resize(self.resize_t) if self.resize_t is not None else None
643        )
644        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
645
646        return transforms.Compose(
647            list(
648                filter(
649                    None,
650                    [
651                        repeat_channels_transform,
652                        to_tensor_transform,
653                        resize_transform,
654                        self.permute_transform_t,  # permutation is included here
655                        normalization_transform,
656                    ],
657                )
658            )
659        )  # the order of transforms matters. No custom transforms for test

The base class of continual learning datasets constructed as permutations of an original dataset.

CLPermutedDataset( root: str, num_tasks: int, batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, permutation_mode: str = 'first_channel_only', permutation_seeds: dict[int, int] | None = None)
434    def __init__(
435        self,
436        root: str,
437        num_tasks: int,
438        batch_size: int | dict[int, int] = 1,
439        num_workers: int | dict[int, int] = 0,
440        custom_transforms: (
441            Callable
442            | transforms.Compose
443            | None
444            | dict[int, Callable | transforms.Compose | None]
445        ) = None,
446        repeat_channels: int | None | dict[int, int | None] = None,
447        to_tensor: bool | dict[int, bool] = True,
448        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
449        permutation_mode: str = "first_channel_only",
450        permutation_seeds: dict[int, int] | None = None,
451    ) -> None:
452        r"""
453        **Args:**
454        - **root** (`str`): the root directory where the original dataset live.
455        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
456        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
457        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
458        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
459        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
460        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
461        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
462        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
463        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
464        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
465        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
466        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
467        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
468        - **permutation_mode** (`str`): the mode of permutation; one of:
469            1. 'all': permute all pixels.
470            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
471            3. 'first_channel_only': permute only the first channel.
472        - **permutation_seeds** (`dict[int, int]` | `None`): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is `None`, which creates a dict of seeds from 0 to `num_tasks`-1.
473        """
474        super().__init__(
475            root=root,
476            num_tasks=num_tasks,
477            batch_size=batch_size,
478            num_workers=num_workers,
479            custom_transforms=custom_transforms,
480            repeat_channels=repeat_channels,
481            to_tensor=to_tensor,
482            resize=resize,
483        )
484
485        self.original_dataset_constants: type[DatasetConstants] = (
486            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
487        )
488        r"""The original dataset constants class."""
489
490        self.permutation_mode: str = permutation_mode
491        r"""The mode of permutation."""
492        self.permutation_seeds: dict[int, int] = (
493            permutation_seeds
494            if permutation_seeds
495            else {t: t - 1 for t in range(1, num_tasks + 1)}
496        )
497        r"""The dict of permutation seeds for each task."""
498
499        self.permutation_seed_t: int
500        r"""The permutation seed for the current task `self.task_id`."""
501        self.permute_transform_t: Permute
502        r"""The permutation transform for the current task `self.task_id`."""
503
504        CLPermutedDataset.sanity_check(self)

Args:

  • root (str): the root directory where the original dataset live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to num_tasks.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
  • permutation_mode (str): the mode of permutation; one of:
    1. 'all': permute all pixels.
    2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    3. 'first_channel_only': permute only the first channel.
  • permutation_seeds (dict[int, int] | None): the dict of seeds for permutation operations used to construct each task. Keys are task IDs and the values are permutation seeds for each task. Default is None, which creates a dict of seeds from 0 to num_tasks-1.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset]

The original dataset class. It must be provided in subclasses.

original_dataset_constants: type[clarena.stl_datasets.raw.constants.DatasetConstants]

The original dataset constants class.

permutation_mode: str

The mode of permutation.

permutation_seeds: dict[int, int]

The dict of permutation seeds for each task.

permutation_seed_t: int

The permutation seed for the current task self.task_id.

permute_transform_t: clarena.utils.transforms.Permute

The permutation transform for the current task self.task_id.

def sanity_check(self) -> None:
506    def sanity_check(self) -> None:
507        r"""Sanity check."""
508
509        # check the permutation mode
510        if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]:
511            raise ValueError(
512                "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'."
513            )
514
515        # check the permutation seeds
516        expected_keys = set(range(1, self.num_tasks + 1))
517        if set(self.permutation_seeds.keys()) != expected_keys:
518            raise ValueError(
519                f"{self.permutation_seeds} dict keys must be consecutive integers from 1 to num_tasks."
520            )

Sanity check.

def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
522    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
523        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
524
525        **Args:**
526        - **task_id** (`int`): the task ID to query the CL class map.
527
528        **Returns:**
529        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
530            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
531            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
532        """
533
534        num_classes_t = (
535            self.original_dataset_constants.NUM_CLASSES
536        )  # the same with the original dataset
537        class_map_t = (
538            self.original_dataset_constants.CLASS_MAP
539        )  # the same with the original dataset
540
541        if self.cl_paradigm == "TIL":
542            return {class_map_t[i]: i for i in range(num_classes_t)}
543        if self.cl_paradigm == "CIL":
544            return {
545                class_map_t[i]: i + (task_id - 1) * num_classes_t
546                for i in range(num_classes_t)
547            }

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
549    def setup_task_id(self, task_id: int) -> None:
550        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
551
552        **Args:**
553        - **task_id** (`int`): the target task ID.
554        """
555
556        CLDataset.setup_task_id(self, task_id)
557
558        self.mean_t = (
559            self.original_dataset_constants.MEAN
560        )  # the same with the original dataset
561        self.std_t = (
562            self.original_dataset_constants.STD
563        )  # the same with the original dataset
564
565        num_channels = (
566            self.original_dataset_constants.NUM_CHANNELS
567            if self.repeat_channels_t is None
568            else self.repeat_channels_t
569        )
570
571        if (
572            hasattr(self.original_dataset_constants, "IMG_SIZE")
573            or self.resize_t is not None
574        ):
575            img_size = (
576                self.original_dataset_constants.IMG_SIZE
577                if self.resize_t is None
578                else torch.Size(self.resize_t)
579            )
580        else:
581            raise AttributeError(
582                "The original dataset has different image sizes. Please resize the images to a fixed size by specifying hyperparameter: resize."
583            )
584
585        # set up the permutation transform
586        self.permutation_seed_t = self.permutation_seeds[task_id]
587        self.permute_transform_t = Permute(
588            num_channels=num_channels,
589            img_size=img_size,
590            mode=self.permutation_mode,
591            seed=self.permutation_seed_t,
592        )

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def train_and_val_transforms(self) -> torchvision.transforms.transforms.Compose:
594    def train_and_val_transforms(self) -> transforms.Compose:
595        r"""Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and `ToTensor()`. In permuted CL datasets, a permute transform also applies.
596
597        **Returns:**
598        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
599        """
600
601        repeat_channels_transform = (
602            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
603            if self.repeat_channels_t is not None
604            else None
605        )
606        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
607        resize_transform = (
608            transforms.Resize(self.resize_t) if self.resize_t is not None else None
609        )
610        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
611
612        return transforms.Compose(
613            list(
614                filter(
615                    None,
616                    [
617                        repeat_channels_transform,
618                        to_tensor_transform,
619                        resize_transform,
620                        self.permute_transform_t,  # permutation is included here
621                        self.custom_transforms_t,
622                        normalization_transform,
623                    ],
624                )
625            )
626        )  # the order of transforms matters

Transforms for training and validation datasets, incorporating the custom transforms with basic transforms like normalization and ToTensor(). In permuted CL datasets, a permute transform also applies.

Returns:

  • train_and_val_transforms (transforms.Compose): the composed train/val transforms.
def test_transforms(self) -> torchvision.transforms.transforms.Compose:
628    def test_transforms(self) -> transforms.Compose:
629        r"""Transforms for the test dataset. Only basic transforms like normalization and `ToTensor()` are included. In permuted CL datasets, a permute transform also applies.
630
631        **Returns:**
632        - **test_transforms** (`transforms.Compose`): the composed test transforms.
633        """
634
635        repeat_channels_transform = (
636            transforms.Grayscale(num_output_channels=self.repeat_channels_t)
637            if self.repeat_channels_t is not None
638            else None
639        )
640        to_tensor_transform = transforms.ToTensor() if self.to_tensor_t else None
641        resize_transform = (
642            transforms.Resize(self.resize_t) if self.resize_t is not None else None
643        )
644        normalization_transform = transforms.Normalize(self.mean_t, self.std_t)
645
646        return transforms.Compose(
647            list(
648                filter(
649                    None,
650                    [
651                        repeat_channels_transform,
652                        to_tensor_transform,
653                        resize_transform,
654                        self.permute_transform_t,  # permutation is included here
655                        normalization_transform,
656                    ],
657                )
658            )
659        )  # the order of transforms matters. No custom transforms for test

Transforms for the test dataset. Only basic transforms like normalization and ToTensor() are included. In permuted CL datasets, a permute transform also applies.

Returns:

  • test_transforms (transforms.Compose): the composed test transforms.
class CLSplitDataset(clarena.cl_datasets.CLDataset):
662class CLSplitDataset(CLDataset):
663    r"""The base class of continual learning datasets constructed as splits of an original dataset."""
664
665    original_dataset_python_class: type[Dataset]
666    r"""The original dataset class. **It must be provided in subclasses.** """
667
668    def __init__(
669        self,
670        root: str,
671        class_split: dict[int, list[int]],
672        batch_size: int | dict[int, int] = 1,
673        num_workers: int | dict[int, int] = 0,
674        custom_transforms: (
675            Callable
676            | transforms.Compose
677            | None
678            | dict[int, Callable | transforms.Compose | None]
679        ) = None,
680        repeat_channels: int | None | dict[int, int | None] = None,
681        to_tensor: bool | dict[int, bool] = True,
682        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
683    ) -> None:
684        r"""
685        **Args:**
686        - **root** (`str`): the root directory where the original dataset live.
687        - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
688        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
689        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
690        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
691        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
692        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
693        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
694        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
695        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
696        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
697        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
698        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
699        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
700        """
701        super().__init__(
702            root=root,
703            num_tasks=len(
704                class_split
705            ),  # num_tasks is not explicitly provided, but derived from the class_split length
706            batch_size=batch_size,
707            num_workers=num_workers,
708            custom_transforms=custom_transforms,
709            repeat_channels=repeat_channels,
710            to_tensor=to_tensor,
711            resize=resize,
712        )
713
714        self.original_dataset_constants: type[DatasetConstants] = (
715            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
716        )
717        r"""The original dataset constants class. """
718
719        self.class_split: dict[int, list[int]] = OmegaConf.to_container(class_split)
720        r"""The dict of class splits for each task."""
721
722        CLSplitDataset.sanity_check(self)
723
724    def sanity_check(self) -> None:
725        r"""Sanity check."""
726
727        # check the class split
728        expected_keys = set(range(1, self.num_tasks + 1))
729        if set(self.class_split.keys()) != expected_keys:
730            raise ValueError(
731                f"{self.class_split} dict keys must be consecutive integers from 1 to num_tasks."
732            )
733        if any(len(split) < 2 for split in self.class_split.values()):
734            raise ValueError("Each class split must contain at least 2 elements!")
735
736    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
737        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
738
739        **Args:**
740        - **task_id** (`int`): the task ID to query the CL class map.
741
742        **Returns:**
743        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
744            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
745            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
746        """
747        num_classes_t = len(
748            self.class_split[task_id]
749        )  # the number of classes in the current task, i.e. the length of the class split
750        class_map_t = (
751            self.original_dataset_constants.CLASS_MAP
752        )  # the same with the original dataset
753
754        if self.cl_paradigm == "TIL":
755            return {
756                class_map_t[self.class_split[task_id][i]]: i
757                for i in range(num_classes_t)
758            }
759        if self.cl_paradigm == "CIL":
760            num_classes_previous = sum(
761                len(self.class_split[i]) for i in range(1, task_id)
762            )
763            return {
764                class_map_t[self.class_split[task_id][i]]: num_classes_previous + i
765                for i in range(num_classes_t)
766            }
767
768    def setup_task_id(self, task_id: int) -> None:
769        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
770
771        **Args:**
772        - **task_id** (`int`): the target task ID.
773        """
774        super().setup_task_id(task_id)
775
776        self.mean_t = (
777            self.original_dataset_constants.MEAN
778        )  # the same with the original dataset
779        self.std_t = (
780            self.original_dataset_constants.STD
781        )  # the same with the original dataset
782
783    @abstractmethod
784    def get_subset_of_classes(self, dataset: Dataset) -> Dataset:
785        r"""Get a subset of classes from the dataset for the current task `self.task_id`. It is used when constructing the split. **It must be implemented by subclasses.**
786
787        **Args:**
788        - **dataset** (`Dataset`): the dataset to retrieve the subset from.
789
790        **Returns:**
791        - **subset** (`Dataset`): the subset of classes from the dataset.
792        """

The base class of continual learning datasets constructed as splits of an original dataset.

CLSplitDataset( root: str, class_split: dict[int, list[int]], batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
668    def __init__(
669        self,
670        root: str,
671        class_split: dict[int, list[int]],
672        batch_size: int | dict[int, int] = 1,
673        num_workers: int | dict[int, int] = 0,
674        custom_transforms: (
675            Callable
676            | transforms.Compose
677            | None
678            | dict[int, Callable | transforms.Compose | None]
679        ) = None,
680        repeat_channels: int | None | dict[int, int | None] = None,
681        to_tensor: bool | dict[int, bool] = True,
682        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
683    ) -> None:
684        r"""
685        **Args:**
686        - **root** (`str`): the root directory where the original dataset live.
687        - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
688        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
689        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
690        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
691        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
692        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
693        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
694        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
695        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
696        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
697        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
698        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
699        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
700        """
701        super().__init__(
702            root=root,
703            num_tasks=len(
704                class_split
705            ),  # num_tasks is not explicitly provided, but derived from the class_split length
706            batch_size=batch_size,
707            num_workers=num_workers,
708            custom_transforms=custom_transforms,
709            repeat_channels=repeat_channels,
710            to_tensor=to_tensor,
711            resize=resize,
712        )
713
714        self.original_dataset_constants: type[DatasetConstants] = (
715            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class]
716        )
717        r"""The original dataset constants class. """
718
719        self.class_split: dict[int, list[int]] = OmegaConf.to_container(class_split)
720        r"""The dict of class splits for each task."""
721
722        CLSplitDataset.sanity_check(self)

Args:

  • root (str): the root directory where the original dataset live.
  • class_split (dict[int, list[int]]): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset]

The original dataset class. It must be provided in subclasses.

original_dataset_constants: type[clarena.stl_datasets.raw.constants.DatasetConstants]

The original dataset constants class.

class_split: dict[int, list[int]]

The dict of class splits for each task.

def sanity_check(self) -> None:
724    def sanity_check(self) -> None:
725        r"""Sanity check."""
726
727        # check the class split
728        expected_keys = set(range(1, self.num_tasks + 1))
729        if set(self.class_split.keys()) != expected_keys:
730            raise ValueError(
731                f"{self.class_split} dict keys must be consecutive integers from 1 to num_tasks."
732            )
733        if any(len(split) < 2 for split in self.class_split.values()):
734            raise ValueError("Each class split must contain at least 2 elements!")

Sanity check.

def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
736    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
737        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
738
739        **Args:**
740        - **task_id** (`int`): the task ID to query the CL class map.
741
742        **Returns:**
743        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
744            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
745            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
746        """
747        num_classes_t = len(
748            self.class_split[task_id]
749        )  # the number of classes in the current task, i.e. the length of the class split
750        class_map_t = (
751            self.original_dataset_constants.CLASS_MAP
752        )  # the same with the original dataset
753
754        if self.cl_paradigm == "TIL":
755            return {
756                class_map_t[self.class_split[task_id][i]]: i
757                for i in range(num_classes_t)
758            }
759        if self.cl_paradigm == "CIL":
760            num_classes_previous = sum(
761                len(self.class_split[i]) for i in range(1, task_id)
762            )
763            return {
764                class_map_t[self.class_split[task_id][i]]: num_classes_previous + i
765                for i in range(num_classes_t)
766            }

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
768    def setup_task_id(self, task_id: int) -> None:
769        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
770
771        **Args:**
772        - **task_id** (`int`): the target task ID.
773        """
774        super().setup_task_id(task_id)
775
776        self.mean_t = (
777            self.original_dataset_constants.MEAN
778        )  # the same with the original dataset
779        self.std_t = (
780            self.original_dataset_constants.STD
781        )  # the same with the original dataset

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
@abstractmethod
def get_subset_of_classes( self, dataset: torch.utils.data.dataset.Dataset) -> torch.utils.data.dataset.Dataset:
783    @abstractmethod
784    def get_subset_of_classes(self, dataset: Dataset) -> Dataset:
785        r"""Get a subset of classes from the dataset for the current task `self.task_id`. It is used when constructing the split. **It must be implemented by subclasses.**
786
787        **Args:**
788        - **dataset** (`Dataset`): the dataset to retrieve the subset from.
789
790        **Returns:**
791        - **subset** (`Dataset`): the subset of classes from the dataset.
792        """

Get a subset of classes from the dataset for the current task self.task_id. It is used when constructing the split. It must be implemented by subclasses.

Args:

  • dataset (Dataset): the dataset to retrieve the subset from.

Returns:

  • subset (Dataset): the subset of classes from the dataset.
class CLCombinedDataset(clarena.cl_datasets.CLDataset):
795class CLCombinedDataset(CLDataset):
796    r"""The base class of continual learning datasets constructed as combinations of several single-task datasets (one dataset per task)."""
797
798    def __init__(
799        self,
800        datasets: dict[int, str],
801        root: str | dict[int, str],
802        batch_size: int | dict[int, int] = 1,
803        num_workers: int | dict[int, int] = 0,
804        custom_transforms: (
805            Callable
806            | transforms.Compose
807            | None
808            | dict[int, Callable | transforms.Compose | None]
809        ) = None,
810        repeat_channels: int | None | dict[int, int | None] = None,
811        to_tensor: bool | dict[int, bool] = True,
812        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
813    ) -> None:
814        r"""
815        **Args:**
816        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
817        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live.
818        If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
819        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
820        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
821        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
822        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
823        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
824        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
825        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
826        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
827        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
828        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
829        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
830        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
831        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
832        """
833        super().__init__(
834            root=root,
835            num_tasks=len(
836                datasets
837            ),  # num_tasks is not explicitly provided, but derived from the datasets length
838            batch_size=batch_size,
839            num_workers=num_workers,
840            custom_transforms=custom_transforms,
841            repeat_channels=repeat_channels,
842            to_tensor=to_tensor,
843            resize=resize,
844        )
845
846        self.original_dataset_python_classes: dict[int, Dataset] = {
847            t: str_to_class(dataset_class_path)
848            for t, dataset_class_path in datasets.items()
849        }
850        r"""The dict of dataset classes for each task."""
851        self.original_dataset_python_class_t: Dataset
852        r"""The dataset class for the current task `self.task_id`."""
853        self.original_dataset_constants_t: type[DatasetConstants]
854        r"""The original dataset constants class for the current task `self.task_id`."""
855
856    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
857        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
858
859        **Args:**
860        - **task_id** (`int`): the task ID to query the CL class map.
861
862        **Returns:**
863        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
864            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
865            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
866        """
867        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
868        original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[
869            original_dataset_python_class_t
870        ]
871        num_classes_t = original_dataset_constants_t.NUM_CLASSES
872        class_map_t = original_dataset_constants_t.CLASS_MAP
873
874        if self.cl_paradigm == "TIL":
875            return {class_map_t[i]: i for i in range(num_classes_t)}
876        if self.cl_paradigm == "CIL":
877            num_classes_previous = sum(
878                [
879                    DATASET_CONSTANTS_MAPPING[
880                        self.original_dataset_python_classes[i]
881                    ].NUM_CLASSES
882                    for i in range(1, task_id)
883                ]
884            )
885            return {
886                class_map_t[i]: num_classes_previous + i for i in range(num_classes_t)
887            }
888
889    def setup_task_id(self, task_id: int) -> None:
890        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
891
892        **Args:**
893        - **task_id** (`int`): the target task ID.
894        """
895
896        self.original_dataset_python_class_t = self.original_dataset_python_classes[
897            task_id
898        ]
899
900        self.original_dataset_constants_t: type[DatasetConstants] = (
901            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class_t]
902        )
903
904        super().setup_task_id(task_id)
905
906        self.mean_t = self.original_dataset_constants_t.MEAN
907        self.std_t = self.original_dataset_constants_t.STD

The base class of continual learning datasets constructed as combinations of several single-task datasets (one dataset per task).

CLCombinedDataset( datasets: dict[int, str], root: str | dict[int, str], batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
798    def __init__(
799        self,
800        datasets: dict[int, str],
801        root: str | dict[int, str],
802        batch_size: int | dict[int, int] = 1,
803        num_workers: int | dict[int, int] = 0,
804        custom_transforms: (
805            Callable
806            | transforms.Compose
807            | None
808            | dict[int, Callable | transforms.Compose | None]
809        ) = None,
810        repeat_channels: int | None | dict[int, int | None] = None,
811        to_tensor: bool | dict[int, bool] = True,
812        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
813    ) -> None:
814        r"""
815        **Args:**
816        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
817        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the CL dataset physically live.
818        If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
819        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to `num_tasks`.
820        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
821        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
822        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
823        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
824        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
825        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
826        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
827        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
828        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
829        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
830        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
831        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
832        """
833        super().__init__(
834            root=root,
835            num_tasks=len(
836                datasets
837            ),  # num_tasks is not explicitly provided, but derived from the datasets length
838            batch_size=batch_size,
839            num_workers=num_workers,
840            custom_transforms=custom_transforms,
841            repeat_channels=repeat_channels,
842            to_tensor=to_tensor,
843            resize=resize,
844        )
845
846        self.original_dataset_python_classes: dict[int, Dataset] = {
847            t: str_to_class(dataset_class_path)
848            for t, dataset_class_path in datasets.items()
849        }
850        r"""The dict of dataset classes for each task."""
851        self.original_dataset_python_class_t: Dataset
852        r"""The dataset class for the current task `self.task_id`."""
853        self.original_dataset_constants_t: type[DatasetConstants]
854        r"""The original dataset constants class for the current task `self.task_id`."""

Args:

  • datasets (dict[int, str]): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
  • root (str | dict[int, str]): the root directory where the original data files for constructing the CL dataset physically live. If it is a dict, the keys are task IDs and the values are the root directories for each task. If it is a string, it is the same root directory for all tasks.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset. This decides the valid task IDs from 1 to num_tasks.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
original_dataset_python_classes: dict[int, torch.utils.data.dataset.Dataset]

The dict of dataset classes for each task.

original_dataset_python_class_t: torch.utils.data.dataset.Dataset

The dataset class for the current task self.task_id.

original_dataset_constants_t: type[clarena.stl_datasets.raw.constants.DatasetConstants]

The original dataset constants class for the current task self.task_id.

def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
856    def get_cl_class_map(self, task_id: int) -> dict[str | int, int]:
857        r"""Get the mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
858
859        **Args:**
860        - **task_id** (`int`): the task ID to query the CL class map.
861
862        **Returns:**
863        - **cl_class_map** (`dict[str | int, int]`): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
864            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
865            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
866        """
867        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
868        original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[
869            original_dataset_python_class_t
870        ]
871        num_classes_t = original_dataset_constants_t.NUM_CLASSES
872        class_map_t = original_dataset_constants_t.CLASS_MAP
873
874        if self.cl_paradigm == "TIL":
875            return {class_map_t[i]: i for i in range(num_classes_t)}
876        if self.cl_paradigm == "CIL":
877            num_classes_previous = sum(
878                [
879                    DATASET_CONSTANTS_MAPPING[
880                        self.original_dataset_python_classes[i]
881                    ].NUM_CLASSES
882                    for i in range(1, task_id)
883                ]
884            )
885            return {
886                class_map_t[i]: num_classes_previous + i for i in range(num_classes_t)
887            }

Get the mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): the task ID to query the CL class map.

Returns:

  • cl_class_map (dict[str | int, int]): the CL class map of the task. Keys are the original class labels and values are the integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
889    def setup_task_id(self, task_id: int) -> None:
890        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
891
892        **Args:**
893        - **task_id** (`int`): the target task ID.
894        """
895
896        self.original_dataset_python_class_t = self.original_dataset_python_classes[
897            task_id
898        ]
899
900        self.original_dataset_constants_t: type[DatasetConstants] = (
901            DATASET_CONSTANTS_MAPPING[self.original_dataset_python_class_t]
902        )
903
904        super().setup_task_id(task_id)
905
906        self.mean_t = self.original_dataset_constants_t.MEAN
907        self.std_t = self.original_dataset_constants_t.STD

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.