clarena.cl_datasets

Continual Learning Datasets

This submodule provides the continual learning datasets that can be used in CLArena.

Please note that this is an API documantation. Please refer to the main documentation pages for more information about the CL datasets and how to configure and implement them:

The datasets are implemented as subclasses of CLDataset classes, which are the base class for all continual learning datasets in CLArena.

 1r"""
 2
 3# Continual Learning Datasets
 4
 5This submodule provides the **continual learning datasets** that can be used in CLArena. 
 6
 7Please note that this is an API documantation. Please refer to the main documentation pages for more information about the CL datasets and how to configure and implement them:
 8
 9- [**Configure CL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiment/cl-dataset)
10- [**Implement Your CL Dataset Class**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-dataset)
11- [**A Beginners' Guide to Continual Learning (CL Dataset)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-CL-dataset)
12
13The datasets are implemented as subclasses of `CLDataset` classes, which are the base class for all continual learning datasets in CLArena.
14
15- `CLDataset`: The base class for continual learning datasets.
16- `CLPermutedDataset`: The base class for permuted continual learning datasets. A child class of `CLDataset`.
17
18"""
19
20from .base import CLClassMapping, CLDataset, CLPermutedDataset, CLSplitDataset, Permute
21from .permuted_mnist import PermutedMNIST
22from .split_cifar100 import SplitCIFAR100
23
24__all__ = [
25    "CLDataset",
26    "CLPermutedDataset",
27    "CLSplitDataset",
28    "CLClassMapping",
29    "Permute",
30    "permuted_mnist",
31    "split_cifar100",
32]
class CLDataset(lightning.pytorch.core.datamodule.LightningDataModule):
 21class CLDataset(LightningDataModule):
 22    r"""The base class of continual learning datasets, inherited from `LightningDataModule`."""
 23
 24    def __init__(
 25        self,
 26        root: str,
 27        num_tasks: int,
 28        validation_percentage: float,
 29        batch_size: int = 1,
 30        num_workers: int = 10,
 31        custom_transforms: Callable | transforms.Compose | None = None,
 32        custom_target_transforms: Callable | transforms.Compose | None = None,
 33    ) -> None:
 34        r"""Initialise the CL dataset object providing the root where data files live.
 35
 36        **Args:**
 37        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
 38        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
 39        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
 40        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 41        - **num_workers** (`int`): the number of workers for dataloaders.
 42        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
 43        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
 44        """
 45        LightningDataModule.__init__(self)
 46
 47        self.root: str = root
 48        r"""Store the root directory of the original data files. Used when constructing the dataset."""
 49        self.num_tasks: int = num_tasks
 50        r"""Store the maximum number of tasks supported by the dataset."""
 51        self.validation_percentage: float = validation_percentage
 52        r"""Store the percentage to randomly split some of the training data into validation data."""
 53        self.batch_size: int = batch_size
 54        r"""Store the batch size. Used when constructing train, val, test dataloader."""
 55        self.num_workers: int = num_workers
 56        r"""Store the number of workers. Used when constructing train, val, test dataloader."""
 57        self.custom_transforms: Callable | transforms.Compose | None = custom_transforms
 58        r"""Store the custom transforms other than the basics. Used when constructing the dataset."""
 59        self.custom_target_transforms: Callable | transforms.Compose | None = (
 60            custom_target_transforms
 61        )
 62        r"""Store the custom target transforms other than the CL class mapping. Used when constructing the dataset."""
 63
 64        self.task_id: int
 65        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
 66        self.cl_paradigm: str
 67        r"""Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from `set_cl_paradigm` and used to define the CL class map."""
 68
 69        self.cl_class_map_t: dict[str | int, int]
 70        r"""Store the CL class map for the current task `self.task_id`. """
 71        self.cl_class_mapping_t: Callable
 72        r"""Store the CL class mapping transform for the current task `self.task_id`. """
 73
 74        self.dataset_train: object
 75        r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object."""
 76        self.dataset_val: object
 77        r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object."""
 78        self.dataset_test: dict[str, object] = {}
 79        r"""The dictionary to store test dataset object. Keys are task IDs (string type) and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects."""
 80
 81        CLDataset.sanity_check(self)
 82
 83    def sanity_check(self) -> None:
 84        r"""Check the sanity of the arguments.
 85
 86        **Raises:**
 87        - **ValueError**: when the `validation_percentage` is not in the range of 0-1.
 88        """
 89        if not 0.0 < self.validation_percentage < 1.0:
 90            raise ValueError("The validation_percentage should be 0-1.")
 91
 92    @abstractmethod
 93    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
 94        r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses.
 95
 96        **Args:**
 97        - **task_id** (`int`): The task ID to query CL class map.
 98
 99        **Returns:**
