clarena.cl_datasets

Continual Learning Datasets

This submodule provides the continual learning datasets that can be used in CLArena.

Please note that this is an API documantation. Please refer to the main documentation page for more information about the CL datasets and how to use and customize them:

The datasets are implemented as subclasses of CLDataset classes, which are the base class for all continual learning datasets in CLArena.

 1"""
 2
 3# Continual Learning Datasets
 4
 5This submodule provides the **continual learning datasets** that can be used in CLArena. 
 6
 7Please note that this is an API documantation. Please refer to the main documentation page for more information about the CL datasets and how to use and customize them:
 8
 9- **Configure your CL dataset:** [https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/cl-dataset](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/cl-dataset)
10- **Implement your CL dataset:** [https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-dataset](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-dataset)
11- **A beginners' guide to continual learning (CL dataset):** [https://pengxiang-wang.com/posts/continual-learning-beginners-guide#CL-dataset](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#CL-dataset)
12
13The datasets are implemented as subclasses of `CLDataset` classes, which are the base class for all continual learning datasets in CLArena.
14
15- `CLDataset`: The base class for continual learning datasets.
16- `CLPermutedDataset`: The base class for permuted continual learning datasets. A child class of `CLDataset`.
17
18"""
19
20from .base import CLClassMapping, CLDataset, CLPermutedDataset, CLSplitDataset, Permute
21from .permuted_mnist import PermutedMNIST
22from .split_cifar100 import SplitCIFAR100
23
24__all__ = [
25    "CLDataset",
26    "CLPermutedDataset",
27    "CLSplitDataset",
28    "CLClassMapping",
29    "Permute",
30    "permuted_mnist",
31    "split_cifar100",
32]
class CLDataset(lightning.pytorch.core.datamodule.LightningDataModule):
 22class CLDataset(LightningDataModule):
 23    """The base class of continual learning datasets, inherited from `LightningDataModule`."""
 24
 25    def __init__(
 26        self,
 27        root: str,
 28        num_tasks: int,
 29        validation_percentage: float,
 30        batch_size: int = 1,
 31        num_workers: int = 10,
 32        custom_transforms: Callable | transforms.Compose | None = None,
 33        custom_target_transforms: Callable | transforms.Compose | None = None,
 34    ):
 35        """Initialise the CL dataset object providing the root where data files live.
 36
 37        **Args:**
 38        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
 39        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
 40        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
 41        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 42        - **num_workers** (`int`): the number of workers for dataloaders.
 43        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
 44        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
 45        """
 46        super().__init__()
 47
 48        self.root: str = root
 49        """Store the root directory of the original data files. Used when constructing the dataset."""
 50        self.num_tasks: int = num_tasks
 51        """Store the maximum number of tasks supported by the dataset."""
 52        self.validation_percentage: float = validation_percentage
 53        """Store the percentage to randomly split some of the training data into validation data."""
 54        self.batch_size: int = batch_size
 55        """Store the batch size. Used when constructing train, val, test dataloader."""
 56        self.num_workers: int = num_workers
 57        """Store the number of workers. Used when constructing train, val, test dataloader."""
 58        self.custom_transforms: Callable | transforms.Compose | None = custom_transforms
 59        """Store the custom transforms other than the basics. Used when constructing the dataset."""
 60        self.custom_target_transforms: Callable | transforms.Compose | None = (
 61            custom_target_transforms
 62        )
 63        """Store the custom target transforms other than the CL class mapping. Used when constructing the dataset."""
 64
 65        self.task_id: int
 66        """Task ID counter indicating which task is being processed. Self updated during the task loop."""
 67        self.cl_paradigm: str
 68        """Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from `set_cl_paradigm` and used to define the CL class map."""
 69
 70        self.cl_class_map_t: dict[str | int, int]
 71        """Store the CL class map for the current task `self.task_id`. """
 72        self.cl_class_mapping_t: Callable
 73        """Store the CL class mapping transform for the current task `self.task_id`. """
 74
 75        self.dataset_train: object
 76        """The training dataset object. Can be a PyTorch Dataset object or any other dataset object."""
 77        self.dataset_val: object
 78        """The validation dataset object. Can be a PyTorch Dataset object or any other dataset object."""
 79        self.dataset_test: dict[int, object] = {}
 80        """The dictionary to store test dataset object. Key is task_id, value is the dataset object. Can be a PyTorch Dataset object or any other dataset object."""
 81
 82        self.sanity_check()
 83
 84    def sanity_check(self) -> None:
 85        """Check the sanity of the arguments.
 86
 87        **Raises:**
 88        - **ValueError**: when the `validation_percentage` is not in the range of 0-1.
 89        """
 90        if not 0.0 < self.validation_percentage < 1.0:
 91            raise ValueError("The validation_percentage should be 0-1.")
 92
 93    @abstractmethod
 94    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
 95        """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses.
 96
 97        **Args:**
 98        - **task_id** (`int`): The task ID to query CL class map.
 99
100        **Returns:**
101        - The CL class map of the task. Key is original class label, value is integer class label for continual learning.
102            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
103            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
104
105        """
106
107    @abstractmethod
108    def prepare_data(self) -> None:
109        """Use this to download and prepare data. It must be implemented by subclasses, regulated by `LightningDatamodule`."""
110
111    def setup(self, stage: str) -> None:
112        """Set up the dataset for different stages.
113
114        **Args:**
115        - **stage** (`str`): the stage of the experiment. Should be one of the following:
116            - 'fit' or 'validate': training and validation dataset of current task `self.task_id` should be assigned to `self.dataset_train` and `self.dataset_val`.
117            - 'test': a list of test dataset of all seen tasks (from task 0 to `self.task_id`) should be assigned to `self.dataset_test`.
118        """
119        if stage == "fit" or "validate":
120
121            pylogger.debug(
122                "Construct train and validation dataset for task %d...", self.task_id
123            )
124            self.dataset_train, self.dataset_val = self.train_and_val_dataset()
125            pylogger.debug(
126                "Train and validation dataset for task %d are ready.", self.task_id
127            )
128
129        if stage == "test":
130
131            pylogger.debug("Construct test dataset for task %d...", self.task_id)
132            self.dataset_test[self.task_id] = self.test_dataset()
133            pylogger.debug("Test dataset for task %d are ready.", self.task_id)
134
135    def setup_task_id(self, task_id: int) -> None:
136        """Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
137
138        **Args:**
139        - **task_id** (`int`): the target task ID.
140        """
141        self.task_id = task_id
142
143        self.cl_class_map_t = self.cl_class_map(task_id)
144        self.cl_class_mapping_t = CLClassMapping(self.cl_class_map_t)
145
146    def set_cl_paradigm(self, cl_paradigm: str) -> None:
147        """Set the continual learning paradigm to `self.cl_paradigm`. It is used to define the CL class map.
148
149        **Args:**
150        - **cl_paradigm** (`str`): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
151        """
152        self.cl_paradigm = cl_paradigm
153
154    @abstractmethod
155    def mean(self, task_id: int) -> tuple[float]:
156        """The mean values for normalisation of task `task_id`. Used when constructing the dataset.
157
158        **Returns:**
159        - The mean values for normalisation.
160        """
161
162    @abstractmethod
163    def std(self, task_id: int) -> tuple[float]:
164        """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset.
165
166        **Returns:**
167        - The standard deviation values for normalisation.
168        """
169
170    def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose:
171        """Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. It is a handy tool to use in subclasses when constructing the dataset.
172
173        **Args:**
174        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
175
176        **Returns:**
177        - The composed training transforms.
178        """
179
180        return transforms.Compose(
181            list(
182                filter(
183                    None,
184                    [
185                        transforms.ToTensor() if to_tensor else None,
186                        self.custom_transforms,
187                        transforms.Normalize(
188                            self.mean(self.task_id), self.std(self.task_id)
189                        ),
190                    ],
191                )
192            )
193        )  # the order of transforms matters
194
195    def test_transforms(self, to_tensor: bool) -> transforms.Compose:
196        """Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset.
197
198        **Args:**
199        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
200
201        **Returns:**
202        - The composed training transforms.
203        """
204
205        return transforms.Compose(
206            list(
207                filter(
208                    None,
209                    [
210                        transforms.ToTensor() if to_tensor else None,
211                        transforms.Normalize(
212                            self.mean(self.task_id), self.std(self.task_id)
213                        ),
214                    ],
215                )
216            )
217        )  # the order of transforms matters
218
219    def target_transforms(self) -> transforms.Compose:
220        """The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset.
221
222        **Args:**
223        - **target** (`Tensor`): the target tensor.
224
225        **Returns:**
226        - The transformed target tensor.
227        """
228
229        return transforms.Compose(
230            list(
231                filter(
232                    None,
233                    [
234                        self.custom_target_transforms,
235                        self.cl_class_mapping_t,
236                    ],
237                )
238            )
239        )  # the order of transforms matters
240
241    @abstractmethod
242    def train_and_val_dataset(self) -> object:
243        """Get the training and validation dataset of task `self.task_id`. It must be implemented by subclasses.
244
245        **Returns:**
246        - The train and validation dataset of task `self.task_id`.
247        """
248
249    @abstractmethod
250    def test_dataset(self) -> object:
251        """Get the test dataset of task `self.task_id`. It must be implemented by subclasses.
252
253        **Returns:**
254        - The test dataset of task `self.task_id`.
255        """
256
257    def train_dataloader(self) -> DataLoader:
258        """DataLoader generator for stage train of task `self.task_id`. It is automatically called before training.
259
260        **Returns:**
261        - The train DataLoader of task `self.task_id`.
262        """
263
264        pylogger.debug("Construct train dataloader for task %d...", self.task_id)
265
266        return DataLoader(
267            dataset=self.dataset_train,
268            batch_size=self.batch_size,
269            shuffle=True,  # shuffle train batch to prevent overfitting
270            num_workers=self.num_workers,
271            persistent_workers=True,
272        )
273
274    def val_dataloader(self) -> DataLoader:
275        """DataLoader generator for stage validate. It is automatically called before validating.
276
277        **Returns:**
278        - The validation DataLoader of task `self.task_id`.
279        """
280
281        pylogger.debug("Construct validation dataloader for task %d...", self.task_id)
282
283        return DataLoader(
284            dataset=self.dataset_val,
285            batch_size=self.batch_size,
286            shuffle=False,  # don't have to shuffle val or test batch
287            num_workers=self.num_workers,
288            persistent_workers=True,
289        )
290
291    def test_dataloader(self) -> dict[int, DataLoader]:
292        """DataLoader generator for stage test. It is automatically called before testing.
293
294        **Returns:**
295        - The test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Key is task_id, value is the DataLoader.
296        """
297
298        pylogger.debug("Construct test dataloader for task %d...", self.task_id)
299
300        return {
301            task_id: DataLoader(
302                dataset=dataset_test,
303                batch_size=self.batch_size,
304                shuffle=False,  # don't have to shuffle val or test batch
305                num_workers=self.num_workers,
306                persistent_workers=True,  # speed up the dataloader worker initialization
307            )
308            for task_id, dataset_test in self.dataset_test.items()
309        }

