clarena.mtl_datasets

Multi-Task Learning Datasets

This submodule provides the multi-task learning datasets that can be used in CLArena.

Here are the base classes for multi-task learning datasets, which inherit from Lightning LightningDataModule:

Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement MTL datasets:

 1r"""
 2
 3# Multi-Task Learning Datasets
 4
 5This submodule provides the **multi-task learning datasets** that can be used in CLArena.
 6
 7Here are the base classes for multi-task learning datasets, which inherit from Lightning `LightningDataModule`:
 8
 9- `MTLDataset`: The base class for all multi-task learning datasets.
10    - `MTLCombinedDataset`: The base class for combined multi-task learning datasets. A child class of `MTLDataset`.
11    - `MTLDatasetFromCL`: The base class for constructing multi-task learning datasets from continual learning datasets. A child class of `MTLDataset`.
12
13Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement MTL datasets:
14
15- [**Configure MTL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/mtl-dataset)
16- [**Implement Custom MTL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/mtl_dataset)
17
18
19
20"""
21
22from .base import MTLDataset, MTLCombinedDataset, MTLDatasetFromCL
23
24from .combined import Combined
25
26
27__all__ = ["MTLDataset", "MTLCombinedDataset", "MTLDatasetFromCL", "combined"]
class MTLDataset(lightning.pytorch.core.datamodule.LightningDataModule):
 32class MTLDataset(LightningDataModule):
 33    r"""The base class of multi-task learning datasets."""
 34
 35    def __init__(
 36        self,
 37        root: str | dict[int, str],
 38        num_tasks: int,
 39        sampling_strategy: str = "mixed",
 40        batch_size: int = 1,
 41        num_workers: int = 0,
 42        custom_transforms: (
 43            Callable
 44            | transforms.Compose
 45            | None
 46            | dict[int, Callable | transforms.Compose | None]
 47        ) = None,
 48        repeat_channels: int | None | dict[int, int | None] = None,
 49        to_tensor: bool | dict[int, bool] = True,
 50        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 51    ) -> None:
 52        r"""
 53        **Args:**
 54        - **root** (`str` | `list[str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `list[str]`, it should be a list of strings, each string is the root directory for each task.
 55        - **num_tasks** (`int`): the maximum number of tasks supported by the MTL dataset.
 56        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
 57            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
 58        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 59        - **num_workers** (`int`): the number of workers for dataloaders.
 60        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included.
 61        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 62        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 63        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 64        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 65        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 66        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 67        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 68        """
 69        super().__init__()
 70
 71        self.root: dict[int, str] = (
 72            OmegaConf.to_container(root)
 73            if isinstance(root, DictConfig)
 74            else {t: root for t in range(1, num_tasks + 1)}
 75        )
 76        r"""The dict of root directories of the original data files for each task."""
 77        self.num_tasks: int = num_tasks
 78        r"""The maximum number of tasks supported by the dataset."""
 79        self.sampling_strategy: str = sampling_strategy
 80        r"""The sampling strategy for constructing training batch from each task's dataset."""
 81        self.batch_size: int = batch_size
 82        r"""The batch size for dataloaders."""
 83        self.num_workers: int = num_workers
 84        r"""The number of workers for dataloaders."""
 85
 86        self.custom_transforms: dict[int, Callable | transforms.Compose | None] = (
 87            OmegaConf.to_container(custom_transforms)
 88            if isinstance(custom_transforms, dict)
 89            else {t: custom_transforms for t in range(1, num_tasks + 1)}
 90        )
 91        r"""The dict of custom transforms for each task."""
 92        self.repeat_channels: dict[int, int | None] = (
 93            OmegaConf.to_container(repeat_channels)
 94            if isinstance(repeat_channels, dict)
 95            else {t: repeat_channels for t in range(1, num_tasks + 1)}
 96        )
 97        r"""The dict of number of channels to repeat for each task."""
 98        self.to_tensor: dict[int, bool] = (
 99            OmegaConf.to_container(to_tensor)
100            if isinstance(to_tensor, dict)
101            else {t: to_tensor for t in range(1, num_tasks + 1)}
102        )
103        r"""The dict of to_tensor flag for each task. """
104        self.resize: dict[int, tuple[int, int] | None] = (
105            {t: tuple(rs) if rs else None for t, rs in resize.items()}
106            if isinstance(resize, DictConfig)
107            else {
108                t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1)
109            }
110        )
111        r"""The dict of sizes to resize to for each task."""
112
113        # dataset containers
114        self.dataset_train: dict[int, Any] = {}
115        r"""The dictionary to store training dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects. 
116        
117        Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary."""
118        self.dataset_val: dict[int, Any] = {}
119        r"""The dictionary to store validation dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
120        
121        Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary."""
122        self.dataset_test: dict[int, Any] = {}
123        r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
124        
125        Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary."""
126
127        self.mean: dict[int, float] = {}
128        r"""Tthe list of mean values for normalization for all tasks. Used when constructing the transforms."""
129        self.std: dict[int, float] = {}
130        r"""The list of standard deviation values for normalization for all tasks. Used when constructing the transforms."""
131
132        # task ID controls
133        self.train_tasks: list[int]
134        r""""The list of task IDs to be trained. It should be a list of integers, each integer is the task ID."""
135        self.eval_tasks: list[int]
136        r"""The list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID."""
137
138        MTLDataset.sanity_check(self)
139
140    def sanity_check(self) -> None:
141        r"""Sanity check."""
142        for attr in [
143            "root",
144            "custom_transforms",
145            "repeat_channels",
146            "to_tensor",
147            "resize",
148        ]:
149            value = getattr(self, attr)
150            expected_keys = set(range(1, self.num_tasks + 1))
151            if set(value.keys()) != expected_keys:
152                raise ValueError(
153                    f"{attr} dict keys must be consecutive integers from 1 to num_tasks."
154                )
155
156    @abstractmethod
157    def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
158        r"""Get the mapping of classes of task `task_id` to fit multi-task learning. It must be implemented by subclasses.
159
160        **Args:**
161        - **task_id** (`int`): The task ID to query class map.
162
163        **Returns:**
164        - **class_map**(`dict[str | int, int]`): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes.
165        """
166
167    @abstractmethod
168    def prepare_data(self) -> None:
169        r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDatamodule`."""
170
171    def setup(self, stage: str) -> None:
172        r"""Set up the dataset for different stages.
173
174        **Args:**
175        - **stage** (`str`): the stage of the experiment; one of:
176            - 'fit': training and validation dataset should be assigned to `self.dataset_train` and `self.dataset_val`.
177            - 'test': test dataset should be assigned to `self.dataset_test`.
178        """
179        if stage == "fit":
180            # these two stages must be done together because a sanity check for validation is conducted before training
181            pylogger.debug("Construct train and validation dataset ...")
182
183            for task_id in self.train_tasks:
184
185                self.dataset_train[task_id], self.dataset_val[task_id] = (
186                    self.train_and_val_dataset(task_id)
187                )
188
189                pylogger.info(
190                    "Train and validation dataset for task %d are ready.", task_id
191                )
192                pylogger.info(
193                    "Train dataset for task %d size: %d",
194                    task_id,
195                    len(self.dataset_train[task_id]),
196                )
197                pylogger.info(
198                    "Validation dataset for task %d size: %d",
199                    task_id,
200                    len(self.dataset_val[task_id]),
201                )
202
203        elif stage == "test":
204
205            pylogger.debug("Construct test dataset ...")
206
207            for task_id in self.eval_tasks:
208
209                self.dataset_test[task_id] = self.test_dataset(task_id)
210
211                pylogger.info("Test dataset for task %d are ready.", task_id)
212                pylogger.info(
213                    "Test dataset for task %d size: %d",
214                    task_id,
215                    len(self.dataset_test[task_id]),
216                )
217
218    def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
219        r"""Set up tasks for the multi-task learning experiment.
220
221        **Args:**
222        - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the train/val dataloader.
223        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the test dataloader.
224        """
225        self.train_tasks = train_tasks
226        self.eval_tasks = eval_tasks
227
228    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
229        r"""Set up evaluation tasks for the multi-task learning evaluation.
230
231        **Args:**
232        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated."""
233        self.eval_tasks = eval_tasks
234
235    def train_and_val_transforms(self, task_id: int) -> transforms.Compose:
236        r"""Transforms for train and validation datasets of task `task_id`, incorporating the custom transforms with basic transforms like `normalization` and `ToTensor()`. It can be used in subclasses when constructing the dataset.
237
238        **Args:**
239        - **task_id** (`int`): the task ID of training and validation dataset to get the transforms for.
240
241        **Returns:**
242        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
243        """
244        repeat_channels_transform = (
245            transforms.Grayscale(num_output_channels=self.repeat_channels[task_id])
246            if self.repeat_channels[task_id] is not None
247            else None
248        )
249        to_tensor_transform = transforms.ToTensor() if self.to_tensor[task_id] else None
250        resize_transform = (
251            transforms.Resize(self.resize[task_id])
252            if self.resize[task_id] is not None
253            else None
254        )
255        normalization_transform = transforms.Normalize(
256            self.mean[task_id], self.std[task_id]
257        )
258
259        return transforms.Compose(
260            list(
261                filter(
262                    None,
263                    [
264                        repeat_channels_transform,
265                        to_tensor_transform,
266                        resize_transform,
267                        self.custom_transforms[task_id],
268                        normalization_transform,
269                    ],
270                )
271            )
272        )  # the order of transforms matters
273
274    def test_transforms(self, task_id: int) -> transforms.Compose:
275        r"""Transforms for test dataset of task `task_id`. Only basic transforms like `normalization` and `ToTensor()` are included. It can be used in subclasses when constructing the dataset.
276
277        **Args:**
278        - **task_id** (`int`): the task ID of test dataset to get the transforms for.
279
280        **Returns:**
281        - **test_transforms** (`transforms.Compose`): the composed test transforms.
282        """
283
284        repeat_channels_transform = (
285            transforms.Grayscale(num_output_channels=self.repeat_channels[task_id])
286            if self.repeat_channels[task_id] is not None
287            else None
288        )
289        to_tensor_transform = transforms.ToTensor() if self.to_tensor[task_id] else None
290        resize_transform = (
291            transforms.Resize(self.resize[task_id])
292            if self.resize[task_id] is not None
293            else None
294        )
295        normalization_transform = transforms.Normalize(
296            self.mean[task_id], self.std[task_id]
297        )
298
299        return transforms.Compose(
300            list(
301                filter(
302                    None,
303                    [
304                        repeat_channels_transform,
305                        to_tensor_transform,
306                        resize_transform,
307                        normalization_transform,
308                    ],
309                )
310            )
311        )  # the order of transforms matters. No custom transforms for test
312
313    def target_transform(self, task_id: int) -> Callable:
314        r"""Target transform for task `task_id`, which maps the original class labels to the integer class labels for multi-task learning. It can be used in subclasses when constructing the dataset.
315
316        **Args:**
317        - **task_id** (`int`): the task ID of dataset to get the target transform for.
318
319        **Returns:**
320        - **target_transform** (`Callable`): the target transform function.
321        """
322
323        class_map = self.get_mtl_class_map(task_id)
324
325        target_transform = ClassMapping(class_map=class_map)
326        return target_transform
327
328    @abstractmethod
329    def train_and_val_dataset(self, task_id: int) -> tuple[Any, Any]:
330        r"""Get the training and validation dataset of task `task_id`. It must be implemented by subclasses.
331
332        **Args:**
333        - **task_id** (`int`): the task ID to get the training and validation dataset for.
334
335        **Returns:**
336        - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation dataset of task `task_id`.
337        """
338
339    @abstractmethod
340    def test_dataset(self, task_id: int) -> Any:
341        """Get the test dataset of task `task_id`. It must be implemented by subclasses.
342
343        **Args:**
344        - **task_id** (`int`): the task ID to get the test dataset for.
345
346        **Returns:**
347        - **test_dataset** (`Any`): the test dataset of task `task_id`.
348        """
349
350    def train_dataloader(self) -> DataLoader:
351        r"""DataLoader generator for stage train. It is automatically called before training.
352
353        **Returns:**
354        - **train_dataloader** (`DataLoader`): the train DataLoader of task.
355        """
356
357        pylogger.debug(
358            "Construct train dataloader ... sampling_strategy method: %s",
359            self.sampling_strategy,
360        )
361
362        if self.sampling_strategy == "mixed":
363            # mixed sampling strategy, which samples from all tasks' datasets
364
365            concatenated_dataset = ConcatDataset(
366                [self.dataset_train[task_id] for task_id in self.train_tasks]
367            )
368
369            return DataLoader(
370                dataset=concatenated_dataset,
371                batch_size=self.batch_size,
372                shuffle=True,  # shuffle train batch to prevent overfitting
373                num_workers=self.num_workers,
374                drop_last=True,  # to avoid batchnorm error (when batch_size is 1)
375            )
376
377    def val_dataloader(self) -> DataLoader:
378        r"""DataLoader generator for the validation stage. It is automatically called before validation.
379
380        **Returns:**
381        - **val_dataloader** (`dict[int, DataLoader]`): the validation DataLoader.
382        """
383
384        pylogger.debug("Construct validation dataloader...")
385
386        return {
387            task_id: DataLoader(
388                dataset=dataset_val_t,
389                batch_size=self.batch_size,
390                shuffle=False,  # don't have to shuffle val or test batch
391                num_workers=self.num_workers,
392            )
393            for task_id, dataset_val_t in self.dataset_val.items()
394        }
395
396    def test_dataloader(self) -> dict[int, DataLoader]:
397        r"""DataLoader generator for stage test. It is automatically called before testing.
398
399        **Returns:**
400        - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader.
401        """
402
403        pylogger.debug("Construct test dataloader...")
404
405        return {
406            task_id: DataLoader(
407                dataset=dataset_test_t,
408                batch_size=self.batch_size,
409                shuffle=False,  # don't have to shuffle val or test batch
410                num_workers=self.num_workers,
411            )
412            for task_id, dataset_test_t in self.dataset_test.items()
413        }

The base class of multi-task learning datasets.

MTLDataset( root: str | dict[int, str], num_tasks: int, sampling_strategy: str = 'mixed', batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
 35    def __init__(
 36        self,
 37        root: str | dict[int, str],
 38        num_tasks: int,
 39        sampling_strategy: str = "mixed",
 40        batch_size: int = 1,
 41        num_workers: int = 0,
 42        custom_transforms: (
 43            Callable
 44            | transforms.Compose
 45            | None
 46            | dict[int, Callable | transforms.Compose | None]
 47        ) = None,
 48        repeat_channels: int | None | dict[int, int | None] = None,
 49        to_tensor: bool | dict[int, bool] = True,
 50        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 51    ) -> None:
 52        r"""
 53        **Args:**
 54        - **root** (`str` | `list[str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `list[str]`, it should be a list of strings, each string is the root directory for each task.
 55        - **num_tasks** (`int`): the maximum number of tasks supported by the MTL dataset.
 56        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
 57            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
 58        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 59        - **num_workers** (`int`): the number of workers for dataloaders.
 60        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included.
 61        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 62        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 63        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 64        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 65        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 66        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 67        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 68        """
 69        super().__init__()
 70
 71        self.root: dict[int, str] = (
 72            OmegaConf.to_container(root)
 73            if isinstance(root, DictConfig)
 74            else {t: root for t in range(1, num_tasks + 1)}
 75        )
 76        r"""The dict of root directories of the original data files for each task."""
 77        self.num_tasks: int = num_tasks
 78        r"""The maximum number of tasks supported by the dataset."""
 79        self.sampling_strategy: str = sampling_strategy
 80        r"""The sampling strategy for constructing training batch from each task's dataset."""
 81        self.batch_size: int = batch_size
 82        r"""The batch size for dataloaders."""
 83        self.num_workers: int = num_workers
 84        r"""The number of workers for dataloaders."""
 85
 86        self.custom_transforms: dict[int, Callable | transforms.Compose | None] = (
 87            OmegaConf.to_container(custom_transforms)
 88            if isinstance(custom_transforms, dict)
 89            else {t: custom_transforms for t in range(1, num_tasks + 1)}
 90        )
 91        r"""The dict of custom transforms for each task."""
 92        self.repeat_channels: dict[int, int | None] = (
 93            OmegaConf.to_container(repeat_channels)
 94            if isinstance(repeat_channels, dict)
 95            else {t: repeat_channels for t in range(1, num_tasks + 1)}
 96        )
 97        r"""The dict of number of channels to repeat for each task."""
 98        self.to_tensor: dict[int, bool] = (
 99            OmegaConf.to_container(to_tensor)
100            if isinstance(to_tensor, dict)
101            else {t: to_tensor for t in range(1, num_tasks + 1)}
102        )
103        r"""The dict of to_tensor flag for each task. """
104        self.resize: dict[int, tuple[int, int] | None] = (
105            {t: tuple(rs) if rs else None for t, rs in resize.items()}
106            if isinstance(resize, DictConfig)
107            else {
108                t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1)
109            }
110        )
111        r"""The dict of sizes to resize to for each task."""
112
113        # dataset containers
114        self.dataset_train: dict[int, Any] = {}
115        r"""The dictionary to store training dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects. 
116        
117        Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary."""
118        self.dataset_val: dict[int, Any] = {}
119        r"""The dictionary to store validation dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
120        
121        Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary."""
122        self.dataset_test: dict[int, Any] = {}
123        r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
124        
125        Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary."""
126
127        self.mean: dict[int, float] = {}
128        r"""Tthe list of mean values for normalization for all tasks. Used when constructing the transforms."""
129        self.std: dict[int, float] = {}
130        r"""The list of standard deviation values for normalization for all tasks. Used when constructing the transforms."""
131
132        # task ID controls
133        self.train_tasks: list[int]
134        r""""The list of task IDs to be trained. It should be a list of integers, each integer is the task ID."""
135        self.eval_tasks: list[int]
136        r"""The list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID."""
137
138        MTLDataset.sanity_check(self)