100        - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
101            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
102            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
103        """
104
105    @abstractmethod
106    def prepare_data(self) -> None:
107        r"""Use this to download and prepare data. It must be implemented by subclasses, regulated by `LightningDatamodule`."""
108
109    def setup(self, stage: str) -> None:
110        r"""Set up the dataset for different stages.
111
112        **Args:**
113        - **stage** (`str`): the stage of the experiment. Should be one of the following:
114            - 'fit' or 'validation': training and validation dataset of current task `self.task_id` should be assigned to `self.dataset_train` and `self.dataset_val`.
115            - 'test': a list of test dataset of all seen tasks (from task 0 to `self.task_id`) should be assigned to `self.dataset_test`.
116        """
117        if stage == "fit" or "validate":
118
119            pylogger.debug(
120                "Construct train and validation dataset for task %d...", self.task_id
121            )
122            self.dataset_train, self.dataset_val = self.train_and_val_dataset()
123            self.dataset_train.target_transform = (
124                self.target_transforms()
125            )  # apply target transform after potential class split
126            self.dataset_val.target_transform = (
127                self.target_transforms()
128            )  # apply target transform after potential class split
129            pylogger.debug(
130                "Train and validation dataset for task %d are ready.", self.task_id
131            )
132
133        if stage == "test":
134
135            pylogger.debug("Construct test dataset for task %d...", self.task_id)
136            self.dataset_test[f"{self.task_id}"] = self.test_dataset()
137            self.dataset_test[f"{self.task_id}"].target_transform = (
138                self.target_transforms()
139            )  # apply target transform after potential class split
140            pylogger.debug("Test dataset for task %d are ready.", self.task_id)
141
142    def setup_task_id(self, task_id: int) -> None:
143        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
144
145        **Args:**
146        - **task_id** (`int`): the target task ID.
147        """
148        self.task_id = task_id
149
150        self.cl_class_map_t = self.cl_class_map(task_id)
151        self.cl_class_mapping_t = CLClassMapping(self.cl_class_map_t)
152
153    def set_cl_paradigm(self, cl_paradigm: str) -> None:
154        r"""Set the continual learning paradigm to `self.cl_paradigm`. It is used to define the CL class map.
155
156        **Args:**
157        - **cl_paradigm** (`str`): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
158        """
159        self.cl_paradigm = cl_paradigm
160
161    @abstractmethod
162    def mean(self, task_id: int) -> tuple[float]:
163        r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset.
164
165        **Returns:**
166        - **mean** (`tuple[float]`): the mean values for normalisation.
167        """
168
169    @abstractmethod
170    def std(self, task_id: int) -> tuple[float]:
171        r"""The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset.
172
173        **Returns:**
174        - **std** (`tuple[float]`): the standard deviation values for normalisation.
175        """
176
177    def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose:
178        r"""Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. It is a handy tool to use in subclasses when constructing the dataset.
179
180        **Args:**
181        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
182
183        **Returns:**
184        - **train_and_val_transforms** (`transforms.Compose`): the composed training transforms.
185        """
186
187        return transforms.Compose(
188            list(
189                filter(
190                    None,
191                    [
192                        transforms.ToTensor() if to_tensor else None,
193                        self.custom_transforms,
194                        transforms.Normalize(
195                            self.mean(self.task_id), self.std(self.task_id)
196                        ),
197                    ],
198                )
199            )
200        )  # the order of transforms matters
201
202    def test_transforms(self, to_tensor: bool) -> transforms.Compose:
203        r"""Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset.
204
205        **Args:**
206        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
207
208        **Returns:**
209        - **test_transforms** (`transforms.Compose`): the composed training transforms.
210        """
211
212        return transforms.Compose(
213            list(
214                filter(
215                    None,
216                    [
217                        transforms.ToTensor() if to_tensor else None,
218                        transforms.Normalize(
219                            self.mean(self.task_id), self.std(self.task_id)
220                        ),
221                    ],
222                )
223            )
224        )  # the order of transforms matters
225
226    def target_transforms(self) -> transforms.Compose:
227        r"""The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset.
228
229        **Args:**
230        - **target** (`Tensor`): the target tensor.
231
232        **Returns:**
233        - **target_transforms** (`transforms.Compose`): the transformed target tensor.
234        """
235
236        return transforms.Compose(
237            list(
238                filter(
239                    None,
240                    [
241                        self.custom_target_transforms,
242                        self.cl_class_mapping_t,
243                    ],
244                )
245            )
246        )  # the order of transforms matters
247
248    @abstractmethod
249    def train_and_val_dataset(self) -> Any:
250        r"""Get the training and validation dataset of task `self.task_id`. It must be implemented by subclasses.
251
252        **Returns:**
253        - **train_and_val_dataset** (`Any`): the train and validation dataset of task `self.task_id`.
254        """
255
256    @abstractmethod
257    def test_dataset(self) -> Any:
258        """Get the test dataset of task `self.task_id`. It must be implemented by subclasses.
259
260        **Returns:**
261        - **train_and_val_dataset** (`Any`): the test dataset of task `self.task_id`.
262        """
263
264    def train_dataloader(self) -> DataLoader:
265        r"""DataLoader generator for stage train of task `self.task_id`. It is automatically called before training.
266
267        **Returns:**
268        - **train_dataloader** (`Dataloader`): the train DataLoader of task `self.task_id`.
269        """
270
271        pylogger.debug("Construct train dataloader for task %d...", self.task_id)
272
273        return DataLoader(
274            dataset=self.dataset_train,
275            batch_size=self.batch_size,
276            shuffle=True,  # shuffle train batch to prevent overfitting
277            num_workers=self.num_workers,
278            persistent_workers=True,
279        )
280
281    def val_dataloader(self) -> DataLoader:
282        r"""DataLoader generator for stage validate. It is automatically called before validation.
283
284        **Returns:**
285        - **val_dataloader** (`Dataloader`): the validation DataLoader of task `self.task_id`.
286        """
287
288        pylogger.debug("Construct validation dataloader for task %d...", self.task_id)
289
290        return DataLoader(
291            dataset=self.dataset_val,
292            batch_size=self.batch_size,
293            shuffle=False,  # don't have to shuffle val or test batch
294            num_workers=self.num_workers,
295            persistent_workers=True,
296        )
297
298    def test_dataloader(self) -> dict[int, DataLoader]:
299        r"""DataLoader generator for stage test. It is automatically called before testing.
300
301        **Returns:**
302        - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs (integer type) and values are the DataLoaders.
303        """
304
305        pylogger.debug("Construct test dataloader for task %d...", self.task_id)
306
307        return {
308            task_id: DataLoader(
309                dataset=dataset_test,
310                batch_size=self.batch_size,
311                shuffle=False,  # don't have to shuffle val or test batch
312                num_workers=self.num_workers,
313                persistent_workers=True,  # speed up the dataloader worker initialization
314            )
315            for task_id, dataset_test in self.dataset_test.items()
316        }

The base class of continual learning datasets, inherited from LightningDataModule.

CLDataset( root: str, num_tasks: int, validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None)
24    def __init__(
25        self,
26        root: str,
27        num_tasks: int,
28        validation_percentage: float,
29        batch_size: int = 1,
30        num_workers: int = 10,
31        custom_transforms: Callable | transforms.Compose | None = None,
32        custom_target_transforms: Callable | transforms.Compose | None = None,
33    ) -> None:
34        r"""Initialise the CL dataset object providing the root where data files live.
35
36        **Args:**
37        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
38        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
39        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
40        - **batch_size** (`int`): The batch size in train, val, test dataloader.
41        - **num_workers** (`int`): the number of workers for dataloaders.
42        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
43        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
44        """
45        LightningDataModule.__init__(self)
46
47        self.root: str = root
48        r"""Store the root directory of the original data files. Used when constructing the dataset."""
49        self.num_tasks: int = num_tasks
50        r"""Store the maximum number of tasks supported by the dataset."""
51        self.validation_percentage: float = validation_percentage
52        r"""Store the percentage to randomly split some of the training data into validation data."""
53        self.batch_size: int = batch_size
54        r"""Store the batch size. Used when constructing train, val, test dataloader."""
55        self.num_workers: int = num_workers
56        r"""Store the number of workers. Used when constructing train, val, test dataloader."""
57        self.custom_transforms: Callable | transforms.Compose | None = custom_transforms
58        r"""Store the custom transforms other than the basics. Used when constructing the dataset."""
59        self.custom_target_transforms: Callable | transforms.Compose | None = (
60            custom_target_transforms
61        )
62        r"""Store the custom target transforms other than the CL class mapping. Used when constructing the dataset."""
63
64        self.task_id: int
65        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
66        self.cl_paradigm: str
67        r"""Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from `set_cl_paradigm` and used to define the CL class map."""
68
69        self.cl_class_map_t: dict[str | int, int]
70        r"""Store the CL class map for the current task `self.task_id`. """
71        self.cl_class_mapping_t: Callable
72        r"""Store the CL class mapping transform for the current task `self.task_id`. """
73
74        self.dataset_train: object
75        r"""The training dataset object. Can be a PyTorch Dataset object or any other dataset object."""
76        self.dataset_val: object
77        r"""The validation dataset object. Can be a PyTorch Dataset object or any other dataset object."""
78        self.dataset_test: dict[str, object] = {}
79        r"""The dictionary to store test dataset object. Keys are task IDs (string type) and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects."""
80
81        CLDataset.sanity_check(self)

Initialise the CL dataset object providing the root where data files live.

Args:

  • root (str): the root directory where the original data files for constructing the CL dataset physically live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalise, permute and so on are not included.
  • custom_target_transforms (transform or transforms.Compose or None): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
root: str

Store the root directory of the original data files. Used when constructing the dataset.

num_tasks: int

Store the maximum number of tasks supported by the dataset.

validation_percentage: float

Store the percentage to randomly split some of the training data into validation data.

batch_size: int

Store the batch size. Used when constructing train, val, test dataloader.

num_workers: int

Store the number of workers. Used when constructing train, val, test dataloader.

custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType]

Store the custom transforms other than the basics. Used when constructing the dataset.

custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType]

Store the custom target transforms other than the CL class mapping. Used when constructing the dataset.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

cl_paradigm: str

Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from set_cl_paradigm and used to define the CL class map.

cl_class_map_t: dict[str | int, int]

Store the CL class map for the current task self.task_id.

cl_class_mapping_t: Callable

Store the CL class mapping transform for the current task self.task_id.

dataset_train: object

The training dataset object. Can be a PyTorch Dataset object or any other dataset object.

dataset_val: object

The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.

dataset_test: dict[str, object]

The dictionary to store test dataset object. Keys are task IDs (string type) and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.

def sanity_check(self) -> None:
83    def sanity_check(self) -> None:
84        r"""Check the sanity of the arguments.
85
86        **Raises:**
87        - **ValueError**: when the `validation_percentage` is not in the range of 0-1.
88        """
89        if not 0.0 < self.validation_percentage < 1.0:
90            raise ValueError("The validation_percentage should be 0-1.")

Check the sanity of the arguments.

Raises:

@abstractmethod
def cl_class_map(self, task_id: int) -> dict[str | int, int]:
 92    @abstractmethod
 93    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
 94        r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses.
 95
 96        **Args:**
 97        - **task_id** (`int`): The task ID to query CL class map.
 98
 99        **Returns:**
100        - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
101            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
102            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
103        """