The base class of continual learning datasets, inherited from LightningDataModule.

CLDataset( root: str, num_tasks: int, validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None)
25    def __init__(
26        self,
27        root: str,
28        num_tasks: int,
29        validation_percentage: float,
30        batch_size: int = 1,
31        num_workers: int = 10,
32        custom_transforms: Callable | transforms.Compose | None = None,
33        custom_target_transforms: Callable | transforms.Compose | None = None,
34    ):
35        """Initialise the CL dataset object providing the root where data files live.
36
37        **Args:**
38        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
39        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
40        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
41        - **batch_size** (`int`): The batch size in train, val, test dataloader.
42        - **num_workers** (`int`): the number of workers for dataloaders.
43        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
44        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
45        """
46        super().__init__()
47
48        self.root: str = root
49        """Store the root directory of the original data files. Used when constructing the dataset."""
50        self.num_tasks: int = num_tasks
51        """Store the maximum number of tasks supported by the dataset."""
52        self.validation_percentage: float = validation_percentage
53        """Store the percentage to randomly split some of the training data into validation data."""
54        self.batch_size: int = batch_size
55        """Store the batch size. Used when constructing train, val, test dataloader."""
56        self.num_workers: int = num_workers
57        """Store the number of workers. Used when constructing train, val, test dataloader."""
58        self.custom_transforms: Callable | transforms.Compose | None = custom_transforms
59        """Store the custom transforms other than the basics. Used when constructing the dataset."""
60        self.custom_target_transforms: Callable | transforms.Compose | None = (
61            custom_target_transforms
62        )
63        """Store the custom target transforms other than the CL class mapping. Used when constructing the dataset."""
64
65        self.task_id: int
66        """Task ID counter indicating which task is being processed. Self updated during the task loop."""
67        self.cl_paradigm: str
68        """Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from `set_cl_paradigm` and used to define the CL class map."""
69
70        self.cl_class_map_t: dict[str | int, int]
71        """Store the CL class map for the current task `self.task_id`. """
72        self.cl_class_mapping_t: Callable
73        """Store the CL class mapping transform for the current task `self.task_id`. """
74
75        self.dataset_train: object
76        """The training dataset object. Can be a PyTorch Dataset object or any other dataset object."""
77        self.dataset_val: object
78        """The validation dataset object. Can be a PyTorch Dataset object or any other dataset object."""
79        self.dataset_test: dict[int, object] = {}
80        """The dictionary to store test dataset object. Key is task_id, value is the dataset object. Can be a PyTorch Dataset object or any other dataset object."""
81
82        self.sanity_check()