Args:

  • root (str | list[str]): the root directory where the original data files for constructing the MTL dataset physically live. If list[str], it should be a list of strings, each string is the root directory for each task.
  • num_tasks (int): the maximum number of tasks supported by the MTL dataset.
  • sampling_strategy (str): the sampling strategy that construct training batch from each task's dataset; one of:
    • 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
root: dict[int, str]

The dict of root directories of the original data files for each task.

num_tasks: int

The maximum number of tasks supported by the dataset.

sampling_strategy: str

The sampling strategy for constructing training batch from each task's dataset.

batch_size: int

The batch size for dataloaders.

num_workers: int

The number of workers for dataloaders.

custom_transforms: dict[int, typing.Union[typing.Callable, torchvision.transforms.transforms.Compose, NoneType]]

The dict of custom transforms for each task.

repeat_channels: dict[int, int | None]

The dict of number of channels to repeat for each task.

to_tensor: dict[int, bool]

The dict of to_tensor flag for each task.

resize: dict[int, tuple[int, int] | None]

The dict of sizes to resize to for each task.

dataset_train: dict[int, typing.Any]

The dictionary to store training dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.

Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use TaskLabelledDataset wrapper if necessary.

dataset_val: dict[int, typing.Any]

The dictionary to store validation dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.

Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use TaskLabelledDataset wrapper if necessary.