The mapping of classes of task task_id to fit continual learning settings self.cl_paradigm. It must be implemented by subclasses.

Args:

  • task_id (int): The task ID to query CL class map.

Returns:

  • cl_class_map(dict[str | int, int]): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
@abstractmethod
def prepare_data(self) -> None:
105    @abstractmethod
106    def prepare_data(self) -> None:
107        r"""Use this to download and prepare data. It must be implemented by subclasses, regulated by `LightningDatamodule`."""

Use this to download and prepare data. It must be implemented by subclasses, regulated by LightningDatamodule.

def setup(self, stage: str) -> None:
109    def setup(self, stage: str) -> None:
110        r"""Set up the dataset for different stages.
111
112        **Args:**
113        - **stage** (`str`): the stage of the experiment. Should be one of the following:
114            - 'fit' or 'validation': training and validation dataset of current task `self.task_id` should be assigned to `self.dataset_train` and `self.dataset_val`.
115            - 'test': a list of test dataset of all seen tasks (from task 0 to `self.task_id`) should be assigned to `self.dataset_test`.
116        """
117        if stage == "fit" or "validate":
118
119            pylogger.debug(
120                "Construct train and validation dataset for task %d...", self.task_id
121            )
122            self.dataset_train, self.dataset_val = self.train_and_val_dataset()
123            self.dataset_train.target_transform = (
124                self.target_transforms()
125            )  # apply target transform after potential class split
126            self.dataset_val.target_transform = (
127                self.target_transforms()
128            )  # apply target transform after potential class split
129            pylogger.debug(
130                "Train and validation dataset for task %d are ready.", self.task_id
131            )
132
133        if stage == "test":
134
135            pylogger.debug("Construct test dataset for task %d...", self.task_id)
136            self.dataset_test[f"{self.task_id}"] = self.test_dataset()
137            self.dataset_test[f"{self.task_id}"].target_transform = (
138                self.target_transforms()
139            )  # apply target transform after potential class split
140            pylogger.debug("Test dataset for task %d are ready.", self.task_id)

Set up the dataset for different stages.

Args:

  • stage (str): the stage of the experiment. Should be one of the following:
    • 'fit' or 'validation': training and validation dataset of current task self.task_id should be assigned to self.dataset_train and self.dataset_val.
    • 'test': a list of test dataset of all seen tasks (from task 0 to self.task_id) should be assigned to self.dataset_test.
def setup_task_id(self, task_id: int) -> None:
142    def setup_task_id(self, task_id: int) -> None:
143        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
144
145        **Args:**
146        - **task_id** (`int`): the target task ID.
147        """
148        self.task_id = task_id
149
150        self.cl_class_map_t = self.cl_class_map(task_id)
151        self.cl_class_mapping_t = CLClassMapping(self.cl_class_map_t)

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def set_cl_paradigm(self, cl_paradigm: str) -> None:
153    def set_cl_paradigm(self, cl_paradigm: str) -> None:
154        r"""Set the continual learning paradigm to `self.cl_paradigm`. It is used to define the CL class map.
155
156        **Args:**
157        - **cl_paradigm** (`str`): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
158        """
159        self.cl_paradigm = cl_paradigm

Set the continual learning paradigm to self.cl_paradigm. It is used to define the CL class map.

Args:

  • cl_paradigm (str): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
@abstractmethod
def mean(self, task_id: int) -> tuple[float]:
161    @abstractmethod
162    def mean(self, task_id: int) -> tuple[float]:
163        r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset.
164
165        **Returns:**
166        - **mean** (`tuple[float]`): the mean values for normalisation.
167        """

The mean values for normalisation of task task_id. Used when constructing the dataset.

Returns:

  • mean (tuple[float]): the mean values for normalisation.
@abstractmethod
def std(self, task_id: int) -> tuple[float]:
169    @abstractmethod
170    def std(self, task_id: int) -> tuple[float]:
171        r"""The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset.
172
173        **Returns:**
174        - **std** (`tuple[float]`): the standard deviation values for normalisation.
175        """

The standard deviation values for normalisation of task task_id. Used when constructing the dataset.

Returns:

  • std (tuple[float]): the standard deviation values for normalisation.
def train_and_val_transforms(self, to_tensor: bool) -> torchvision.transforms.transforms.Compose:
177    def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose:
178        r"""Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. It is a handy tool to use in subclasses when constructing the dataset.
179
180        **Args:**
181        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
182
183        **Returns:**
184        - **train_and_val_transforms** (`transforms.Compose`): the composed training transforms.
185        """
186
187        return transforms.Compose(
188            list(
189                filter(
190                    None,
191                    [
192                        transforms.ToTensor() if to_tensor else None,
193                        self.custom_transforms,
194                        transforms.Normalize(
195                            self.mean(self.task_id), self.std(self.task_id)
196                        ),
197                    ],
198                )
199            )
200        )  # the order of transforms matters

Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like normalisation and ToTensor(). It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • to_tensor (bool): whether to include ToTensor() transform.