Initialise the CL dataset object providing the root where data files live.

Args:

  • root (str): the root directory where the original data files for constructing the CL dataset physically live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalise, permute and so on are not included.
  • custom_target_transforms (transform or transforms.Compose or None): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
root: str

Store the root directory of the original data files. Used when constructing the dataset.

num_tasks: int

Store the maximum number of tasks supported by the dataset.

validation_percentage: float

Store the percentage to randomly split some of the training data into validation data.

batch_size: int

Store the batch size. Used when constructing train, val, test dataloader.

num_workers: int

Store the number of workers. Used when constructing train, val, test dataloader.

custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType]

Store the custom transforms other than the basics. Used when constructing the dataset.

custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType]

Store the custom target transforms other than the CL class mapping. Used when constructing the dataset.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

cl_paradigm: str

Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Gotten from set_cl_paradigm and used to define the CL class map.

cl_class_map_t: dict[str | int, int]

Store the CL class map for the current task self.task_id.

cl_class_mapping_t: Callable

Store the CL class mapping transform for the current task self.task_id.

dataset_train: object

The training dataset object. Can be a PyTorch Dataset object or any other dataset object.

dataset_val: object

The validation dataset object. Can be a PyTorch Dataset object or any other dataset object.

dataset_test: dict[int, object]

The dictionary to store test dataset object. Key is task_id, value is the dataset object. Can be a PyTorch Dataset object or any other dataset object.

def sanity_check(self) -> None:
84    def sanity_check(self) -> None:
85        """Check the sanity of the arguments.
86
87        **Raises:**
88        - **ValueError**: when the `validation_percentage` is not in the range of 0-1.
89        """
90        if not 0.0 < self.validation_percentage < 1.0:
91            raise ValueError("The validation_percentage should be 0-1.")

Check the sanity of the arguments.

Raises:

@abstractmethod
def cl_class_map(self, task_id: int) -> dict[str | int, int]:
 93    @abstractmethod
 94    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
 95        """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`. It must be implemented by subclasses.
 96
 97        **Args:**
 98        - **task_id** (`int`): The task ID to query CL class map.
 99
100        **Returns:**
101        - The CL class map of the task. Key is original class label, value is integer class label for continual learning.
102            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
103            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
104
105        """

The mapping of classes of task task_id to fit continual learning settings self.cl_paradigm. It must be implemented by subclasses.

Args:

  • task_id (int): The task ID to query CL class map.

Returns:

  • The CL class map of the task. Key is original class label, value is integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
@abstractmethod
def prepare_data(self) -> None:
107    @abstractmethod
108    def prepare_data(self) -> None:
109        """Use this to download and prepare data. It must be implemented by subclasses, regulated by `LightningDatamodule`."""

Use this to download and prepare data. It must be implemented by subclasses, regulated by LightningDatamodule.

def setup(self, stage: str) -> None:
111    def setup(self, stage: str) -> None:
112        """Set up the dataset for different stages.
113
114        **Args:**
115        - **stage** (`str`): the stage of the experiment. Should be one of the following:
116            - 'fit' or 'validate': training and validation dataset of current task `self.task_id` should be assigned to `self.dataset_train` and `self.dataset_val`.
117            - 'test': a list of test dataset of all seen tasks (from task 0 to `self.task_id`) should be assigned to `self.dataset_test`.
118        """
119        if stage == "fit" or "validate":
120
121            pylogger.debug(
122                "Construct train and validation dataset for task %d...", self.task_id
123            )
124            self.dataset_train, self.dataset_val = self.train_and_val_dataset()
125            pylogger.debug(
126                "Train and validation dataset for task %d are ready.", self.task_id
127            )
128
129        if stage == "test":
130
131            pylogger.debug("Construct test dataset for task %d...", self.task_id)
132            self.dataset_test[self.task_id] = self.test_dataset()
133            pylogger.debug("Test dataset for task %d are ready.", self.task_id)

Set up the dataset for different stages.

Args:

  • stage (str): the stage of the experiment. Should be one of the following:
    • 'fit' or 'validate': training and validation dataset of current task self.task_id should be assigned to self.dataset_train and self.dataset_val.
    • 'test': a list of test dataset of all seen tasks (from task 0 to self.task_id) should be assigned to self.dataset_test.
def setup_task_id(self, task_id: int) -> None:
135    def setup_task_id(self, task_id: int) -> None:
136        """Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
137
138        **Args:**
139        - **task_id** (`int`): the target task ID.
140        """
141        self.task_id = task_id
142
143        self.cl_class_map_t = self.cl_class_map(task_id)
144        self.cl_class_mapping_t = CLClassMapping(self.cl_class_map_t)

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def set_cl_paradigm(self, cl_paradigm: str) -> None:
146    def set_cl_paradigm(self, cl_paradigm: str) -> None:
147        """Set the continual learning paradigm to `self.cl_paradigm`. It is used to define the CL class map.
148
149        **Args:**
150        - **cl_paradigm** (`str`): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
151        """
152        self.cl_paradigm = cl_paradigm

Set the continual learning paradigm to self.cl_paradigm. It is used to define the CL class map.

Args:

  • cl_paradigm (str): the continual learning paradigmeither 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning).
@abstractmethod
def mean(self, task_id: int) -> tuple[float]:
154    @abstractmethod
155    def mean(self, task_id: int) -> tuple[float]:
156        """The mean values for normalisation of task `task_id`. Used when constructing the dataset.
157
158        **Returns:**
159        - The mean values for normalisation.
160        """

The mean values for normalisation of task task_id. Used when constructing the dataset.

Returns:

  • The mean values for normalisation.
@abstractmethod
def std(self, task_id: int) -> tuple[float]:
162    @abstractmethod
163    def std(self, task_id: int) -> tuple[float]:
164        """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset.
165
166        **Returns:**
167        - The standard deviation values for normalisation.
168        """

The standard deviation values for normalisation of task task_id. Used when constructing the dataset.

Returns:

  • The standard deviation values for normalisation.
def train_and_val_transforms(self, to_tensor: bool) -> torchvision.transforms.transforms.Compose:
170    def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose:
171        """Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. It is a handy tool to use in subclasses when constructing the dataset.
172
173        **Args:**
174        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
175
176        **Returns:**
177        - The composed training transforms.
178        """
179
180        return transforms.Compose(
181            list(
182                filter(
183                    None,
184                    [
185                        transforms.ToTensor() if to_tensor else None,
186                        self.custom_transforms,
187                        transforms.Normalize(
188                            self.mean(self.task_id), self.std(self.task_id)
189                        ),
190                    ],
191                )
192            )
193        )  # the order of transforms matters

Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like normalisation and ToTensor(). It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • to_tensor (bool): whether to include ToTensor() transform.

Returns:

  • The composed training transforms.