dataset_test: dict[int, typing.Any]

The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.

Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use TaskLabelledDataset wrapper if necessary.

mean: dict[int, float]

Tthe list of mean values for normalization for all tasks. Used when constructing the transforms.

std: dict[int, float]

The list of standard deviation values for normalization for all tasks. Used when constructing the transforms.

train_tasks: list[int]

"The list of task IDs to be trained. It should be a list of integers, each integer is the task ID.

eval_tasks: list[int]

The list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID.

def sanity_check(self) -> None:
140    def sanity_check(self) -> None:
141        r"""Sanity check."""
142        for attr in [
143            "root",
144            "custom_transforms",
145            "repeat_channels",
146            "to_tensor",
147            "resize",
148        ]:
149            value = getattr(self, attr)
150            expected_keys = set(range(1, self.num_tasks + 1))
151            if set(value.keys()) != expected_keys:
152                raise ValueError(
153                    f"{attr} dict keys must be consecutive integers from 1 to num_tasks."
154                )

Sanity check.

@abstractmethod
def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
156    @abstractmethod
157    def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
158        r"""Get the mapping of classes of task `task_id` to fit multi-task learning. It must be implemented by subclasses.
159
160        **Args:**
161        - **task_id** (`int`): The task ID to query class map.
162
163        **Returns:**
164        - **class_map**(`dict[str | int, int]`): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes.
165        """