Returns:

  • train_and_val_transforms (transforms.Compose): the composed training transforms.
def test_transforms(self, to_tensor: bool) -> torchvision.transforms.transforms.Compose:
202    def test_transforms(self, to_tensor: bool) -> transforms.Compose:
203        r"""Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset.
204
205        **Args:**
206        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
207
208        **Returns:**
209        - **test_transforms** (`transforms.Compose`): the composed training transforms.
210        """
211
212        return transforms.Compose(
213            list(
214                filter(
215                    None,
216                    [
217                        transforms.ToTensor() if to_tensor else None,
218                        transforms.Normalize(
219                            self.mean(self.task_id), self.std(self.task_id)
220                        ),
221                    ],
222                )
223            )
224        )  # the order of transforms matters

Transforms generator for test dataset. Only basic transforms like normalisation and ToTensor() are included. It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • to_tensor (bool): whether to include ToTensor() transform.

Returns:

  • test_transforms (transforms.Compose): the composed training transforms.
def target_transforms(self) -> torchvision.transforms.transforms.Compose:
226    def target_transforms(self) -> transforms.Compose:
227        r"""The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset.
228
229        **Args:**
230        - **target** (`Tensor`): the target tensor.
231
232        **Returns:**
233        - **target_transforms** (`transforms.Compose`): the transformed target tensor.
234        """
235
236        return transforms.Compose(
237            list(
238                filter(
239                    None,
240                    [
241                        self.custom_target_transforms,
242                        self.cl_class_mapping_t,
243                    ],
244                )
245            )
246        )  # the order of transforms matters

The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • target (Tensor): the target tensor.

Returns:

  • target_transforms (transforms.Compose): the transformed target tensor.
@abstractmethod
def train_and_val_dataset(self) -> Any:
248    @abstractmethod
249    def train_and_val_dataset(self) -> Any:
250        r"""Get the training and validation dataset of task `self.task_id`. It must be implemented by subclasses.
251
252        **Returns:**
253        - **train_and_val_dataset** (`Any`): the train and validation dataset of task `self.task_id`.
254        """

Get the training and validation dataset of task self.task_id. It must be implemented by subclasses.

Returns:

  • train_and_val_dataset (Any): the train and validation dataset of task self.task_id.
@abstractmethod
def test_dataset(self) -> Any:
256    @abstractmethod
257    def test_dataset(self) -> Any:
258        """Get the test dataset of task `self.task_id`. It must be implemented by subclasses.
259
260        **Returns:**
261        - **train_and_val_dataset** (`Any`): the test dataset of task `self.task_id`.
262        """

Get the test dataset of task self.task_id. It must be implemented by subclasses.

Returns:

  • train_and_val_dataset (Any): the test dataset of task self.task_id.
def train_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
264    def train_dataloader(self) -> DataLoader:
265        r"""DataLoader generator for stage train of task `self.task_id`. It is automatically called before training.
266
267        **Returns:**
268        - **train_dataloader** (`Dataloader`): the train DataLoader of task `self.task_id`.
269        """
270
271        pylogger.debug("Construct train dataloader for task %d...", self.task_id)
272
273        return DataLoader(
274            dataset=self.dataset_train,
275            batch_size=self.batch_size,
276            shuffle=True,  # shuffle train batch to prevent overfitting
277            num_workers=self.num_workers,
278            persistent_workers=True,
279        )

DataLoader generator for stage train of task self.task_id. It is automatically called before training.

Returns:

  • train_dataloader (Dataloader): the train DataLoader of task self.task_id.
def val_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
281    def val_dataloader(self) -> DataLoader:
282        r"""DataLoader generator for stage validate. It is automatically called before validation.
283
284        **Returns:**
285        - **val_dataloader** (`Dataloader`): the validation DataLoader of task `self.task_id`.
286        """
287
288        pylogger.debug("Construct validation dataloader for task %d...", self.task_id)
289
290        return DataLoader(
291            dataset=self.dataset_val,
292            batch_size=self.batch_size,
293            shuffle=False,  # don't have to shuffle val or test batch
294            num_workers=self.num_workers,
295            persistent_workers=True,
296        )

DataLoader generator for stage validate. It is automatically called before validation.

Returns:

  • val_dataloader (Dataloader): the validation DataLoader of task self.task_id.
def test_dataloader(self) -> dict[int, torch.utils.data.dataloader.DataLoader]:
298    def test_dataloader(self) -> dict[int, DataLoader]:
299        r"""DataLoader generator for stage test. It is automatically called before testing.
300
301        **Returns:**
302        - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs (integer type) and values are the DataLoaders.
303        """
304
305        pylogger.debug("Construct test dataloader for task %d...", self.task_id)
306
307        return {
308            task_id: DataLoader(
309                dataset=dataset_test,
310                batch_size=self.batch_size,
311                shuffle=False,  # don't have to shuffle val or test batch
312                num_workers=self.num_workers,
313                persistent_workers=True,  # speed up the dataloader worker initialization
314            )
315            for task_id, dataset_test in self.dataset_test.items()
316        }

DataLoader generator for stage test. It is automatically called before testing.

Returns:

  • test_dataloader (dict[int, DataLoader]): the test DataLoader dict of self.task_id and all tasks before (as the test is conducted on all seen tasks). Keys are task IDs (integer type) and values are the DataLoaders.