def test_transforms(self, to_tensor: bool) -> torchvision.transforms.transforms.Compose:
195    def test_transforms(self, to_tensor: bool) -> transforms.Compose:
196        """Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset.
197
198        **Args:**
199        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
200
201        **Returns:**
202        - The composed training transforms.
203        """
204
205        return transforms.Compose(
206            list(
207                filter(
208                    None,
209                    [
210                        transforms.ToTensor() if to_tensor else None,
211                        transforms.Normalize(
212                            self.mean(self.task_id), self.std(self.task_id)
213                        ),
214                    ],
215                )
216            )
217        )  # the order of transforms matters

Transforms generator for test dataset. Only basic transforms like normalisation and ToTensor() are included. It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • to_tensor (bool): whether to include ToTensor() transform.

Returns:

  • The composed training transforms.
def target_transforms(self) -> torchvision.transforms.transforms.Compose:
219    def target_transforms(self) -> transforms.Compose:
220        """The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset.
221
222        **Args:**
223        - **target** (`Tensor`): the target tensor.
224
225        **Returns:**
226        - The transformed target tensor.
227        """
228
229        return transforms.Compose(
230            list(
231                filter(
232                    None,
233                    [
234                        self.custom_target_transforms,
235                        self.cl_class_mapping_t,
236                    ],
237                )
238            )
239        )  # the order of transforms matters

The target transform for the dataset. It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • target (Tensor): the target tensor.

Returns:

  • The transformed target tensor.
@abstractmethod
def train_and_val_dataset(self) -> object:
241    @abstractmethod
242    def train_and_val_dataset(self) -> object:
243        """Get the training and validation dataset of task `self.task_id`. It must be implemented by subclasses.
244
245        **Returns:**
246        - The train and validation dataset of task `self.task_id`.
247        """

Get the training and validation dataset of task self.task_id. It must be implemented by subclasses.

Returns:

  • The train and validation dataset of task self.task_id.
@abstractmethod
def test_dataset(self) -> object:
249    @abstractmethod
250    def test_dataset(self) -> object:
251        """Get the test dataset of task `self.task_id`. It must be implemented by subclasses.
252
253        **Returns:**
254        - The test dataset of task `self.task_id`.
255        """

Get the test dataset of task self.task_id. It must be implemented by subclasses.

Returns:

  • The test dataset of task self.task_id.
def train_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
257    def train_dataloader(self) -> DataLoader:
258        """DataLoader generator for stage train of task `self.task_id`. It is automatically called before training.
259
260        **Returns:**
261        - The train DataLoader of task `self.task_id`.
262        """
263
264        pylogger.debug("Construct train dataloader for task %d...", self.task_id)
265
266        return DataLoader(
267            dataset=self.dataset_train,
268            batch_size=self.batch_size,
269            shuffle=True,  # shuffle train batch to prevent overfitting
270            num_workers=self.num_workers,
271            persistent_workers=True,
272        )

DataLoader generator for stage train of task self.task_id. It is automatically called before training.

Returns:

  • The train DataLoader of task self.task_id.
def val_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
274    def val_dataloader(self) -> DataLoader:
275        """DataLoader generator for stage validate. It is automatically called before validating.
276
277        **Returns:**
278        - The validation DataLoader of task `self.task_id`.
279        """
280
281        pylogger.debug("Construct validation dataloader for task %d...", self.task_id)
282
283        return DataLoader(
284            dataset=self.dataset_val,
285            batch_size=self.batch_size,
286            shuffle=False,  # don't have to shuffle val or test batch
287            num_workers=self.num_workers,
288            persistent_workers=True,
289        )

DataLoader generator for stage validate. It is automatically called before validating.

Returns:

  • The validation DataLoader of task self.task_id.
def test_dataloader(self) -> dict[int, torch.utils.data.dataloader.DataLoader]:
291    def test_dataloader(self) -> dict[int, DataLoader]:
292        """DataLoader generator for stage test. It is automatically called before testing.
293
294        **Returns:**
295        - The test DataLoader dict of `self.task_id` and all tasks before (as the test is conducted on all seen tasks). Key is task_id, value is the DataLoader.
296        """
297
298        pylogger.debug("Construct test dataloader for task %d...", self.task_id)
299
300        return {
301            task_id: DataLoader(
302                dataset=dataset_test,
303                batch_size=self.batch_size,
304                shuffle=False,  # don't have to shuffle val or test batch
305                num_workers=self.num_workers,
306                persistent_workers=True,  # speed up the dataloader worker initialization
307            )
308            for task_id, dataset_test in self.dataset_test.items()
309        }

DataLoader generator for stage test. It is automatically called before testing.

Returns:

  • The test DataLoader dict of self.task_id and all tasks before (as the test is conducted on all seen tasks). Key is task_id, value is the DataLoader.