Get the mapping of classes of task task_id to fit multi-task learning. It must be implemented by subclasses.

Args:

  • task_id (int): The task ID to query class map.

Returns:

  • class_map(dict[str | int, int]): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes.
@abstractmethod
def prepare_data(self) -> None:
167    @abstractmethod
168    def prepare_data(self) -> None:
169        r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDatamodule`."""

Use this to download and prepare data. It must be implemented by subclasses, as required by LightningDatamodule.

def setup(self, stage: str) -> None:
171    def setup(self, stage: str) -> None:
172        r"""Set up the dataset for different stages.
173
174        **Args:**
175        - **stage** (`str`): the stage of the experiment; one of:
176            - 'fit': training and validation dataset should be assigned to `self.dataset_train` and `self.dataset_val`.
177            - 'test': test dataset should be assigned to `self.dataset_test`.
178        """
179        if stage == "fit":
180            # these two stages must be done together because a sanity check for validation is conducted before training
181            pylogger.debug("Construct train and validation dataset ...")
182
183            for task_id in self.train_tasks:
184
185                self.dataset_train[task_id], self.dataset_val[task_id] = (
186                    self.train_and_val_dataset(task_id)
187                )
188
189                pylogger.info(
190                    "Train and validation dataset for task %d are ready.", task_id
191                )
192                pylogger.info(
193                    "Train dataset for task %d size: %d",
194                    task_id,
195                    len(self.dataset_train[task_id]),
196                )
197                pylogger.info(
198                    "Validation dataset for task %d size: %d",
199                    task_id,
200                    len(self.dataset_val[task_id]),
201                )
202
203        elif stage == "test":
204
205            pylogger.debug("Construct test dataset ...")
206
207            for task_id in self.eval_tasks:
208
209                self.dataset_test[task_id] = self.test_dataset(task_id)
210
211                pylogger.info("Test dataset for task %d are ready.", task_id)
212                pylogger.info(
213                    "Test dataset for task %d size: %d",
214                    task_id,
215                    len(self.dataset_test[task_id]),
216                )

Set up the dataset for different stages.

Args:

  • stage (str): the stage of the experiment; one of:
    • 'fit': training and validation dataset should be assigned to self.dataset_train and self.dataset_val.
    • 'test': test dataset should be assigned to self.dataset_test.
def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
218    def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
219        r"""Set up tasks for the multi-task learning experiment.
220
221        **Args:**
222        - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the train/val dataloader.
223        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the test dataloader.
224        """
225        self.train_tasks = train_tasks
226        self.eval_tasks = eval_tasks

Set up tasks for the multi-task learning experiment.

Args:

  • train_tasks (list[int]): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the train/val dataloader.
  • eval_tasks (list[int]): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the test dataloader.
def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
228    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
229        r"""Set up evaluation tasks for the multi-task learning evaluation.
230
231        **Args:**
232        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated."""
233        self.eval_tasks = eval_tasks

Set up evaluation tasks for the multi-task learning evaluation.

Args:

  • eval_tasks (list[int]): the list of task IDs to be evaluated.
def train_and_val_transforms(self, task_id: int) -> torchvision.transforms.transforms.Compose:
235    def train_and_val_transforms(self, task_id: int) -> transforms.Compose:
236        r"""Transforms for train and validation datasets of task `task_id`, incorporating the custom transforms with basic transforms like `normalization` and `ToTensor()`. It can be used in subclasses when constructing the dataset.
237
238        **Args:**
239        - **task_id** (`int`): the task ID of training and validation dataset to get the transforms for.
240
241        **Returns:**
242        - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms.
243        """
244        repeat_channels_transform = (
245            transforms.Grayscale(num_output_channels=self.repeat_channels[task_id])
246            if self.repeat_channels[task_id] is not None
247            else None
248        )
249        to_tensor_transform = transforms.ToTensor() if self.to_tensor[task_id] else None
250        resize_transform = (
251            transforms.Resize(self.resize[task_id])
252            if self.resize[task_id] is not None
253            else None
254        )
255        normalization_transform = transforms.Normalize(
256            self.mean[task_id], self.std[task_id]
257        )
258
259        return transforms.Compose(
260            list(
261                filter(
262                    None,
263                    [
264                        repeat_channels_transform,
265                        to_tensor_transform,
266                        resize_transform,
267                        self.custom_transforms[task_id],
268                        normalization_transform,
269                    ],
270                )
271            )
272        )  # the order of transforms matters

Transforms for train and validation datasets of task task_id, incorporating the custom transforms with basic transforms like normalization and ToTensor(). It can be used in subclasses when constructing the dataset.

Args:

  • task_id (int): the task ID of training and validation dataset to get the transforms for.

Returns:

  • train_and_val_transforms (transforms.Compose): the composed train/val transforms.
def test_transforms(self, task_id: int) -> torchvision.transforms.transforms.Compose:
274    def test_transforms(self, task_id: int) -> transforms.Compose:
275        r"""Transforms for test dataset of task `task_id`. Only basic transforms like `normalization` and `ToTensor()` are included. It can be used in subclasses when constructing the dataset.
276
277        **Args:**
278        - **task_id** (`int`): the task ID of test dataset to get the transforms for.
279
280        **Returns:**
281        - **test_transforms** (`transforms.Compose`): the composed test transforms.
282        """
283
284        repeat_channels_transform = (
285            transforms.Grayscale(num_output_channels=self.repeat_channels[task_id])
286            if self.repeat_channels[task_id] is not None
287            else None
288        )
289        to_tensor_transform = transforms.ToTensor() if self.to_tensor[task_id] else None
290        resize_transform = (
291            transforms.Resize(self.resize[task_id])
292            if self.resize[task_id] is not None
293            else None
294        )
295        normalization_transform = transforms.Normalize(
296            self.mean[task_id], self.std[task_id]
297        )
298
299        return transforms.Compose(
300            list(
301                filter(
302                    None,
303                    [
304                        repeat_channels_transform,
305                        to_tensor_transform,
306                        resize_transform,
307                        normalization_transform,
308                    ],
309                )
310            )
311        )  # the order of transforms matters. No custom transforms for test

Transforms for test dataset of task task_id. Only basic transforms like normalization and ToTensor() are included. It can be used in subclasses when constructing the dataset.

Args:

  • task_id (int): the task ID of test dataset to get the transforms for.

Returns:

  • test_transforms (transforms.Compose): the composed test transforms.
def target_transform(self, task_id: int) -> Callable:
313    def target_transform(self, task_id: int) -> Callable:
314        r"""Target transform for task `task_id`, which maps the original class labels to the integer class labels for multi-task learning. It can be used in subclasses when constructing the dataset.
315
316        **Args:**
317        - **task_id** (`int`): the task ID of dataset to get the target transform for.
318
319        **Returns:**
320        - **target_transform** (`Callable`): the target transform function.
321        """
322
323        class_map = self.get_mtl_class_map(task_id)
324
325        target_transform = ClassMapping(class_map=class_map)
326        return target_transform

Target transform for task task_id, which maps the original class labels to the integer class labels for multi-task learning. It can be used in subclasses when constructing the dataset.

Args:

  • task_id (int): the task ID of dataset to get the target transform for.

Returns:

  • target_transform (Callable): the target transform function.
@abstractmethod
def train_and_val_dataset(self, task_id: int) -> tuple[typing.Any, typing.Any]:
328    @abstractmethod
329    def train_and_val_dataset(self, task_id: int) -> tuple[Any, Any]:
330        r"""Get the training and validation dataset of task `task_id`. It must be implemented by subclasses.
331
332        **Args:**
333        - **task_id** (`int`): the task ID to get the training and validation dataset for.
334
335        **Returns:**
336        - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation dataset of task `task_id`.
337        """

Get the training and validation dataset of task task_id. It must be implemented by subclasses.

Args:

  • task_id (int): the task ID to get the training and validation dataset for.

Returns:

  • train_and_val_dataset (tuple[Any, Any]): the train and validation dataset of task task_id.
@abstractmethod
def test_dataset(self, task_id: int) -> Any:
339    @abstractmethod
340    def test_dataset(self, task_id: int) -> Any:
341        """Get the test dataset of task `task_id`. It must be implemented by subclasses.
342
343        **Args:**
344        - **task_id** (`int`): the task ID to get the test dataset for.
345
346        **Returns:**
347        - **test_dataset** (`Any`): the test dataset of task `task_id`.
348        """

Get the test dataset of task task_id. It must be implemented by subclasses.

Args:

  • task_id (int): the task ID to get the test dataset for.

Returns:

  • test_dataset (Any): the test dataset of task task_id.
def train_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
350    def train_dataloader(self) -> DataLoader:
351        r"""DataLoader generator for stage train. It is automatically called before training.
352
353        **Returns:**
354        - **train_dataloader** (`DataLoader`): the train DataLoader of task.
355        """
356
357        pylogger.debug(
358            "Construct train dataloader ... sampling_strategy method: %s",
359            self.sampling_strategy,
360        )
361
362        if self.sampling_strategy == "mixed":
363            # mixed sampling strategy, which samples from all tasks' datasets
364
365            concatenated_dataset = ConcatDataset(
366                [self.dataset_train[task_id] for task_id in self.train_tasks]
367            )
368
369            return DataLoader(
370                dataset=concatenated_dataset,
371                batch_size=self.batch_size,
372                shuffle=True,  # shuffle train batch to prevent overfitting
373                num_workers=self.num_workers,
374                drop_last=True,  # to avoid batchnorm error (when batch_size is 1)
375            )

DataLoader generator for stage train. It is automatically called before training.

Returns:

  • train_dataloader (DataLoader): the train DataLoader of task.
def val_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
377    def val_dataloader(self) -> DataLoader:
378        r"""DataLoader generator for the validation stage. It is automatically called before validation.
379
380        **Returns:**
381        - **val_dataloader** (`dict[int, DataLoader]`): the validation DataLoader.
382        """
383
384        pylogger.debug("Construct validation dataloader...")
385
386        return {
387            task_id: DataLoader(
388                dataset=dataset_val_t,
389                batch_size=self.batch_size,
390                shuffle=False,  # don't have to shuffle val or test batch
391                num_workers=self.num_workers,
392            )
393            for task_id, dataset_val_t in self.dataset_val.items()
394        }

DataLoader generator for the validation stage. It is automatically called before validation.

Returns:

  • val_dataloader (dict[int, DataLoader]): the validation DataLoader.
def test_dataloader(self) -> dict[int, torch.utils.data.dataloader.DataLoader]:
396    def test_dataloader(self) -> dict[int, DataLoader]:
397        r"""DataLoader generator for stage test. It is automatically called before testing.
398
399        **Returns:**
400        - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader.
401        """
402
403        pylogger.debug("Construct test dataloader...")
404
405        return {
406            task_id: DataLoader(
407                dataset=dataset_test_t,
408                batch_size=self.batch_size,
409                shuffle=False,  # don't have to shuffle val or test batch
410                num_workers=self.num_workers,
411            )
412            for task_id, dataset_test_t in self.dataset_test.items()
413        }

DataLoader generator for stage test. It is automatically called before testing.

Returns:

  • test_dataloader (dict[int, DataLoader]): the test DataLoader.
class MTLCombinedDataset(clarena.mtl_datasets.MTLDataset):
416class MTLCombinedDataset(MTLDataset):
417    r"""The base class of multi-task learning datasets constructed as combinations of several single-task datasets (one dataset per task)."""
418
419    def __init__(
420        self,
421        datasets: dict[int, str],
422        root: str | dict[int, str],
423        sampling_strategy: str = "mixed",
424        batch_size: int = 1,
425        num_workers: int = 0,
426        custom_transforms: (
427            Callable
428            | transforms.Compose
429            | None
430            | dict[int, Callable | transforms.Compose | None]
431        ) = None,
432        repeat_channels: int | None | dict[int, int | None] = None,
433        to_tensor: bool | dict[int, bool] = True,
434        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
435    ) -> None:
436        r"""
437        **Args:**
438        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
439        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories.
440        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
441            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
442        - **batch_size** (`int`): The batch size in train, val, test dataloader.
443        - **num_workers** (`int`): the number of workers for dataloaders.
444        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, and so on are not included.
445        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
446        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
447        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
448        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
449        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
450        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
451        """
452        super().__init__(
453            root=root,
454            num_tasks=len(
455                datasets
456            ),  # num_tasks is not explicitly provided, but derived from the datasets length
457            sampling_strategy=sampling_strategy,
458            batch_size=batch_size,
459            num_workers=num_workers,
460            custom_transforms=custom_transforms,
461            repeat_channels=repeat_channels,
462            to_tensor=to_tensor,
463            resize=resize,
464        )
465
466        self.original_dataset_python_classes: dict[int, Dataset] = {
467            t: str_to_class(dataset_class_path)
468            for t, dataset_class_path in datasets.items()
469        }
470        r"""The dict of dataset classes for each task."""
471
472    def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
473        r"""Get the mapping of classes of task `task_id` to fit multi-task learning.
474
475        **Args:**
476        - **task_id** (`int`): the task ID to query the class map.
477
478        **Returns:**
479        - **class_map** (`dict[str | int, int]`): the class map of the task. Keys are the original class label and values are the integer class labels for multi-task learning. For multi-task learning, the mapped class labels of a task should be continuous integers from 0 to the number of classes.
480        """
481        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
482        original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[
483            original_dataset_python_class_t
484        ]
485        num_classes_t = original_dataset_constants_t.NUM_CLASSES
486        class_map_t = original_dataset_constants_t.CLASS_MAP
487
488        return {class_map_t[i]: i for i in range(num_classes_t)}
489
490    def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
491        r"""Set up tasks for the multi-task learning experiment.
492
493        **Args:**
494        - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
495        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
496        """
497        super().setup_tasks_expr(train_tasks=train_tasks, eval_tasks=eval_tasks)
498
499        for task_id in train_tasks + eval_tasks:
500            original_dataset_python_class_t = self.original_dataset_python_classes[
501                task_id
502            ]
503            original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[
504                original_dataset_python_class_t
505            ]
506            self.mean[task_id] = original_dataset_constants_t.MEAN
507            self.std[task_id] = original_dataset_constants_t.STD
508
509    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
510        r"""Set up evaluation tasks for the multi-task learning evaluation.
511
512        **Args:**
513        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated.
514        """
515        super().setup_tasks_eval(eval_tasks=eval_tasks)

The base class of multi-task learning datasets constructed as combinations of several single-task datasets (one dataset per task).

MTLCombinedDataset( datasets: dict[int, str], root: str | dict[int, str], sampling_strategy: str = 'mixed', batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
419    def __init__(
420        self,
421        datasets: dict[int, str],
422        root: str | dict[int, str],
423        sampling_strategy: str = "mixed",
424        batch_size: int = 1,
425        num_workers: int = 0,
426        custom_transforms: (
427            Callable
428            | transforms.Compose
429            | None
430            | dict[int, Callable | transforms.Compose | None]
431        ) = None,
432        repeat_channels: int | None | dict[int, int | None] = None,
433        to_tensor: bool | dict[int, bool] = True,
434        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
435    ) -> None:
436        r"""
437        **Args:**
438        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
439        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories.
440        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
441            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
442        - **batch_size** (`int`): The batch size in train, val, test dataloader.
443        - **num_workers** (`int`): the number of workers for dataloaders.
444        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, and so on are not included.
445        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
446        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
447        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
448        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
449        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
450        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
451        """
452        super().__init__(
453            root=root,
454            num_tasks=len(
455                datasets
456            ),  # num_tasks is not explicitly provided, but derived from the datasets length
457            sampling_strategy=sampling_strategy,
458            batch_size=batch_size,
459            num_workers=num_workers,
460            custom_transforms=custom_transforms,
461            repeat_channels=repeat_channels,
462            to_tensor=to_tensor,
463            resize=resize,
464        )
465
466        self.original_dataset_python_classes: dict[int, Dataset] = {
467            t: str_to_class(dataset_class_path)
468            for t, dataset_class_path in datasets.items()
469        }
470        r"""The dict of dataset classes for each task."""

Args:

  • datasets (dict[int, str]): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
  • root (str | dict[int, str]): the root directory where the original data files for constructing the MTL dataset physically live. If dict[int, str], it should be a dict of task IDs and their corresponding root directories.
  • sampling_strategy (str): the sampling strategy that construct training batch from each task's dataset; one of:
    • 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
original_dataset_python_classes: dict[int, torch.utils.data.dataset.Dataset]

The dict of dataset classes for each task.

def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
472    def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
473        r"""Get the mapping of classes of task `task_id` to fit multi-task learning.
474
475        **Args:**
476        - **task_id** (`int`): the task ID to query the class map.
477
478        **Returns:**
479        - **class_map** (`dict[str | int, int]`): the class map of the task. Keys are the original class label and values are the integer class labels for multi-task learning. For multi-task learning, the mapped class labels of a task should be continuous integers from 0 to the number of classes.
480        """
481        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
482        original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[
483            original_dataset_python_class_t
484        ]
485        num_classes_t = original_dataset_constants_t.NUM_CLASSES
486        class_map_t = original_dataset_constants_t.CLASS_MAP
487
488        return {class_map_t[i]: i for i in range(num_classes_t)}

Get the mapping of classes of task task_id to fit multi-task learning.

Args:

  • task_id (int): the task ID to query the class map.

Returns:

  • class_map (dict[str | int, int]): the class map of the task. Keys are the original class label and values are the integer class labels for multi-task learning. For multi-task learning, the mapped class labels of a task should be continuous integers from 0 to the number of classes.
def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
490    def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
491        r"""Set up tasks for the multi-task learning experiment.
492
493        **Args:**
494        - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
495        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
496        """
497        super().setup_tasks_expr(train_tasks=train_tasks, eval_tasks=eval_tasks)
498
499        for task_id in train_tasks + eval_tasks:
500            original_dataset_python_class_t = self.original_dataset_python_classes[
501                task_id
502            ]
503            original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[
504                original_dataset_python_class_t
505            ]
506            self.mean[task_id] = original_dataset_constants_t.MEAN
507            self.std[task_id] = original_dataset_constants_t.STD

Set up tasks for the multi-task learning experiment.

Args:

  • train_tasks (list[int]): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
  • eval_tasks (list[int]): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
509    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
510        r"""Set up evaluation tasks for the multi-task learning evaluation.
511
512        **Args:**
513        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated.
514        """
515        super().setup_tasks_eval(eval_tasks=eval_tasks)

Set up evaluation tasks for the multi-task learning evaluation.

Args:

  • eval_tasks (list[int]): the list of task IDs to be evaluated.
class MTLDatasetFromCL(clarena.mtl_datasets.MTLDataset):
518class MTLDatasetFromCL(MTLDataset):
519    r"""Multi-task learning datasets constructed from the CL datasets.
520
521    This is usually for constructing the reference joint learning experiment for continual learning.
522    """
523
524    def __init__(
525        self,
526        cl_dataset: CLDataset,
527        sampling_strategy: str = "mixed",
528        batch_size: int = 1,
529        num_workers: int = 0,
530    ) -> None:
531        r"""Initialize the `MTLDatasetFromCL` object.
532
533        **Args:**
534        - **cl_dataset** (`CLDataset`): the CL dataset object to be used for constructing the MTL dataset.
535        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
536            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
537        - **batch_size** (`int`): The batch size in train, val, test dataloader.
538        - **num_workers** (`int`): the number of workers for dataloaders.
539        """
540
541        self.cl_dataset: CLDataset = cl_dataset
542        r"""The CL dataset for constructing the MTL dataset."""
543
544        super().__init__(
545            root=None,
546            num_tasks=cl_dataset.num_tasks,
547            sampling_strategy=sampling_strategy,
548            batch_size=batch_size,
549            num_workers=num_workers,
550            custom_transforms=None,  # already handled in the CL dataset
551            repeat_channels=None,
552            to_tensor=None,
553            resize=None,
554        )
555
556    def prepare_data(self) -> None:
557        r"""Download and prepare data."""
558        self.cl_dataset.prepare_data()  # prepare the CL dataset
559
560    def setup(self, stage: str) -> None:
561        r"""Set up the dataset for different stages.
562
563        **Args:**
564        - **stage** (`str`): the stage of the experiment; one of:
565            - 'fit': training and validation dataset should be assigned to `self.dataset_train` and `self.dataset_val`.
566            - 'test': test dataset should be assigned to `self.dataset_test`.
567        """
568        if stage == "fit":
569            pylogger.debug("Construct train and validation dataset ...")
570
571            # go through each task of continual learning to get the training dataset of each task
572            for task_id in range(1, self.num_tasks + 1):
573                self.cl_dataset.setup_task_id(task_id)
574                self.cl_dataset.setup(stage)
575
576                # label the training dataset with the task ID
577                task_labelled_dataset_train_t = TaskLabelledDataset(
578                    self.cl_dataset.dataset_train_t, task_id
579                )
580                self.dataset_train[task_id] = task_labelled_dataset_train_t
581
582                # label the validation dataset with the task ID
583                task_labelled_dataset_val_t = TaskLabelledDataset(
584                    self.cl_dataset.dataset_val_t, task_id
585                )
586                self.dataset_val[task_id] = task_labelled_dataset_val_t
587
588                pylogger.debug(
589                    "Train and validation dataset for task %d are ready.", task_id
590                )
591                pylogger.info(
592                    "Train dataset for task %d size: %d",
593                    task_id,
594                    len(self.dataset_train[task_id]),
595                )
596                pylogger.info(
597                    "Validation dataset for task %d size: %d",
598                    task_id,
599                    len(self.dataset_val[task_id]),
600                )
601
602        elif stage == "test":
603
604            pylogger.debug("Construct test dataset ...")
605
606            for task_id in self.eval_tasks:
607
608                self.cl_dataset.setup_task_id(task_id)
609                self.cl_dataset.setup(stage)
610
611                task_labelled_dataset_test_t = TaskLabelledDataset(
612                    self.cl_dataset.dataset_test[task_id], task_id
613                )
614
615                self.dataset_test[task_id] = task_labelled_dataset_test_t
616
617                pylogger.debug("Test dataset for task %d are ready.", task_id)
618                pylogger.info(
619                    "Test dataset for task %d size: %d",
620                    task_id,
621                    len(self.dataset_test[task_id]),
622                )
623
624    def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
625        r"""Get the mapping of classes of task `task_id` to fit multi-task learning.
626
627        **Args:**
628        - **task_id** (`int`): The task ID to query class map.
629
630        **Returns:**
631        - **class_map**(`dict[str | int, int]`): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes.
632        """
633        return self.cl_dataset.get_cl_class_map(
634            task_id
635        )  # directly use the CL dataset's class map (from TIL setting)
636
637    def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
638        r"""Set up tasks for the multi-task learning experiment.
639
640        **Args:**
641        - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
642        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
643        """
644        super().setup_tasks_expr(train_tasks=train_tasks, eval_tasks=eval_tasks)
645
646        # MTL requires independent heads
647        self.cl_dataset.set_cl_paradigm(cl_paradigm="TIL")
648        for task_id in train_tasks + eval_tasks:
649            self.cl_dataset.setup_task_id(task_id)
650
651    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
652        r"""Set up evaluation tasks for the multi-task learning evaluation.
653
654        **Args:**
655        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated."""
656        super().setup_tasks_eval(eval_tasks=eval_tasks)
657
658        # MTL requires independent heads
659        self.cl_dataset.set_cl_paradigm(cl_paradigm="TIL")
660        for task_id in eval_tasks:
661            self.cl_dataset.setup_task_id(task_id)

Multi-task learning datasets constructed from the CL datasets.

This is usually for constructing the reference joint learning experiment for continual learning.

MTLDatasetFromCL( cl_dataset: clarena.cl_datasets.CLDataset, sampling_strategy: str = 'mixed', batch_size: int = 1, num_workers: int = 0)
524    def __init__(
525        self,
526        cl_dataset: CLDataset,
527        sampling_strategy: str = "mixed",
528        batch_size: int = 1,
529        num_workers: int = 0,
530    ) -> None:
531        r"""Initialize the `MTLDatasetFromCL` object.
532
533        **Args:**
534        - **cl_dataset** (`CLDataset`): the CL dataset object to be used for constructing the MTL dataset.
535        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
536            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
537        - **batch_size** (`int`): The batch size in train, val, test dataloader.
538        - **num_workers** (`int`): the number of workers for dataloaders.
539        """
540
541        self.cl_dataset: CLDataset = cl_dataset
542        r"""The CL dataset for constructing the MTL dataset."""
543
544        super().__init__(
545            root=None,
546            num_tasks=cl_dataset.num_tasks,
547            sampling_strategy=sampling_strategy,
548            batch_size=batch_size,
549            num_workers=num_workers,
550            custom_transforms=None,  # already handled in the CL dataset
551            repeat_channels=None,
552            to_tensor=None,
553            resize=None,
554        )

Initialize the MTLDatasetFromCL object.

Args:

  • cl_dataset (CLDataset): the CL dataset object to be used for constructing the MTL dataset.
  • sampling_strategy (str): the sampling strategy that construct training batch from each task's dataset; one of:
    • 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.

The CL dataset for constructing the MTL dataset.

def prepare_data(self) -> None:
556    def prepare_data(self) -> None:
557        r"""Download and prepare data."""
558        self.cl_dataset.prepare_data()  # prepare the CL dataset

Download and prepare data.

def setup(self, stage: str) -> None:
560    def setup(self, stage: str) -> None:
561        r"""Set up the dataset for different stages.
562
563        **Args:**
564        - **stage** (`str`): the stage of the experiment; one of:
565            - 'fit': training and validation dataset should be assigned to `self.dataset_train` and `self.dataset_val`.
566            - 'test': test dataset should be assigned to `self.dataset_test`.
567        """
568        if stage == "fit":
569            pylogger.debug("Construct train and validation dataset ...")
570
571            # go through each task of continual learning to get the training dataset of each task
572            for task_id in range(1, self.num_tasks + 1):
573                self.cl_dataset.setup_task_id(task_id)
574                self.cl_dataset.setup(stage)
575
576                # label the training dataset with the task ID
577                task_labelled_dataset_train_t = TaskLabelledDataset(
578                    self.cl_dataset.dataset_train_t, task_id
579                )
580                self.dataset_train[task_id] = task_labelled_dataset_train_t
581
582                # label the validation dataset with the task ID
583                task_labelled_dataset_val_t = TaskLabelledDataset(
584                    self.cl_dataset.dataset_val_t, task_id
585                )
586                self.dataset_val[task_id] = task_labelled_dataset_val_t
587
588                pylogger.debug(
589                    "Train and validation dataset for task %d are ready.", task_id
590                )
591                pylogger.info(
592                    "Train dataset for task %d size: %d",
593                    task_id,
594                    len(self.dataset_train[task_id]),
595                )
596                pylogger.info(
597                    "Validation dataset for task %d size: %d",
598                    task_id,
599                    len(self.dataset_val[task_id]),
600                )
601
602        elif stage == "test":
603
604            pylogger.debug("Construct test dataset ...")
605
606            for task_id in self.eval_tasks:
607
608                self.cl_dataset.setup_task_id(task_id)
609                self.cl_dataset.setup(stage)
610
611                task_labelled_dataset_test_t = TaskLabelledDataset(
612                    self.cl_dataset.dataset_test[task_id], task_id
613                )
614
615                self.dataset_test[task_id] = task_labelled_dataset_test_t
616
617                pylogger.debug("Test dataset for task %d are ready.", task_id)
618                pylogger.info(
619                    "Test dataset for task %d size: %d",
620                    task_id,
621                    len(self.dataset_test[task_id]),
622                )

Set up the dataset for different stages.

Args:

  • stage (str): the stage of the experiment; one of:
    • 'fit': training and validation dataset should be assigned to self.dataset_train and self.dataset_val.
    • 'test': test dataset should be assigned to self.dataset_test.
def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
624    def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]:
625        r"""Get the mapping of classes of task `task_id` to fit multi-task learning.
626
627        **Args:**
628        - **task_id** (`int`): The task ID to query class map.
629
630        **Returns:**
631        - **class_map**(`dict[str | int, int]`): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes.
632        """
633        return self.cl_dataset.get_cl_class_map(
634            task_id
635        )  # directly use the CL dataset's class map (from TIL setting)

Get the mapping of classes of task task_id to fit multi-task learning.

Args:

  • task_id (int): The task ID to query class map.

Returns:

  • class_map(dict[str | int, int]): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes.
def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
637    def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None:
638        r"""Set up tasks for the multi-task learning experiment.
639
640        **Args:**
641        - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
642        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
643        """
644        super().setup_tasks_expr(train_tasks=train_tasks, eval_tasks=eval_tasks)
645
646        # MTL requires independent heads
647        self.cl_dataset.set_cl_paradigm(cl_paradigm="TIL")
648        for task_id in train_tasks + eval_tasks:
649            self.cl_dataset.setup_task_id(task_id)

Set up tasks for the multi-task learning experiment.

Args:

  • train_tasks (list[int]): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
  • eval_tasks (list[int]): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
651    def setup_tasks_eval(self, eval_tasks: list[int]) -> None:
652        r"""Set up evaluation tasks for the multi-task learning evaluation.
653
654        **Args:**
655        - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated."""
656        super().setup_tasks_eval(eval_tasks=eval_tasks)
657
658        # MTL requires independent heads
659        self.cl_dataset.set_cl_paradigm(cl_paradigm="TIL")
660        for task_id in eval_tasks:
661            self.cl_dataset.setup_task_id(task_id)

Set up evaluation tasks for the multi-task learning evaluation.

Args:

  • eval_tasks (list[int]): the list of task IDs to be evaluated.