class CLPermutedDataset(clarena.cl_datasets.CLDataset):
319class CLPermutedDataset(CLDataset):
320    r"""The base class of continual learning datasets which are constructed as permutations from an original dataset, inherited from `CLDataset`."""
321
322    num_classes: int
323    r"""The number of classes in the original dataset before permutation. It must be provided in subclasses."""
324
325    img_size: torch.Size
326    r"""The size of images in the original dataset before permutation. Used when constructing permutation operations. It must be provided in subclasses."""
327
328    mean_original: tuple[float]
329    r"""The mean values for normalisation. It must be provided in subclasses."""
330
331    std_original: tuple[float]
332    r"""The standard deviation values for normalisation. It must be provided in subclasses."""
333
334    def __init__(
335        self,
336        root: str,
337        num_tasks: int,
338        validation_percentage: float,
339        batch_size: int = 1,
340        num_workers: int = 10,
341        custom_transforms: Callable | transforms.Compose | None = None,
342        custom_target_transforms: Callable | transforms.Compose | None = None,
343        permutation_mode: str = "first_channel_only",
344        permutation_seeds: list[int] | None = None,
345    ) -> None:
346        r"""Initialise the CL dataset object providing the root where data files live.
347
348        **Args:**
349        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
350        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
351        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
352        - **batch_size** (`int`): The batch size in train, val, test dataloader.
353        - **num_workers** (`int`): the number of workers for dataloaders.
354        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
355        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
356        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
357            1. 'all': permute all pixels.
358            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
359            3. 'first_channel_only': permute only the first channel.
360        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
361        """
362        CLDataset.__init__(
363            self,
364            root=root,
365            num_tasks=num_tasks,
366            validation_percentage=validation_percentage,
367            batch_size=batch_size,
368            num_workers=num_workers,
369            custom_transforms=custom_transforms,
370            custom_target_transforms=custom_target_transforms,
371        )
372
373        self.permutation_mode: str = permutation_mode
374        r"""Store the mode of permutation. Used when permutation operations used to construct tasks. """
375
376        self.permutation_seeds: list[int] = (
377            permutation_seeds if permutation_seeds else list(range(num_tasks))
378        )
379        r"""Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks. """
380
381        self.permutation_seed_t: int
382        r"""Store the permutation seed for the current task `self.task_id`."""
383        self.permute_t: Permute
384        r"""Store the permutation transform for the current task `self.task_id`. """
385
386        CLPermutedDataset.sanity_check(self)
387
388    def sanity_check(self) -> None:
389        r"""Check the sanity of the arguments.
390
391        **Raises:**
392        - **ValueError**: when the `permutation_seeds` is not equal to `num_tasks`, or the `permutation_mode` is not one of the valid options.
393        """
394        if self.permutation_seeds and self.num_tasks != len(self.permutation_seeds):
395            raise ValueError(
396                "The number of permutation seeds is not equal to number of tasks!"
397            )
398        if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]:
399            raise ValueError(
400                "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'."
401            )
402
403    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
404        r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
405
406        **Args:**
407        - **task_id** (`int`): The task ID to query CL class map.
408
409        **Returns:**
410        - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
411            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
412            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
413        """
414        if self.cl_paradigm == "TIL":
415            return {i: i for i in range(self.num_classes)}
416        if self.cl_paradigm == "CIL":
417            return {
418                i: i + (task_id - 1) * self.num_classes for i in range(self.num_classes)
419            }
420
421    def setup_task_id(self, task_id: int) -> None:
422        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
423
424        **Args:**
425        - **task_id** (`int`): the target task ID.
426        """
427        super().setup_task_id(task_id)
428
429        self.permutation_seed_t = self.permutation_seeds[task_id - 1]
430        self.permute_t = Permute(
431            img_size=self.img_size,
432            mode=self.permutation_mode,
433            seed=self.permutation_seed_t,
434        )
435
436    def mean(self, task_id: int) -> tuple[float]:
437        r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
438
439        **Returns:**
440        - **mean** (`tuple[float]`): the mean values for normalisation.
441        """
442        return self.mean_original
443
444    def std(self, task_id: int) -> tuple[float]:
445        """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
446
447        **Returns:**
448        - **std** (`tuple[float]`): the standard deviation values for normalisation.
449        """
450        return self.std_original
451
452    def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose:
453        r"""Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset.
454
455        **Args:**
456        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
457
458        **Returns:**
459        - **train_and_val_transforms** (`transforms.Compose`): the composed training transforms.
460        """
461
462        return transforms.Compose(
463            list(
464                filter(
465                    None,
466                    [
467                        transforms.ToTensor() if to_tensor else None,
468                        self.permute_t,
469                        self.custom_transforms,
470                        transforms.Normalize(
471                            self.mean(self.task_id), self.std(self.task_id)
472                        ),
473                    ],
474                )
475            )
476        )  # the order of transforms matters
477
478    def test_transforms(self, to_tensor: bool) -> transforms.Compose:
479        r"""Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset.
480
481        **Args:**
482        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
483
484        **Returns:**
485        - **test_transforms** (`transforms.Compose`): the composed training transforms.
486        """
487
488        return transforms.Compose(
489            list(
490                filter(
491                    None,
492                    [
493                        transforms.ToTensor() if to_tensor else None,
494                        self.permute_t,
495                        transforms.Normalize(
496                            self.mean(self.task_id), self.std(self.task_id)
497                        ),
498                    ],
499                )
500            )
501        )  # the order of transforms matters

The base class of continual learning datasets which are constructed as permutations from an original dataset, inherited from CLDataset.

CLPermutedDataset( root: str, num_tasks: int, validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, permutation_mode: str = 'first_channel_only', permutation_seeds: list[int] | None = None)
334    def __init__(
335        self,
336        root: str,
337        num_tasks: int,
338        validation_percentage: float,
339        batch_size: int = 1,
340        num_workers: int = 10,
341        custom_transforms: Callable | transforms.Compose | None = None,
342        custom_target_transforms: Callable | transforms.Compose | None = None,
343        permutation_mode: str = "first_channel_only",
344        permutation_seeds: list[int] | None = None,
345    ) -> None:
346        r"""Initialise the CL dataset object providing the root where data files live.
347
348        **Args:**
349        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
350        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
351        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
352        - **batch_size** (`int`): The batch size in train, val, test dataloader.
353        - **num_workers** (`int`): the number of workers for dataloaders.
354        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
355        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
356        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
357            1. 'all': permute all pixels.
358            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
359            3. 'first_channel_only': permute only the first channel.
360        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
361        """
362        CLDataset.__init__(
363            self,
364            root=root,
365            num_tasks=num_tasks,
366            validation_percentage=validation_percentage,
367            batch_size=batch_size,
368            num_workers=num_workers,
369            custom_transforms=custom_transforms,
370            custom_target_transforms=custom_target_transforms,
371        )
372
373        self.permutation_mode: str = permutation_mode
374        r"""Store the mode of permutation. Used when permutation operations used to construct tasks. """
375
376        self.permutation_seeds: list[int] = (
377            permutation_seeds if permutation_seeds else list(range(num_tasks))
378        )
379        r"""Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks. """
380
381        self.permutation_seed_t: int
382        r"""Store the permutation seed for the current task `self.task_id`."""
383        self.permute_t: Permute
384        r"""Store the permutation transform for the current task `self.task_id`. """
385
386        CLPermutedDataset.sanity_check(self)

Initialise the CL dataset object providing the root where data files live.

Args:

  • root (str): the root directory where the original data files for constructing the CL dataset physically live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalise, permute and so on are not included.
  • custom_target_transforms (transform or transforms.Compose or None): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
  • permutation_mode (str): the mode of permutation, should be one of the following:
    1. 'all': permute all pixels.
    2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    3. 'first_channel_only': permute only the first channel.
  • permutation_seeds (list[int] or None): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as num_tasks. Default is None, which creates a list of seeds from 1 to num_tasks.
num_classes: int

The number of classes in the original dataset before permutation. It must be provided in subclasses.

img_size: torch.Size

The size of images in the original dataset before permutation. Used when constructing permutation operations. It must be provided in subclasses.

mean_original: tuple[float]

The mean values for normalisation. It must be provided in subclasses.

std_original: tuple[float]

The standard deviation values for normalisation. It must be provided in subclasses.

permutation_mode: str

Store the mode of permutation. Used when permutation operations used to construct tasks.

permutation_seeds: list[int]

Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks.

permutation_seed_t: int

Store the permutation seed for the current task self.task_id.

permute_t: Permute

Store the permutation transform for the current task self.task_id.

def sanity_check(self) -> None:
388    def sanity_check(self) -> None:
389        r"""Check the sanity of the arguments.
390
391        **Raises:**
392        - **ValueError**: when the `permutation_seeds` is not equal to `num_tasks`, or the `permutation_mode` is not one of the valid options.
393        """
394        if self.permutation_seeds and self.num_tasks != len(self.permutation_seeds):
395            raise ValueError(
396                "The number of permutation seeds is not equal to number of tasks!"
397            )
398        if self.permutation_mode not in ["all", "by_channel", "first_channel_only"]:
399            raise ValueError(
400                "The permutation_mode should be one of 'all', 'by_channel', 'first_channel_only'."
401            )

Check the sanity of the arguments.

Raises:

def cl_class_map(self, task_id: int) -> dict[str | int, int]:
403    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
404        r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
405
406        **Args:**
407        - **task_id** (`int`): The task ID to query CL class map.
408
409        **Returns:**
410        - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
411            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
412            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
413        """
414        if self.cl_paradigm == "TIL":
415            return {i: i for i in range(self.num_classes)}
416        if self.cl_paradigm == "CIL":
417            return {
418                i: i + (task_id - 1) * self.num_classes for i in range(self.num_classes)
419            }

The mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): The task ID to query CL class map.

Returns:

  • cl_class_map(dict[str | int, int]): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
421    def setup_task_id(self, task_id: int) -> None:
422        r"""Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
423
424        **Args:**
425        - **task_id** (`int`): the target task ID.
426        """
427        super().setup_task_id(task_id)
428
429        self.permutation_seed_t = self.permutation_seeds[task_id - 1]
430        self.permute_t = Permute(
431            img_size=self.img_size,
432            mode=self.permutation_mode,
433            seed=self.permutation_seed_t,
434        )

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def mean(self, task_id: int) -> tuple[float]:
436    def mean(self, task_id: int) -> tuple[float]:
437        r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
438
439        **Returns:**
440        - **mean** (`tuple[float]`): the mean values for normalisation.
441        """
442        return self.mean_original

The mean values for normalisation of task task_id. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.

Returns:

  • mean (tuple[float]): the mean values for normalisation.
def std(self, task_id: int) -> tuple[float]:
444    def std(self, task_id: int) -> tuple[float]:
445        """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
446
447        **Returns:**
448        - **std** (`tuple[float]`): the standard deviation values for normalisation.
449        """
450        return self.std_original

The standard deviation values for normalisation of task task_id. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.

Returns:

  • std (tuple[float]): the standard deviation values for normalisation.
def train_and_val_transforms(self, to_tensor: bool) -> torchvision.transforms.transforms.Compose:
452    def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose:
453        r"""Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset.
454
455        **Args:**
456        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
457
458        **Returns:**
459        - **train_and_val_transforms** (`transforms.Compose`): the composed training transforms.
460        """
461
462        return transforms.Compose(
463            list(
464                filter(
465                    None,
466                    [
467                        transforms.ToTensor() if to_tensor else None,
468                        self.permute_t,
469                        self.custom_transforms,
470                        transforms.Normalize(
471                            self.mean(self.task_id), self.std(self.task_id)
472                        ),
473                    ],
474                )
475            )
476        )  # the order of transforms matters

Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like normalisation and ToTensor(). In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • to_tensor (bool): whether to include ToTensor() transform.

Returns:

  • train_and_val_transforms (transforms.Compose): the composed training transforms.
def test_transforms(self, to_tensor: bool) -> torchvision.transforms.transforms.Compose:
478    def test_transforms(self, to_tensor: bool) -> transforms.Compose:
479        r"""Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset.
480
481        **Args:**
482        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
483
484        **Returns:**
485        - **test_transforms** (`transforms.Compose`): the composed training transforms.
486        """
487
488        return transforms.Compose(
489            list(
490                filter(
491                    None,
492                    [
493                        transforms.ToTensor() if to_tensor else None,
494                        self.permute_t,
495                        transforms.Normalize(
496                            self.mean(self.task_id), self.std(self.task_id)
497                        ),
498                    ],
499                )
500            )
501        )  # the order of transforms matters

Transforms generator for test dataset. Only basic transforms like normalisation and ToTensor() are included. It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • to_tensor (bool): whether to include ToTensor() transform.

Returns:

  • test_transforms (transforms.Compose): the composed training transforms.
class CLSplitDataset(clarena.cl_datasets.CLDataset):
504class CLSplitDataset(CLDataset):
505    r"""The base class of continual learning datasets, which are constructed as permutations from an original dataset, inherited from `CLDataset`."""
506
507    num_classes: int
508    r"""The number of classes in the original dataset before permutation. It must be provided in subclasses."""
509
510    mean_original: tuple[float]
511    r"""The mean values for normalisation. It must be provided in subclasses."""
512
513    std_original: tuple[float]
514    r"""The standard deviation values for normalisation. It must be provided in subclasses."""
515
516    def __init__(
517        self,
518        root: str,
519        num_tasks: int,
520        class_split: list[list[int]],
521        validation_percentage: float,
522        batch_size: int = 1,
523        num_workers: int = 10,
524        custom_transforms: Callable | transforms.Compose | None = None,
525        custom_target_transforms: Callable | transforms.Compose | None = None,
526    ) -> None:
527        r"""Initialise the CL dataset object providing the root where data files live.
528
529        **Args:**
530        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
531        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
532        - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
533        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
534        - **batch_size** (`int`): The batch size in train, val, test dataloader.
535        - **num_workers** (`int`): the number of workers for dataloaders.
536        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
537        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
538        """
539        CLDataset.__init__(
540            self,
541            root=root,
542            num_tasks=num_tasks,
543            validation_percentage=validation_percentage,
544            batch_size=batch_size,
545            num_workers=num_workers,
546            custom_transforms=custom_transforms,
547            custom_target_transforms=custom_target_transforms,
548        )
549
550        self.class_split = class_split
551        r"""Store the class split for each task. Used when constructing the split dataset."""
552
553        CLSplitDataset.sanity_check(self)
554
555    def sanity_check(self) -> None:
556        r"""Check the sanity of the arguments.
557
558        **Raises:**
559        - **ValueError**: when the length of `class_split` is not equal to `num_tasks`.
560        - **ValueError**: when any of the lists in `class_split` has less than 2 elements. A classification task must have less than 2 classes.
561        """
562        if len(self.class_split) != self.num_tasks:
563            raise ValueError(
564                "The length of class split is not equal to number of tasks!"
565            )
566        if any(len(split) < 2 for split in self.class_split):
567            raise ValueError("Each class split must contain at least 2 elements!")
568
569    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
570        r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
571
572        **Args:**
573        - **task_id** (`int`): The task ID to query CL class map.
574
575        **Returns:**
576        - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
577            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
578            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
579        """
580        num_classes_t = len(self.class_split[task_id - 1])
581        if self.cl_paradigm == "TIL":
582            return {self.class_split[task_id - 1][i]: i for i in range(num_classes_t)}
583        if self.cl_paradigm == "CIL":
584            num_classes_previous = sum(
585                [len(self.class_split[i]) for i in range(self.task_id - 1)]
586            )
587            return {
588                self.class_split[task_id - 1][i]: i
589                + (task_id - 1) * num_classes_previous
590                for i in range(num_classes_t)
591            }
592
593    def mean(self, task_id: int) -> tuple[float]:
594        r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
595
596        **Returns:**
597        - **mean** (`tuple[float]`): the mean values for normalisation.
598        """
599        return self.mean_original
600
601    def std(self, task_id: int) -> tuple[float]:
602        r"""The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
603
604        **Returns:**
605        - **std** (`tuple[float]`): he standard deviation values for normalisation.
606        """
607        return self.std_original
608
609    def get_class_subset(self, dataset: Dataset) -> Dataset:
610        r"""Provide a util method here to retrieve a subset from PyTorch Dataset of current classes of `self.task_id`. It could be useful when you constructing the split CL dataset.
611
612        **Args:**
613        - **dataset** (`Dataset`): the original dataset to retrieve subset from.
614
615        **Returns:**
616        - **subset** (`Dataset`): subset of original dataset in classes.
617        """
618        classes = self.class_split[self.task_id - 1]
619
620        # get the indices of the dataset that belong to the classes
621        idx = [i for i, (_, target) in enumerate(dataset) if target in classes]
622
623        # subset the dataset by the indices, in-place operation
624        dataset.data = dataset.data[idx]
625        dataset.targets = [dataset.targets[i] for i in idx]
626
627        return dataset

The base class of continual learning datasets, which are constructed as permutations from an original dataset, inherited from CLDataset.

CLSplitDataset( root: str, num_tasks: int, class_split: list[list[int]], validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None)
516    def __init__(
517        self,
518        root: str,
519        num_tasks: int,
520        class_split: list[list[int]],
521        validation_percentage: float,
522        batch_size: int = 1,
523        num_workers: int = 10,
524        custom_transforms: Callable | transforms.Compose | None = None,
525        custom_target_transforms: Callable | transforms.Compose | None = None,
526    ) -> None:
527        r"""Initialise the CL dataset object providing the root where data files live.
528
529        **Args:**
530        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
531        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
532        - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
533        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
534        - **batch_size** (`int`): The batch size in train, val, test dataloader.
535        - **num_workers** (`int`): the number of workers for dataloaders.
536        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
537        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
538        """
539        CLDataset.__init__(
540            self,
541            root=root,
542            num_tasks=num_tasks,
543            validation_percentage=validation_percentage,
544            batch_size=batch_size,
545            num_workers=num_workers,
546            custom_transforms=custom_transforms,
547            custom_target_transforms=custom_target_transforms,
548        )
549
550        self.class_split = class_split
551        r"""Store the class split for each task. Used when constructing the split dataset."""
552
553        CLSplitDataset.sanity_check(self)

Initialise the CL dataset object providing the root where data files live.

Args:

  • root (str): the root directory where the original data files for constructing the CL dataset physically live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset.
  • class_split (list[list[int]]): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalise, permute and so on are not included.
  • custom_target_transforms (transform or transforms.Compose or None): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
num_classes: int

The number of classes in the original dataset before permutation. It must be provided in subclasses.

mean_original: tuple[float]

The mean values for normalisation. It must be provided in subclasses.

std_original: tuple[float]

The standard deviation values for normalisation. It must be provided in subclasses.

class_split

Store the class split for each task. Used when constructing the split dataset.

def sanity_check(self) -> None:
555    def sanity_check(self) -> None:
556        r"""Check the sanity of the arguments.
557
558        **Raises:**
559        - **ValueError**: when the length of `class_split` is not equal to `num_tasks`.
560        - **ValueError**: when any of the lists in `class_split` has less than 2 elements. A classification task must have less than 2 classes.
561        """
562        if len(self.class_split) != self.num_tasks:
563            raise ValueError(
564                "The length of class split is not equal to number of tasks!"
565            )
566        if any(len(split) < 2 for split in self.class_split):
567            raise ValueError("Each class split must contain at least 2 elements!")

Check the sanity of the arguments.

Raises:

  • ValueError: when the length of class_split is not equal to num_tasks.
  • ValueError: when any of the lists in class_split has less than 2 elements. A classification task must have less than 2 classes.
def cl_class_map(self, task_id: int) -> dict[str | int, int]:
569    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
570        r"""The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
571
572        **Args:**
573        - **task_id** (`int`): The task ID to query CL class map.
574
575        **Returns:**
576        - **cl_class_map**(`dict[str | int, int]`): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
577            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
578            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
579        """
580        num_classes_t = len(self.class_split[task_id - 1])
581        if self.cl_paradigm == "TIL":
582            return {self.class_split[task_id - 1][i]: i for i in range(num_classes_t)}
583        if self.cl_paradigm == "CIL":
584            num_classes_previous = sum(
585                [len(self.class_split[i]) for i in range(self.task_id - 1)]
586            )
587            return {
588                self.class_split[task_id - 1][i]: i
589                + (task_id - 1) * num_classes_previous
590                for i in range(num_classes_t)
591            }

The mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): The task ID to query CL class map.

Returns:

  • cl_class_map(dict[str | int, int]): the CL class map of the task. Key is original class label, value is integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def mean(self, task_id: int) -> tuple[float]:
593    def mean(self, task_id: int) -> tuple[float]:
594        r"""The mean values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
595
596        **Returns:**
597        - **mean** (`tuple[float]`): the mean values for normalisation.
598        """
599        return self.mean_original

The mean values for normalisation of task task_id. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.

Returns:

  • mean (tuple[float]): the mean values for normalisation.
def std(self, task_id: int) -> tuple[float]:
601    def std(self, task_id: int) -> tuple[float]:
602        r"""The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
603
604        **Returns:**
605        - **std** (`tuple[float]`): he standard deviation values for normalisation.
606        """
607        return self.std_original

The standard deviation values for normalisation of task task_id. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.

Returns:

  • std (tuple[float]): he standard deviation values for normalisation.
def get_class_subset( self, dataset: torch.utils.data.dataset.Dataset) -> torch.utils.data.dataset.Dataset:
609    def get_class_subset(self, dataset: Dataset) -> Dataset:
610        r"""Provide a util method here to retrieve a subset from PyTorch Dataset of current classes of `self.task_id`. It could be useful when you constructing the split CL dataset.
611
612        **Args:**
613        - **dataset** (`Dataset`): the original dataset to retrieve subset from.
614
615        **Returns:**
616        - **subset** (`Dataset`): subset of original dataset in classes.
617        """
618        classes = self.class_split[self.task_id - 1]
619
620        # get the indices of the dataset that belong to the classes
621        idx = [i for i, (_, target) in enumerate(dataset) if target in classes]
622
623        # subset the dataset by the indices, in-place operation
624        dataset.data = dataset.data[idx]
625        dataset.targets = [dataset.targets[i] for i in idx]
626
627        return dataset

Provide a util method here to retrieve a subset from PyTorch Dataset of current classes of self.task_id. It could be useful when you constructing the split CL dataset.

Args:

  • dataset (Dataset): the original dataset to retrieve subset from.

Returns:

  • subset (Dataset): subset of original dataset in classes.
class CLClassMapping:
630class CLClassMapping:
631    r"""CL Class mapping to dataset labels. Used as a PyTorch target Transform."""
632
633    def __init__(self, cl_class_map: dict[str | int, int]) -> None:
634        r"""Initialise the CL class mapping transform object from the CL class map of a task.
635
636        **Args:**
637        - **cl_class_map** (`dict[str | int, int]`): the CL class map for a task.
638        """
639        self.cl_class_map = cl_class_map
640
641    def __call__(self, target: torch.Tensor) -> torch.Tensor:
642        r"""The CL class mapping transform to dataset labels. It is defined as a callable object like a PyTorch Transform.
643
644        **Args:**
645        - **target** (`Tensor`): the target tensor.
646
647        **Returns:**
648        - **transformed_target** (`Tensor`): the transformed target tensor.
649        """
650
651        return self.cl_class_map[target]

CL Class mapping to dataset labels. Used as a PyTorch target Transform.

CLClassMapping(cl_class_map: dict[str | int, int])
633    def __init__(self, cl_class_map: dict[str | int, int]) -> None:
634        r"""Initialise the CL class mapping transform object from the CL class map of a task.
635
636        **Args:**
637        - **cl_class_map** (`dict[str | int, int]`): the CL class map for a task.
638        """
639        self.cl_class_map = cl_class_map

Initialise the CL class mapping transform object from the CL class map of a task.

Args:

  • cl_class_map (dict[str | int, int]): the CL class map for a task.
cl_class_map
class Permute:
654class Permute:
655    r"""Permutation operation to image. Used to construct permuted CL dataset.
656
657    Used as a PyTorch Dataset Transform.
658    """
659
660    def __init__(
661        self,
662        img_size: torch.Size,
663        mode: str = "first_channel_only",
664        seed: int | None = None,
665    ) -> None:
666        r"""Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime.
667
668        **Args:**
669        - **img_size** (`torch.Size`): the size of the image to be permuted.
670        - **mode** (`str`): the mode of permutation, shouble be one of the following:
671            - 'all': permute all pixels.
672            - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
673            - 'first_channel_only': permute only the first channel.
674        - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
675        """
676        self.mode = mode
677        r"""Store the mode of permutation."""
678
679        # get generator for permutation
680        torch_generator = torch.Generator()
681        if seed:
682            torch_generator.manual_seed(seed)
683
684        # calculate the number of pixels from the image size
685        if self.mode == "all":
686            num_pixels = img_size[0] * img_size[1] * img_size[2]
687        elif self.mode == "by_channel" or "first_channel_only":
688            num_pixels = img_size[1] * img_size[2]
689
690        self.permute: torch.Tensor = torch.randperm(
691            num_pixels, generator=torch_generator
692        )
693        r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""
694
695    def __call__(self, img: torch.Tensor) -> torch.Tensor:
696        r"""The permutation operation to image is defined as a callable object like a PyTorch Transform.
697
698        **Args:**
699        - **img** (`Tensor`): image to be permuted. Must match the size of `img_size` in the initialisation.
700
701        **Returns:**
702        - **img_permuted** (`Tensor`): the permuted image.
703        """
704
705        if self.mode == "all":
706
707            img_flat = img.view(
708                -1
709            )  # flatten the whole image to 1d so that it can be applied 1d permuted order
710            img_flat_permuted = img_flat[self.permute]  # conduct permutation operation
711            img_permuted = img_flat_permuted.view(
712                img.size()
713            )  # return to the original image shape
714            return img_permuted
715
716        if self.mode == "by_channel":
717
718            permuted_channels = []
719            for i in range(img.size(0)):
720                # act on every channel
721                channel_flat = img[i].view(
722                    -1
723                )  # flatten the channel to 1d so that it can be applied 1d permuted order
724                channel_flat_permuted = channel_flat[
725                    self.permute
726                ]  # conduct permutation operation
727                channel_permuted = channel_flat_permuted.view(
728                    img[0].size()
729                )  # return to the original channel shape
730                permuted_channels.append(channel_permuted)
731            img_permuted = torch.stack(
732                permuted_channels
733            )  # stack the permuted channels to restore the image
734            return img_permuted
735
736        if self.mode == "first_channel_only":
737
738            first_channel_flat = img[0].view(
739                -1
740            )  # flatten the first channel to 1d so that it can be applied 1d permuted order
741            first_channel_flat_permuted = first_channel_flat[
742                self.permute
743            ]  # conduct permutation operation
744            first_channel_permuted = first_channel_flat_permuted.view(
745                img[0].size()
746            )  # return to the original channel shape
747
748            img_permuted = img.clone()
749            img_permuted[0] = first_channel_permuted
750
751            return img_permuted

Permutation operation to image. Used to construct permuted CL dataset.

Used as a PyTorch Dataset Transform.

Permute( img_size: torch.Size, mode: str = 'first_channel_only', seed: int | None = None)
660    def __init__(
661        self,
662        img_size: torch.Size,
663        mode: str = "first_channel_only",
664        seed: int | None = None,
665    ) -> None:
666        r"""Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime.
667
668        **Args:**
669        - **img_size** (`torch.Size`): the size of the image to be permuted.
670        - **mode** (`str`): the mode of permutation, shouble be one of the following:
671            - 'all': permute all pixels.
672            - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
673            - 'first_channel_only': permute only the first channel.
674        - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
675        """
676        self.mode = mode
677        r"""Store the mode of permutation."""
678
679        # get generator for permutation
680        torch_generator = torch.Generator()
681        if seed:
682            torch_generator.manual_seed(seed)
683
684        # calculate the number of pixels from the image size
685        if self.mode == "all":
686            num_pixels = img_size[0] * img_size[1] * img_size[2]
687        elif self.mode == "by_channel" or "first_channel_only":
688            num_pixels = img_size[1] * img_size[2]
689
690        self.permute: torch.Tensor = torch.randperm(
691            num_pixels, generator=torch_generator
692        )
693        r"""The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""

Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime.

Args:

  • img_size (torch.Size): the size of the image to be permuted.
  • mode (str): the mode of permutation, shouble be one of the following:
    • 'all': permute all pixels.
    • 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    • 'first_channel_only': permute only the first channel.
  • seed (int or None): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
mode

Store the mode of permutation.

permute: torch.Tensor

The permutation order, a Tensor permuted from [1,2, ..., num_pixels] with the given seed. It is the core element of permutation operation.