class CLPermutedDataset(clarena.cl_datasets.CLDataset):
312class CLPermutedDataset(CLDataset):
313    """The base class of continual learning datasets which are constructed as permutations from an original dataset, inherited from `CLDataset`."""
314
315    num_classes: int
316    """The number of classes in the original dataset before permutation. It must be provided in subclasses."""
317
318    img_size: torch.Size
319    """The size of images in the original dataset before permutation. Used when constructing permutation operations. It must be provided in subclasses."""
320
321    mean_original: tuple[float]
322    """The mean values for normalisation. It must be provided in subclasses."""
323
324    std_original: tuple[float]
325    """The standard deviation values for normalisation. It must be provided in subclasses."""
326
327    def __init__(
328        self,
329        root: str,
330        num_tasks: int,
331        validation_percentage: float,
332        batch_size: int = 1,
333        num_workers: int = 10,
334        custom_transforms: Callable | transforms.Compose | None = None,
335        custom_target_transforms: Callable | transforms.Compose | None = None,
336        permutation_mode: str = "first_channel_only",
337        permutation_seeds: list[int] | None = None,
338    ):
339        """Initialise the CL dataset object providing the root where data files live.
340
341        **Args:**
342        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
343        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
344        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
345        - **batch_size** (`int`): The batch size in train, val, test dataloader.
346        - **num_workers** (`int`): the number of workers for dataloaders.
347        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
348        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
349        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
350            1. 'all': permute all pixels.
351            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
352            3. 'first_channel_only': permute only the first channel.
353        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
354        """
355        self.permutation_mode: str = permutation_mode
356        """Store the mode of permutation. Used when permutation operations used to construct tasks. """
357
358        self.permutation_seeds: list[int] = (
359            permutation_seeds if permutation_seeds else list(range(num_tasks))
360        )
361        """Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks. """
362
363        self.permutation_seed_t: int
364        """Store the permutation seed for the current task `self.task_id`."""
365        self.permute_t: Permute
366        """Store the permutation transform for the current task `self.task_id`. """
367
368        super().__init__(
369            root=root,
370            num_tasks=num_tasks,
371            validation_percentage=validation_percentage,
372            batch_size=batch_size,
373            num_workers=num_workers,
374            custom_transforms=custom_transforms,
375            custom_target_transforms=custom_target_transforms,
376        )
377
378    def sanity_check(self) -> None:
379        """Check the sanity of the arguments.
380
381        **Raises:**
382        - **ValueError**: when the `permutation_seeds` is not equal to `num_tasks`.
383        """
384        if self.permutation_seeds and self.num_tasks != len(self.permutation_seeds):
385            raise ValueError(
386                "The number of permutation seeds is not equal to number of tasks!"
387            )
388
389        super().sanity_check()
390
391    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
392        """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
393
394        **Args:**
395        - **task_id** (`int`): The task ID to query CL class map.
396
397        **Returns:**
398        - The CL class map of the task. Key is original class label, value is integer class label for continual learning.
399            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
400            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
401        """
402        if self.cl_paradigm == "TIL":
403            return {i: i for i in range(self.num_classes)}
404        if self.cl_paradigm == "CIL":
405            return {
406                i: i + (task_id - 1) * self.num_classes for i in range(self.num_classes)
407            }
408
409    def setup_task_id(self, task_id: int) -> None:
410        """Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
411
412        **Args:**
413        - **task_id** (`int`): the target task ID.
414        """
415        super().setup_task_id(task_id)
416
417        self.permutation_seed_t = self.permutation_seeds[task_id - 1]
418        self.permute_t = Permute(
419            img_size=self.img_size,
420            mode=self.permutation_mode,
421            seed=self.permutation_seed_t,
422        )
423
424    def mean(self, task_id: int) -> tuple[float]:
425        """The mean values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
426
427        **Returns:**
428        - The mean values for normalisation.
429        """
430        return self.mean_original
431
432    def std(self, task_id: int) -> tuple[float]:
433        """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
434
435        **Returns:**
436        - The standard deviation values for normalisation.
437        """
438        return self.std_original
439
440    def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose:
441        """Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset.
442
443        **Args:**
444        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
445
446        **Returns:**
447        - The composed training transforms.
448        """
449
450        return transforms.Compose(
451            list(
452                filter(
453                    None,
454                    [
455                        transforms.ToTensor() if to_tensor else None,
456                        self.permute_t,
457                        self.custom_transforms,
458                        transforms.Normalize(
459                            self.mean(self.task_id), self.std(self.task_id)
460                        ),
461                    ],
462                )
463            )
464        )  # the order of transforms matters
465
466    def test_transforms(self, to_tensor: bool) -> transforms.Compose:
467        """Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset.
468
469        **Args:**
470        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
471
472        **Returns:**
473        - The composed training transforms.
474        """
475
476        return transforms.Compose(
477            list(
478                filter(
479                    None,
480                    [
481                        transforms.ToTensor() if to_tensor else None,
482                        self.permute_t,
483                        transforms.Normalize(
484                            self.mean(self.task_id), self.std(self.task_id)
485                        ),
486                    ],
487                )
488            )
489        )  # the order of transforms matters

The base class of continual learning datasets which are constructed as permutations from an original dataset, inherited from CLDataset.

CLPermutedDataset( root: str, num_tasks: int, validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, permutation_mode: str = 'first_channel_only', permutation_seeds: list[int] | None = None)
327    def __init__(
328        self,
329        root: str,
330        num_tasks: int,
331        validation_percentage: float,
332        batch_size: int = 1,
333        num_workers: int = 10,
334        custom_transforms: Callable | transforms.Compose | None = None,
335        custom_target_transforms: Callable | transforms.Compose | None = None,
336        permutation_mode: str = "first_channel_only",
337        permutation_seeds: list[int] | None = None,
338    ):
339        """Initialise the CL dataset object providing the root where data files live.
340
341        **Args:**
342        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
343        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
344        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
345        - **batch_size** (`int`): The batch size in train, val, test dataloader.
346        - **num_workers** (`int`): the number of workers for dataloaders.
347        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
348        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
349        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
350            1. 'all': permute all pixels.
351            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
352            3. 'first_channel_only': permute only the first channel.
353        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
354        """
355        self.permutation_mode: str = permutation_mode
356        """Store the mode of permutation. Used when permutation operations used to construct tasks. """
357
358        self.permutation_seeds: list[int] = (
359            permutation_seeds if permutation_seeds else list(range(num_tasks))
360        )
361        """Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks. """
362
363        self.permutation_seed_t: int
364        """Store the permutation seed for the current task `self.task_id`."""
365        self.permute_t: Permute
366        """Store the permutation transform for the current task `self.task_id`. """
367
368        super().__init__(
369            root=root,
370            num_tasks=num_tasks,
371            validation_percentage=validation_percentage,
372            batch_size=batch_size,
373            num_workers=num_workers,
374            custom_transforms=custom_transforms,
375            custom_target_transforms=custom_target_transforms,
376        )

Initialise the CL dataset object providing the root where data files live.

Args:

  • root (str): the root directory where the original data files for constructing the CL dataset physically live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalise, permute and so on are not included.
  • custom_target_transforms (transform or transforms.Compose or None): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
  • permutation_mode (str): the mode of permutation, should be one of the following:
    1. 'all': permute all pixels.
    2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    3. 'first_channel_only': permute only the first channel.
  • permutation_seeds (list[int] or None): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as num_tasks. Default is None, which creates a list of seeds from 1 to num_tasks.
num_classes: int

The number of classes in the original dataset before permutation. It must be provided in subclasses.

img_size: torch.Size

The size of images in the original dataset before permutation. Used when constructing permutation operations. It must be provided in subclasses.

mean_original: tuple[float]

The mean values for normalisation. It must be provided in subclasses.

std_original: tuple[float]

The standard deviation values for normalisation. It must be provided in subclasses.

permutation_mode: str

Store the mode of permutation. Used when permutation operations used to construct tasks.

permutation_seeds: list[int]

Store the permutation seeds for all tasks. Use when permutation operations used to construct tasks.

permutation_seed_t: int

Store the permutation seed for the current task self.task_id.

permute_t: Permute

Store the permutation transform for the current task self.task_id.

def sanity_check(self) -> None:
378    def sanity_check(self) -> None:
379        """Check the sanity of the arguments.
380
381        **Raises:**
382        - **ValueError**: when the `permutation_seeds` is not equal to `num_tasks`.
383        """
384        if self.permutation_seeds and self.num_tasks != len(self.permutation_seeds):
385            raise ValueError(
386                "The number of permutation seeds is not equal to number of tasks!"
387            )
388
389        super().sanity_check()

Check the sanity of the arguments.

Raises:

def cl_class_map(self, task_id: int) -> dict[str | int, int]:
391    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
392        """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
393
394        **Args:**
395        - **task_id** (`int`): The task ID to query CL class map.
396
397        **Returns:**
398        - The CL class map of the task. Key is original class label, value is integer class label for continual learning.
399            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
400            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
401        """
402        if self.cl_paradigm == "TIL":
403            return {i: i for i in range(self.num_classes)}
404        if self.cl_paradigm == "CIL":
405            return {
406                i: i + (task_id - 1) * self.num_classes for i in range(self.num_classes)
407            }

The mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): The task ID to query CL class map.

Returns:

  • The CL class map of the task. Key is original class label, value is integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def setup_task_id(self, task_id: int) -> None:
409    def setup_task_id(self, task_id: int) -> None:
410        """Set up which task's dataset the CL experiment is on. This must be done before `setup()` method is called.
411
412        **Args:**
413        - **task_id** (`int`): the target task ID.
414        """
415        super().setup_task_id(task_id)
416
417        self.permutation_seed_t = self.permutation_seeds[task_id - 1]
418        self.permute_t = Permute(
419            img_size=self.img_size,
420            mode=self.permutation_mode,
421            seed=self.permutation_seed_t,
422        )

Set up which task's dataset the CL experiment is on. This must be done before setup() method is called.

Args:

  • task_id (int): the target task ID.
def mean(self, task_id: int) -> tuple[float]:
424    def mean(self, task_id: int) -> tuple[float]:
425        """The mean values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
426
427        **Returns:**
428        - The mean values for normalisation.
429        """
430        return self.mean_original

The mean values for normalisation of task task_id. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.

Returns:

  • The mean values for normalisation.
def std(self, task_id: int) -> tuple[float]:
432    def std(self, task_id: int) -> tuple[float]:
433        """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.
434
435        **Returns:**
436        - The standard deviation values for normalisation.
437        """
438        return self.std_original

The standard deviation values for normalisation of task task_id. Used when constructing the dataset. In permuted CL dataset, the mean values are the same as the original dataset.

Returns:

  • The standard deviation values for normalisation.
def train_and_val_transforms(self, to_tensor: bool) -> torchvision.transforms.transforms.Compose:
440    def train_and_val_transforms(self, to_tensor: bool) -> transforms.Compose:
441        """Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like `normalisation` and `ToTensor()`. In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset.
442
443        **Args:**
444        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
445
446        **Returns:**
447        - The composed training transforms.
448        """
449
450        return transforms.Compose(
451            list(
452                filter(
453                    None,
454                    [
455                        transforms.ToTensor() if to_tensor else None,
456                        self.permute_t,
457                        self.custom_transforms,
458                        transforms.Normalize(
459                            self.mean(self.task_id), self.std(self.task_id)
460                        ),
461                    ],
462                )
463            )
464        )  # the order of transforms matters

Transforms generator for train and validation dataset incorporating the custom transforms with basic transforms like normalisation and ToTensor(). In permuted CL datasets, permute transform also applies. It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • to_tensor (bool): whether to include ToTensor() transform.

Returns:

  • The composed training transforms.
def test_transforms(self, to_tensor: bool) -> torchvision.transforms.transforms.Compose:
466    def test_transforms(self, to_tensor: bool) -> transforms.Compose:
467        """Transforms generator for test dataset. Only basic transforms like `normalisation` and `ToTensor()` are included. It is a handy tool to use in subclasses when constructing the dataset.
468
469        **Args:**
470        - **to_tensor** (`bool`): whether to include `ToTensor()` transform.
471
472        **Returns:**
473        - The composed training transforms.
474        """
475
476        return transforms.Compose(
477            list(
478                filter(
479                    None,
480                    [
481                        transforms.ToTensor() if to_tensor else None,
482                        self.permute_t,
483                        transforms.Normalize(
484                            self.mean(self.task_id), self.std(self.task_id)
485                        ),
486                    ],
487                )
488            )
489        )  # the order of transforms matters

Transforms generator for test dataset. Only basic transforms like normalisation and ToTensor() are included. It is a handy tool to use in subclasses when constructing the dataset.

Args:

  • to_tensor (bool): whether to include ToTensor() transform.

Returns:

  • The composed training transforms.
class CLSplitDataset(clarena.cl_datasets.CLDataset):
492class CLSplitDataset(CLDataset):
493    """The base class of continual learning datasets, which are constructed as permutations from an original dataset, inherited from `CLDataset`."""
494
495    num_classes: int
496    """The number of classes in the original dataset before permutation. It must be provided in subclasses."""
497
498    mean_original: tuple[float]
499    """The mean values for normalisation. It must be provided in subclasses."""
500
501    std_original: tuple[float]
502    """The standard deviation values for normalisation. It must be provided in subclasses."""
503
504    def __init__(
505        self,
506        root: str,
507        num_tasks: int,
508        class_split: list[list[int]],
509        validation_percentage: float,
510        batch_size: int = 1,
511        num_workers: int = 10,
512        custom_transforms: Callable | transforms.Compose | None = None,
513        custom_target_transforms: Callable | transforms.Compose | None = None,
514    ):
515        """Initialise the CL dataset object providing the root where data files live.
516
517        **Args:**
518        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
519        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
520        - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
521        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
522        - **batch_size** (`int`): The batch size in train, val, test dataloader.
523        - **num_workers** (`int`): the number of workers for dataloaders.
524        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
525        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
526        """
527
528        self.class_split = class_split
529        """Store the class split for each task. Used when constructing the split dataset."""
530
531        super().__init__(
532            root=root,
533            num_tasks=num_tasks,
534            validation_percentage=validation_percentage,
535            batch_size=batch_size,
536            num_workers=num_workers,
537            custom_transforms=custom_transforms,
538            custom_target_transforms=custom_target_transforms,
539        )
540
541    def sanity_check(self) -> None:
542        """Check the sanity of the arguments.
543
544        **Raises:**
545        - **ValueError**: when the length of `class_split` is not equal to `num_tasks`.
546        - **ValueError**: when any of the lists in `class_split` has less than 2 elements. A classification task must have less than 2 classes.
547        """
548        if len(self.class_split) != self.num_tasks:
549            raise ValueError(
550                "The length of class split is not equal to number of tasks!"
551            )
552        if any(len(split) < 2 for split in self.class_split):
553            raise ValueError("Each class split must contain at least 2 elements!")
554
555        super().sanity_check()
556
557    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
558        """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
559
560        **Args:**
561        - **task_id** (`int`): The task ID to query CL class map.
562
563        **Returns:**
564        - The CL class map of the task. Key is original class label, value is integer class label for continual learning.
565            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
566            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
567        """
568        num_classes_t = len(self.class_split[task_id - 1])
569        if self.cl_paradigm == "TIL":
570            return {self.class_split[task_id - 1][i]: i for i in range(num_classes_t)}
571        if self.cl_paradigm == "CIL":
572            return {
573                self.class_split[task_id - 1][i]: i + (task_id - 1) * self.num_classes
574                for i in range(num_classes_t)
575            }
576
577    def mean(self, task_id: int) -> tuple[float]:
578        """The mean values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
579
580        **Returns:**
581        - The mean values for normalisation.
582        """
583        return self.mean_original
584
585    def std(self, task_id: int) -> tuple[float]:
586        """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
587
588        **Returns:**
589        - The standard deviation values for normalisation.
590        """
591        return self.std_original
592
593    def get_class_subset(self, dataset: Dataset, classes: list[int]) -> Dataset:
594        """Provide a util method here to retrieve a subset from PyTorch Dataset of classes. It could be useful when you constructing the split CL dataset.
595
596        **Args:**
597        - **dataset** (`Dataset`): the original dataset to retrieve subset from.
598        - **classes** (`list[int]`): the classes of the subset.
599
600        **Returns:**
601        - `Dataset`: subset of original dataset in classes.
602        """
603        # get the indices of the dataset that belong to the classes
604        idx = np.isin(dataset.targets, classes)
605
606        # subset the dataset by the indices, in-place operation
607        dataset.data = dataset.data[idx]
608        dataset.targets = dataset.targets[idx]
609
610        return dataset

The base class of continual learning datasets, which are constructed as permutations from an original dataset, inherited from CLDataset.

CLSplitDataset( root: str, num_tasks: int, class_split: list[list[int]], validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None)
504    def __init__(
505        self,
506        root: str,
507        num_tasks: int,
508        class_split: list[list[int]],
509        validation_percentage: float,
510        batch_size: int = 1,
511        num_workers: int = 10,
512        custom_transforms: Callable | transforms.Compose | None = None,
513        custom_target_transforms: Callable | transforms.Compose | None = None,
514    ):
515        """Initialise the CL dataset object providing the root where data files live.
516
517        **Args:**
518        - **root** (`str`): the root directory where the original data files for constructing the CL dataset physically live.
519        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
520        - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
521        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
522        - **batch_size** (`int`): The batch size in train, val, test dataloader.
523        - **num_workers** (`int`): the number of workers for dataloaders.
524        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalise, permute and so on are not included.
525        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
526        """
527
528        self.class_split = class_split
529        """Store the class split for each task. Used when constructing the split dataset."""
530
531        super().__init__(
532            root=root,
533            num_tasks=num_tasks,
534            validation_percentage=validation_percentage,
535            batch_size=batch_size,
536            num_workers=num_workers,
537            custom_transforms=custom_transforms,
538            custom_target_transforms=custom_target_transforms,
539        )

Initialise the CL dataset object providing the root where data files live.

Args:

  • root (str): the root directory where the original data files for constructing the CL dataset physically live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset.
  • class_split (list[list[int]]): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalise, permute and so on are not included.
  • custom_target_transforms (transform or transforms.Compose or None): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
num_classes: int

The number of classes in the original dataset before permutation. It must be provided in subclasses.

mean_original: tuple[float]

The mean values for normalisation. It must be provided in subclasses.

std_original: tuple[float]

The standard deviation values for normalisation. It must be provided in subclasses.

class_split

Store the class split for each task. Used when constructing the split dataset.

def sanity_check(self) -> None:
541    def sanity_check(self) -> None:
542        """Check the sanity of the arguments.
543
544        **Raises:**
545        - **ValueError**: when the length of `class_split` is not equal to `num_tasks`.
546        - **ValueError**: when any of the lists in `class_split` has less than 2 elements. A classification task must have less than 2 classes.
547        """
548        if len(self.class_split) != self.num_tasks:
549            raise ValueError(
550                "The length of class split is not equal to number of tasks!"
551            )
552        if any(len(split) < 2 for split in self.class_split):
553            raise ValueError("Each class split must contain at least 2 elements!")
554
555        super().sanity_check()

Check the sanity of the arguments.

Raises:

  • ValueError: when the length of class_split is not equal to num_tasks.
  • ValueError: when any of the lists in class_split has less than 2 elements. A classification task must have less than 2 classes.
def cl_class_map(self, task_id: int) -> dict[str | int, int]:
557    def cl_class_map(self, task_id: int) -> dict[str | int, int]:
558        """The mapping of classes of task `task_id` to fit continual learning settings `self.cl_paradigm`.
559
560        **Args:**
561        - **task_id** (`int`): The task ID to query CL class map.
562
563        **Returns:**
564        - The CL class map of the task. Key is original class label, value is integer class label for continual learning.
565            - If `self.cl_paradigm` is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
566            - If `self.cl_paradigm` is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
567        """
568        num_classes_t = len(self.class_split[task_id - 1])
569        if self.cl_paradigm == "TIL":
570            return {self.class_split[task_id - 1][i]: i for i in range(num_classes_t)}
571        if self.cl_paradigm == "CIL":
572            return {
573                self.class_split[task_id - 1][i]: i + (task_id - 1) * self.num_classes
574                for i in range(num_classes_t)
575            }

The mapping of classes of task task_id to fit continual learning settings self.cl_paradigm.

Args:

  • task_id (int): The task ID to query CL class map.

Returns:

  • The CL class map of the task. Key is original class label, value is integer class label for continual learning.
    • If self.cl_paradigm is 'TIL', the mapped class labels of a task should be continuous integers from 0 to the number of classes.
    • If self.cl_paradigm is 'CIL', the mapped class labels of a task should be continuous integers from the number of classes of previous tasks to the number of classes of the current task.
def mean(self, task_id: int) -> tuple[float]:
577    def mean(self, task_id: int) -> tuple[float]:
578        """The mean values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
579
580        **Returns:**
581        - The mean values for normalisation.
582        """
583        return self.mean_original

The mean values for normalisation of task task_id. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.

Returns:

  • The mean values for normalisation.
def std(self, task_id: int) -> tuple[float]:
585    def std(self, task_id: int) -> tuple[float]:
586        """The standard deviation values for normalisation of task `task_id`. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.
587
588        **Returns:**
589        - The standard deviation values for normalisation.
590        """
591        return self.std_original

The standard deviation values for normalisation of task task_id. Used when constructing the dataset. In split CL dataset, the mean values are the same as the original dataset.

Returns:

  • The standard deviation values for normalisation.
def get_class_subset( self, dataset: torch.utils.data.dataset.Dataset, classes: list[int]) -> torch.utils.data.dataset.Dataset:
593    def get_class_subset(self, dataset: Dataset, classes: list[int]) -> Dataset:
594        """Provide a util method here to retrieve a subset from PyTorch Dataset of classes. It could be useful when you constructing the split CL dataset.
595
596        **Args:**
597        - **dataset** (`Dataset`): the original dataset to retrieve subset from.
598        - **classes** (`list[int]`): the classes of the subset.
599
600        **Returns:**
601        - `Dataset`: subset of original dataset in classes.
602        """
603        # get the indices of the dataset that belong to the classes
604        idx = np.isin(dataset.targets, classes)
605
606        # subset the dataset by the indices, in-place operation
607        dataset.data = dataset.data[idx]
608        dataset.targets = dataset.targets[idx]
609
610        return dataset

Provide a util method here to retrieve a subset from PyTorch Dataset of classes. It could be useful when you constructing the split CL dataset.

Args:

  • dataset (Dataset): the original dataset to retrieve subset from.
  • classes (list[int]): the classes of the subset.

Returns:

  • Dataset: subset of original dataset in classes.
class CLClassMapping:
613class CLClassMapping:
614    """CL Class mapping to dataset labels. Used as a PyTorch target Transform."""
615
616    def __init__(self, cl_class_map: dict[str | int, int]):
617        """Initialise the CL class mapping transform object from the CL class map of a task.
618
619        **Args:**
620        - **cl_class_map** (`dict[str | int, int]`): the CL class map for a task.
621        """
622        self.cl_class_map = cl_class_map
623
624    def __call__(self, target: torch.Tensor) -> torch.Tensor:
625        """The CL class mapping transform to dataset labels. It is defined as a callable object like a PyTorch Transform.
626
627        **Args:**
628        - **target** (`Tensor`): the target tensor.
629
630        **Returns:**
631        - The transformed target tensor.
632        """
633
634        return self.cl_class_map[target]

CL Class mapping to dataset labels. Used as a PyTorch target Transform.

CLClassMapping(cl_class_map: dict[str | int, int])
616    def __init__(self, cl_class_map: dict[str | int, int]):
617        """Initialise the CL class mapping transform object from the CL class map of a task.
618
619        **Args:**
620        - **cl_class_map** (`dict[str | int, int]`): the CL class map for a task.
621        """
622        self.cl_class_map = cl_class_map

Initialise the CL class mapping transform object from the CL class map of a task.

Args:

  • cl_class_map (dict[str | int, int]): the CL class map for a task.
cl_class_map
class Permute:
637class Permute:
638    """Permutation operation to image. Used to construct permuted CL dataset.
639
640    Used as a PyTorch Dataset Transform.
641    """
642
643    def __init__(
644        self,
645        img_size: torch.Size,
646        mode: str = "first_channel_only",
647        seed: int | None = None,
648    ):
649        """Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime.
650
651        **Args:**
652        - **img_size** (`torch.Size`): the size of the image to be permuted.
653        - **mode** (`str`): the mode of permutation, shouble be one of the following:
654            - 'all': permute all pixels.
655            - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
656            - 'first_channel_only': permute only the first channel.
657        - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
658        """
659        self.mode = mode
660        """Store the mode of permutation."""
661
662        # get generator for permutation
663        torch_generator = torch.Generator()
664        if seed:
665            torch_generator.manual_seed(seed)
666
667        # calculate the number of pixels from the image size
668        if self.mode == "all":
669            num_pixels = img_size[0] * img_size[1] * img_size[2]
670        elif self.mode == "by_channel" or "first_channel_only":
671            num_pixels = img_size[1] * img_size[2]
672
673        self.permute: torch.Tensor = torch.randperm(
674            num_pixels, generator=torch_generator
675        )
676        """The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""
677
678    def __call__(self, img: torch.Tensor) -> torch.Tensor:
679        """The permutation operation to image is defined as a callable object like a PyTorch Transform.
680
681        **Args:**
682        - **img** (`Tensor`): image to be permuted. Must match the size of `img_size` in the initialisation.
683
684        **Returns:**
685        - The permuted image (`Tensor`).
686        """
687
688        if self.mode == "all":
689
690            img_flat = img.view(
691                -1
692            )  # flatten the whole image to 1d so that it can be applied 1d permuted order
693            img_flat_permuted = img_flat[self.permute]  # conduct permutation operation
694            img_permuted = img_flat_permuted.view(
695                img.size()
696            )  # return to the original image shape
697            return img_permuted
698
699        if self.mode == "by_channel":
700
701            permuted_channels = []
702            for i in range(img.size(0)):
703                # act on every channel
704                channel_flat = img[i].view(
705                    -1
706                )  # flatten the channel to 1d so that it can be applied 1d permuted order
707                channel_flat_permuted = channel_flat[
708                    self.permute
709                ]  # conduct permutation operation
710                channel_permuted = channel_flat_permuted.view(
711                    img[0].size()
712                )  # return to the original channel shape
713                permuted_channels.append(channel_permuted)
714            img_permuted = torch.stack(
715                permuted_channels
716            )  # stack the permuted channels to restore the image
717            return img_permuted
718
719        if self.mode == "first_channel_only":
720
721            first_channel_flat = img[0].view(
722                -1
723            )  # flatten the first channel to 1d so that it can be applied 1d permuted order
724            first_channel_flat_permuted = first_channel_flat[
725                self.permute
726            ]  # conduct permutation operation
727            first_channel_permuted = first_channel_flat_permuted.view(
728                img[0].size()
729            )  # return to the original channel shape
730
731            img_permuted = img.clone()
732            img_permuted[0] = first_channel_permuted
733
734            return img_permuted

Permutation operation to image. Used to construct permuted CL dataset.

Used as a PyTorch Dataset Transform.

Permute( img_size: torch.Size, mode: str = 'first_channel_only', seed: int | None = None)
643    def __init__(
644        self,
645        img_size: torch.Size,
646        mode: str = "first_channel_only",
647        seed: int | None = None,
648    ):
649        """Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime.
650
651        **Args:**
652        - **img_size** (`torch.Size`): the size of the image to be permuted.
653        - **mode** (`str`): the mode of permutation, shouble be one of the following:
654            - 'all': permute all pixels.
655            - 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
656            - 'first_channel_only': permute only the first channel.
657        - **seed** (`int` or `None`): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
658        """
659        self.mode = mode
660        """Store the mode of permutation."""
661
662        # get generator for permutation
663        torch_generator = torch.Generator()
664        if seed:
665            torch_generator.manual_seed(seed)
666
667        # calculate the number of pixels from the image size
668        if self.mode == "all":
669            num_pixels = img_size[0] * img_size[1] * img_size[2]
670        elif self.mode == "by_channel" or "first_channel_only":
671            num_pixels = img_size[1] * img_size[2]
672
673        self.permute: torch.Tensor = torch.randperm(
674            num_pixels, generator=torch_generator
675        )
676        """The permutation order, a `Tensor` permuted from [1,2, ..., `num_pixels`] with the given seed. It is the core element of permutation operation."""

Initialise the Permute transform object. The permutation order is constructed in the initialisation to save runtime.

Args:

  • img_size (torch.Size): the size of the image to be permuted.
  • mode (str): the mode of permutation, shouble be one of the following:
    • 'all': permute all pixels.
    • 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    • 'first_channel_only': permute only the first channel.
  • seed (int or None): seed for permutation operation. If None, the permutation will use a default seed from PyTorch generator.
mode

Store the mode of permutation.

permute: torch.Tensor

The permutation order, a Tensor permuted from [1,2, ..., num_pixels] with the given seed. It is the core element of permutation operation.