clarena.mtl_datasets.combined

The submodule in mtl_datasets for CelebA dataset used for multi-task learning.

  1r"""
  2The submodule in `mtl_datasets` for CelebA dataset used for multi-task learning.
  3"""
  4
  5__all__ = ["Combined"]
  6
  7import logging
  8from typing import Callable
  9
 10import torch
 11from tinyimagenet import TinyImageNet
 12from torch.utils.data import Dataset, random_split
 13from torchvision.datasets import (
 14    CIFAR10,
 15    CIFAR100,
 16    DTD,
 17    FER2013,
 18    GTSRB,
 19    KMNIST,
 20    MNIST,
 21    PCAM,
 22    SEMEION,
 23    SUN397,
 24    SVHN,
 25    USPS,
 26    Caltech101,
 27    Caltech256,
 28    CelebA,
 29    Country211,
 30    EuroSAT,
 31    FashionMNIST,
 32    Flowers102,
 33    Food101,
 34    RenderedSST2,
 35    StanfordCars,
 36)
 37from torchvision.datasets.vision import VisionDataset
 38from torchvision.transforms import transforms
 39
 40from clarena.mtl_datasets import MTLCombinedDataset
 41from clarena.stl_datasets.raw import (
 42    CUB2002011,
 43    ArabicHandwrittenDigits,
 44    EMNISTBalanced,
 45    EMNISTByClass,
 46    EMNISTByMerge,
 47    EMNISTDigits,
 48    EMNISTLetters,
 49    FaceScrub10,
 50    FaceScrub20,
 51    FaceScrub50,
 52    FaceScrub100,
 53    FaceScrubFromHAT,
 54    FGVCAircraftFamily,
 55    FGVCAircraftManufacturer,
 56    FGVCAircraftVariant,
 57    KannadaMNIST,
 58    Linnaeus5_32,
 59    Linnaeus5_64,
 60    Linnaeus5_128,
 61    Linnaeus5_256,
 62    NotMNIST,
 63    NotMNISTFromHAT,
 64    OxfordIIITPet2,
 65    OxfordIIITPet37,
 66    SignLanguageMNIST,
 67    TrafficSignsFromHAT,
 68)
 69
 70# always get logger for built-in logging in each module
 71pylogger = logging.getLogger(__name__)
 72
 73
 74class Combined(MTLCombinedDataset):
 75    r"""Combined MTL dataset from available datasets."""
 76
 77    AVAILABLE_DATASETS: list[VisionDataset] = [
 78        ArabicHandwrittenDigits,
 79        CIFAR10,
 80        CIFAR100,
 81        CUB2002011,
 82        Caltech101,
 83        Caltech256,
 84        CelebA,
 85        Country211,
 86        DTD,
 87        EMNISTBalanced,
 88        EMNISTByClass,
 89        EMNISTByMerge,
 90        EMNISTDigits,
 91        EMNISTLetters,
 92        EuroSAT,
 93        FER2013,
 94        FGVCAircraftFamily,
 95        FGVCAircraftManufacturer,
 96        FGVCAircraftVariant,
 97        FaceScrub10,
 98        FaceScrub100,
 99        FaceScrubFromHAT,
100        FaceScrub20,
101        FaceScrub50,
102        FashionMNIST,
103        Flowers102,
104        Food101,
105        GTSRB,
106        KMNIST,
107        KannadaMNIST,
108        Linnaeus5_128,
109        Linnaeus5_256,
110        Linnaeus5_32,
111        Linnaeus5_64,
112        MNIST,
113        NotMNIST,
114        NotMNISTFromHAT,
115        OxfordIIITPet2,
116        OxfordIIITPet37,
117        PCAM,
118        RenderedSST2,
119        SEMEION,
120        SUN397,
121        SVHN,
122        SignLanguageMNIST,
123        StanfordCars,
124        TrafficSignsFromHAT,
125        TinyImageNet,
126        USPS,
127    ]
128    r"""The list of available datasets."""
129
130    def __init__(
131        self,
132        datasets: dict[int, str],
133        root: str | dict[int, str],
134        validation_percentage: float,
135        test_percentage: float,
136        sampling_strategy: str = "mixed",
137        batch_size: int = 1,
138        num_workers: int = 0,
139        custom_transforms: (
140            Callable
141            | transforms.Compose
142            | None
143            | dict[int, Callable | transforms.Compose | None]
144        ) = None,
145        repeat_channels: int | None | dict[int, int | None] = None,
146        to_tensor: bool | dict[int, bool] = True,
147        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
148    ) -> None:
149        r"""
150        **Args:**
151        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
152        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories.
153        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset).
154        - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset).
155        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
156            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
157        - **batch_size** (`int`): The batch size in train, val, test dataloader.
158        - **num_workers** (`int`): the number of workers for dataloaders.
159        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included.
160        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
161        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
162        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
163        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
164        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
165        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
166        """
167        super().__init__(
168            datasets=datasets,
169            root=root,
170            sampling_strategy=sampling_strategy,
171            batch_size=batch_size,
172            num_workers=num_workers,
173            custom_transforms=custom_transforms,
174            repeat_channels=repeat_channels,
175            to_tensor=to_tensor,
176            resize=resize,
177        )
178
179        self.test_percentage: float = test_percentage
180        """The percentage to randomly split some data into test data."""
181        self.validation_percentage: float = validation_percentage
182        """The percentage to randomly split some training data into validation data."""
183
184    def prepare_data(self) -> None:
185        r"""Download the original datasets if haven't."""
186
187        failed_dataset_classes = []
188        for task_id in range(1, self.num_tasks + 1):
189            root = self.root[task_id]
190            dataset_class = self.original_dataset_python_classes[task_id]
191            # torchvision datasets might have different APIs
192            try:
193                # collect the error and raise it at the end to avoid stopping the whole download process
194
195                if dataset_class in [
196                    ArabicHandwrittenDigits,
197                    KannadaMNIST,
198                    SignLanguageMNIST,
199                ]:
200                    # these datasets have no automatic download function, we require users to download them manually
201                    # the following code is just to check if the dataset is already downloaded
202                    dataset_class(root=root, train=True, download=False)
203                    dataset_class(root=root, train=False, download=False)
204
205                elif dataset_class in [
206                    Caltech101,
207                    Caltech256,
208                    EuroSAT,
209                    SEMEION,
210                    SUN397,
211                ]:
212                    # dataset classes that don't have any train, val, test split
213                    dataset_class(root=root, download=True)
214
215                elif dataset_class in [
216                    ArabicHandwrittenDigits,
217                    CIFAR10,
218                    CIFAR100,
219                    CUB2002011,
220                    EMNISTByClass,
221                    EMNISTByMerge,
222                    EMNISTBalanced,
223                    EMNISTLetters,
224                    EMNISTDigits,
225                    FaceScrub10,
226                    FaceScrub20,
227                    FaceScrub50,
228                    FaceScrub100,
229                    FaceScrubFromHAT,
230                    FashionMNIST,
231                    KannadaMNIST,
232                    KMNIST,
233                    Linnaeus5_32,
234                    Linnaeus5_64,
235                    Linnaeus5_128,
236                    Linnaeus5_256,
237                    MNIST,
238                    NotMNIST,
239                    NotMNISTFromHAT,
240                    SignLanguageMNIST,
241                    TrafficSignsFromHAT,
242                    USPS,
243                ]:
244                    # dataset classes that have `train` bool argument
245                    dataset_class(root=root, train=True, download=True)
246                    dataset_class(root=root, train=False, download=True)
247
248                elif dataset_class in [
249                    Food101,
250                    GTSRB,
251                    StanfordCars,
252                    SVHN,
253                ]:
254                    # dataset classes that have `split` argument with 'train', 'test'
255                    dataset_class(
256                        root=root,
257                        split="train",
258                        download=True,
259                    )
260                    dataset_class(
261                        root=root,
262                        split="test",
263                        download=True,
264                    )
265                elif dataset_class in [Country211]:
266                    # dataset classes that have `split` argument with 'train', 'valid', 'test'
267                    dataset_class(
268                        root=root,
269                        split="train",
270                        download=True,
271                    )
272                    dataset_class(
273                        root=root,
274                        split="valid",
275                        download=True,
276                    )
277                    dataset_class(
278                        root=root,
279                        split="test",
280                        download=True,
281                    )
282
283                elif dataset_class in [
284                    DTD,
285                    FGVCAircraftVariant,
286                    FGVCAircraftFamily,
287                    FGVCAircraftManufacturer,
288                    Flowers102,
289                    PCAM,
290                    RenderedSST2,
291                ]:
292                    # dataset classes that have `split` argument with 'train', 'val', 'test'
293                    dataset_class(
294                        root=root,
295                        split="train",
296                        download=True,
297                    )
298                    dataset_class(
299                        root=root,
300                        split="val",
301                        download=True,
302                    )
303                    dataset_class(
304                        root=root,
305                        split="test",
306                        download=True,
307                    )
308                elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]:
309                    # dataset classes that have `split` argument with 'trainval', 'test'
310                    dataset_class(
311                        root=root,
312                        split="trainval",
313                        download=True,
314                    )
315                    dataset_class(
316                        root=root,
317                        split="test",
318                        download=True,
319                    )
320                elif dataset_class == CelebA:
321                    # special case
322                    dataset_class(
323                        root=root,
324                        split="train",
325                        target_type="identity",
326                        download=True,
327                    )
328                    dataset_class(
329                        root=root,
330                        split="valid",
331                        target_type="identity",
332                        download=True,
333                    )
334                    dataset_class(
335                        root=root,
336                        split="test",
337                        target_type="identity",
338                        download=True,
339                    )
340                elif dataset_class == FER2013:
341                    # special case
342                    dataset_class(
343                        root=root,
344                        split="train",
345                    )
346                    dataset_class(
347                        root=root,
348                        split="test",
349                    )
350                elif dataset_class == TinyImageNet:
351                    # special case
352                    dataset_class(root=root)
353
354            except RuntimeError:
355                failed_dataset_classes.append(dataset_class)  # save for later prompt
356            else:
357                pylogger.debug(
358                    "The original %s dataset for task %s has been downloaded to %s.",
359                    dataset_class,
360                    task_id,
361                    root,
362                )
363
364        if failed_dataset_classes:
365            raise RuntimeError(
366                f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually."
367            )
368
369    def train_and_val_dataset(self, task_id: int) -> tuple[Dataset, Dataset]:
370        r"""Get the training and validation dataset of task `task_id`.
371
372        **Args:**
373        - **task_id** (`int`): the task ID to get the training and validation dataset for.
374
375        **Returns:**
376        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `task_id`.
377        """
378        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
379
380        # torchvision datasets might have different APIs
381        if original_dataset_python_class_t in [
382            ArabicHandwrittenDigits,
383            CIFAR10,
384            CIFAR100,
385            CUB2002011,
386            EMNISTByClass,
387            EMNISTByMerge,
388            EMNISTBalanced,
389            EMNISTLetters,
390            EMNISTDigits,
391            FaceScrub10,
392            FaceScrub20,
393            FaceScrub50,
394            FaceScrub100,
395            FaceScrubFromHAT,
396            FashionMNIST,
397            KannadaMNIST,
398            KMNIST,
399            Linnaeus5_32,
400            Linnaeus5_64,
401            Linnaeus5_128,
402            Linnaeus5_256,
403            MNIST,
404            NotMNIST,
405            NotMNISTFromHAT,
406            SignLanguageMNIST,
407            TrafficSignsFromHAT,
408            USPS,
409        ]:
410            # dataset classes that have `train` bool argument
411            dataset_train_and_val = original_dataset_python_class_t(
412                root=self.root[task_id],
413                train=True,
414                transform=self.train_and_val_transforms(task_id),
415                target_transform=self.target_transform(task_id),
416            )
417
418            return random_split(
419                dataset_train_and_val,
420                lengths=[1 - self.validation_percentage, self.validation_percentage],
421                generator=torch.Generator().manual_seed(
422                    42
423                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
424            )
425        elif original_dataset_python_class_t in [
426            Caltech101,
427            Caltech256,
428            EuroSAT,
429            SEMEION,
430            SUN397,
431        ]:
432            # dataset classes that don't have train and test splt
433            dataset_all = original_dataset_python_class_t(
434                root=self.root[task_id],
435                transform=self.train_and_val_transforms(task_id),
436                target_transform=self.target_transform(task_id),
437            )
438
439            dataset_train_and_val, _ = random_split(
440                dataset_all,
441                lengths=[
442                    1 - self.test_percentage,
443                    self.test_percentage,
444                ],
445                generator=torch.Generator().manual_seed(
446                    42
447                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
448            )
449
450            return random_split(
451                dataset_train_and_val,
452                lengths=[1 - self.validation_percentage, self.validation_percentage],
453                generator=torch.Generator().manual_seed(
454                    42
455                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
456            )
457        elif original_dataset_python_class_t in [Country211]:
458            # dataset classes that have `split` argument with 'train', 'valid', 'test'
459            dataset_train = original_dataset_python_class_t(
460                root=self.root[task_id],
461                split="train",
462                transform=self.train_and_val_transforms(task_id),
463                target_transform=self.target_transform(task_id),
464            )
465            dataset_val = original_dataset_python_class_t(
466                root=self.root[task_id],
467                split="valid",
468                transform=self.train_and_val_transforms(task_id),
469                target_transform=self.target_transform(task_id),
470            )
471
472            return dataset_train, dataset_val
473
474        elif original_dataset_python_class_t in [
475            DTD,
476            FGVCAircraftVariant,
477            FGVCAircraftFamily,
478            FGVCAircraftManufacturer,
479            Flowers102,
480            PCAM,
481            RenderedSST2,
482        ]:
483            # dataset classes that have `split` argument with 'train', 'val', 'test'
484            dataset_train = original_dataset_python_class_t(
485                root=self.root[task_id],
486                split="train",
487                transform=self.train_and_val_transforms(task_id),
488                target_transform=self.target_transform(task_id),
489            )
490
491            dataset_val = original_dataset_python_class_t(
492                root=self.root[task_id],
493                split="val",
494                transform=self.train_and_val_transforms(task_id),
495                target_transform=self.target_transform(task_id),
496            )
497
498            return dataset_train, dataset_val
499
500        elif original_dataset_python_class_t in [
501            FER2013,
502            Food101,
503            GTSRB,
504            StanfordCars,
505            SVHN,
506            TinyImageNet,
507        ]:
508            # dataset classes that have `split` argument with 'train', 'test'
509
510            dataset_train_and_val = original_dataset_python_class_t(
511                root=self.root[task_id],
512                split="train",
513                transform=self.train_and_val_transforms(task_id),
514                target_transform=self.target_transform(task_id),
515            )
516
517            return random_split(
518                dataset_train_and_val,
519                lengths=[1 - self.validation_percentage, self.validation_percentage],
520                generator=torch.Generator().manual_seed(
521                    42
522                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
523            )
524
525        elif original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]:
526            # dataset classes that have `split` argument with 'trainval', 'test'
527
528            dataset_train_and_val = original_dataset_python_class_t(
529                root=self.root[task_id],
530                split="trainval",
531                transform=self.train_and_val_transforms(task_id),
532                target_transform=self.target_transform(task_id),
533            )
534
535            return random_split(
536                dataset_train_and_val,
537                lengths=[1 - self.validation_percentage, self.validation_percentage],
538                generator=torch.Generator().manual_seed(
539                    42
540                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
541            )
542
543        elif original_dataset_python_class_t in [CelebA]:
544            # special case
545            dataset_train = original_dataset_python_class_t(
546                root=self.root[task_id],
547                split="train",
548                target_type="identity",
549                transform=self.train_and_val_transforms(task_id),
550                target_transform=self.target_transform(task_id),
551            )
552
553            dataset_val = original_dataset_python_class_t(
554                root=self.root[task_id],
555                split="valid",
556                target_type="identity",
557                transform=self.train_and_val_transforms(task_id),
558                target_transform=self.target_transform(task_id),
559            )
560
561            return dataset_train, dataset_val
562
563    def test_dataset(self, task_id: int) -> Dataset:
564        """Get the test dataset of task `task_id`.
565
566        **Args:**
567        - **task_id** (`int`): the task ID to get the test dataset for.
568
569        **Returns:**
570        - **test_dataset** (`Dataset`): the test dataset of task `task_id`.
571        """
572        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
573
574        if original_dataset_python_class_t in [
575            ArabicHandwrittenDigits,
576            CIFAR10,
577            CIFAR100,
578            CUB2002011,
579            EMNISTByClass,
580            EMNISTByMerge,
581            EMNISTBalanced,
582            EMNISTLetters,
583            EMNISTDigits,
584            FaceScrub10,
585            FaceScrub20,
586            FaceScrub50,
587            FaceScrub100,
588            FaceScrubFromHAT,
589            FashionMNIST,
590            KannadaMNIST,
591            KMNIST,
592            Linnaeus5_32,
593            Linnaeus5_64,
594            Linnaeus5_128,
595            Linnaeus5_256,
596            MNIST,
597            NotMNIST,
598            NotMNISTFromHAT,
599            SignLanguageMNIST,
600            TrafficSignsFromHAT,
601            USPS,
602        ]:
603            # dataset classes that have `train` bool argument
604            dataset_test = original_dataset_python_class_t(
605                root=self.root[task_id],
606                train=False,
607                transform=self.test_transforms(task_id),
608                target_transform=self.target_transform(task_id),
609            )
610
611            return dataset_test
612
613        elif original_dataset_python_class_t in [
614            Country211,
615            DTD,
616            FER2013,
617            FGVCAircraftVariant,
618            FGVCAircraftFamily,
619            FGVCAircraftManufacturer,
620            Flowers102,
621            Food101,
622            GTSRB,
623            OxfordIIITPet2,
624            OxfordIIITPet37,
625            PCAM,
626            RenderedSST2,
627            StanfordCars,
628            SVHN,
629        ]:
630            # dataset classes that have `split` argument with 'test'
631
632            dataset_test = original_dataset_python_class_t(
633                root=self.root[task_id],
634                split="test",
635                transform=self.test_transforms(task_id),
636                target_transform=self.target_transform(task_id),
637            )
638
639            return dataset_test
640
641        elif original_dataset_python_class_t in [
642            Caltech101,
643            Caltech256,
644            EuroSAT,
645            SEMEION,
646            SUN397,
647        ]:
648            # dataset classes that don't have train and test splt
649
650            dataset_all = original_dataset_python_class_t(
651                root=self.root[task_id],
652                transform=self.train_and_val_transforms(task_id),
653                target_transform=self.target_transform(task_id),
654            )
655
656            _, dataset_test = random_split(
657                dataset_all,
658                lengths=[1 - self.test_percentage, self.test_percentage],
659                generator=torch.Generator().manual_seed(42),
660            )
661
662            return dataset_test
663
664        elif original_dataset_python_class_t in [CelebA]:
665            # special case
666            dataset_test = original_dataset_python_class_t(
667                root=self.root[task_id],
668                split="test",
669                target_type="identity",
670                transform=self.test_transforms(task_id),
671                target_transform=self.target_transform(task_id),
672            )
673
674            return dataset_test
675
676        elif original_dataset_python_class_t in [TinyImageNet]:
677            # special case
678            dataset_test = original_dataset_python_class_t(
679                root=self.root[task_id],
680                split="val",
681                transform=self.test_transforms(task_id),
682                target_transform=self.target_transform(task_id),
683            )
684
685            return dataset_test
class Combined(clarena.mtl_datasets.base.MTLCombinedDataset):
 75class Combined(MTLCombinedDataset):
 76    r"""Combined MTL dataset from available datasets."""
 77
 78    AVAILABLE_DATASETS: list[VisionDataset] = [
 79        ArabicHandwrittenDigits,
 80        CIFAR10,
 81        CIFAR100,
 82        CUB2002011,
 83        Caltech101,
 84        Caltech256,
 85        CelebA,
 86        Country211,
 87        DTD,
 88        EMNISTBalanced,
 89        EMNISTByClass,
 90        EMNISTByMerge,
 91        EMNISTDigits,
 92        EMNISTLetters,
 93        EuroSAT,
 94        FER2013,
 95        FGVCAircraftFamily,
 96        FGVCAircraftManufacturer,
 97        FGVCAircraftVariant,
 98        FaceScrub10,
 99        FaceScrub100,
100        FaceScrubFromHAT,
101        FaceScrub20,
102        FaceScrub50,
103        FashionMNIST,
104        Flowers102,
105        Food101,
106        GTSRB,
107        KMNIST,
108        KannadaMNIST,
109        Linnaeus5_128,
110        Linnaeus5_256,
111        Linnaeus5_32,
112        Linnaeus5_64,
113        MNIST,
114        NotMNIST,
115        NotMNISTFromHAT,
116        OxfordIIITPet2,
117        OxfordIIITPet37,
118        PCAM,
119        RenderedSST2,
120        SEMEION,
121        SUN397,
122        SVHN,
123        SignLanguageMNIST,
124        StanfordCars,
125        TrafficSignsFromHAT,
126        TinyImageNet,
127        USPS,
128    ]
129    r"""The list of available datasets."""
130
131    def __init__(
132        self,
133        datasets: dict[int, str],
134        root: str | dict[int, str],
135        validation_percentage: float,
136        test_percentage: float,
137        sampling_strategy: str = "mixed",
138        batch_size: int = 1,
139        num_workers: int = 0,
140        custom_transforms: (
141            Callable
142            | transforms.Compose
143            | None
144            | dict[int, Callable | transforms.Compose | None]
145        ) = None,
146        repeat_channels: int | None | dict[int, int | None] = None,
147        to_tensor: bool | dict[int, bool] = True,
148        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
149    ) -> None:
150        r"""
151        **Args:**
152        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
153        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories.
154        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset).
155        - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset).
156        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
157            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
158        - **batch_size** (`int`): The batch size in train, val, test dataloader.
159        - **num_workers** (`int`): the number of workers for dataloaders.
160        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included.
161        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
162        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
163        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
164        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
165        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
166        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
167        """
168        super().__init__(
169            datasets=datasets,
170            root=root,
171            sampling_strategy=sampling_strategy,
172            batch_size=batch_size,
173            num_workers=num_workers,
174            custom_transforms=custom_transforms,
175            repeat_channels=repeat_channels,
176            to_tensor=to_tensor,
177            resize=resize,
178        )
179
180        self.test_percentage: float = test_percentage
181        """The percentage to randomly split some data into test data."""
182        self.validation_percentage: float = validation_percentage
183        """The percentage to randomly split some training data into validation data."""
184
185    def prepare_data(self) -> None:
186        r"""Download the original datasets if haven't."""
187
188        failed_dataset_classes = []
189        for task_id in range(1, self.num_tasks + 1):
190            root = self.root[task_id]
191            dataset_class = self.original_dataset_python_classes[task_id]
192            # torchvision datasets might have different APIs
193            try:
194                # collect the error and raise it at the end to avoid stopping the whole download process
195
196                if dataset_class in [
197                    ArabicHandwrittenDigits,
198                    KannadaMNIST,
199                    SignLanguageMNIST,
200                ]:
201                    # these datasets have no automatic download function, we require users to download them manually
202                    # the following code is just to check if the dataset is already downloaded
203                    dataset_class(root=root, train=True, download=False)
204                    dataset_class(root=root, train=False, download=False)
205
206                elif dataset_class in [
207                    Caltech101,
208                    Caltech256,
209                    EuroSAT,
210                    SEMEION,
211                    SUN397,
212                ]:
213                    # dataset classes that don't have any train, val, test split
214                    dataset_class(root=root, download=True)
215
216                elif dataset_class in [
217                    ArabicHandwrittenDigits,
218                    CIFAR10,
219                    CIFAR100,
220                    CUB2002011,
221                    EMNISTByClass,
222                    EMNISTByMerge,
223                    EMNISTBalanced,
224                    EMNISTLetters,
225                    EMNISTDigits,
226                    FaceScrub10,
227                    FaceScrub20,
228                    FaceScrub50,
229                    FaceScrub100,
230                    FaceScrubFromHAT,
231                    FashionMNIST,
232                    KannadaMNIST,
233                    KMNIST,
234                    Linnaeus5_32,
235                    Linnaeus5_64,
236                    Linnaeus5_128,
237                    Linnaeus5_256,
238                    MNIST,
239                    NotMNIST,
240                    NotMNISTFromHAT,
241                    SignLanguageMNIST,
242                    TrafficSignsFromHAT,
243                    USPS,
244                ]:
245                    # dataset classes that have `train` bool argument
246                    dataset_class(root=root, train=True, download=True)
247                    dataset_class(root=root, train=False, download=True)
248
249                elif dataset_class in [
250                    Food101,
251                    GTSRB,
252                    StanfordCars,
253                    SVHN,
254                ]:
255                    # dataset classes that have `split` argument with 'train', 'test'
256                    dataset_class(
257                        root=root,
258                        split="train",
259                        download=True,
260                    )
261                    dataset_class(
262                        root=root,
263                        split="test",
264                        download=True,
265                    )
266                elif dataset_class in [Country211]:
267                    # dataset classes that have `split` argument with 'train', 'valid', 'test'
268                    dataset_class(
269                        root=root,
270                        split="train",
271                        download=True,
272                    )
273                    dataset_class(
274                        root=root,
275                        split="valid",
276                        download=True,
277                    )
278                    dataset_class(
279                        root=root,
280                        split="test",
281                        download=True,
282                    )
283
284                elif dataset_class in [
285                    DTD,
286                    FGVCAircraftVariant,
287                    FGVCAircraftFamily,
288                    FGVCAircraftManufacturer,
289                    Flowers102,
290                    PCAM,
291                    RenderedSST2,
292                ]:
293                    # dataset classes that have `split` argument with 'train', 'val', 'test'
294                    dataset_class(
295                        root=root,
296                        split="train",
297                        download=True,
298                    )
299                    dataset_class(
300                        root=root,
301                        split="val",
302                        download=True,
303                    )
304                    dataset_class(
305                        root=root,
306                        split="test",
307                        download=True,
308                    )
309                elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]:
310                    # dataset classes that have `split` argument with 'trainval', 'test'
311                    dataset_class(
312                        root=root,
313                        split="trainval",
314                        download=True,
315                    )
316                    dataset_class(
317                        root=root,
318                        split="test",
319                        download=True,
320                    )
321                elif dataset_class == CelebA:
322                    # special case
323                    dataset_class(
324                        root=root,
325                        split="train",
326                        target_type="identity",
327                        download=True,
328                    )
329                    dataset_class(
330                        root=root,
331                        split="valid",
332                        target_type="identity",
333                        download=True,
334                    )
335                    dataset_class(
336                        root=root,
337                        split="test",
338                        target_type="identity",
339                        download=True,
340                    )
341                elif dataset_class == FER2013:
342                    # special case
343                    dataset_class(
344                        root=root,
345                        split="train",
346                    )
347                    dataset_class(
348                        root=root,
349                        split="test",
350                    )
351                elif dataset_class == TinyImageNet:
352                    # special case
353                    dataset_class(root=root)
354
355            except RuntimeError:
356                failed_dataset_classes.append(dataset_class)  # save for later prompt
357            else:
358                pylogger.debug(
359                    "The original %s dataset for task %s has been downloaded to %s.",
360                    dataset_class,
361                    task_id,
362                    root,
363                )
364
365        if failed_dataset_classes:
366            raise RuntimeError(
367                f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually."
368            )
369
370    def train_and_val_dataset(self, task_id: int) -> tuple[Dataset, Dataset]:
371        r"""Get the training and validation dataset of task `task_id`.
372
373        **Args:**
374        - **task_id** (`int`): the task ID to get the training and validation dataset for.
375
376        **Returns:**
377        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `task_id`.
378        """
379        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
380
381        # torchvision datasets might have different APIs
382        if original_dataset_python_class_t in [
383            ArabicHandwrittenDigits,
384            CIFAR10,
385            CIFAR100,
386            CUB2002011,
387            EMNISTByClass,
388            EMNISTByMerge,
389            EMNISTBalanced,
390            EMNISTLetters,
391            EMNISTDigits,
392            FaceScrub10,
393            FaceScrub20,
394            FaceScrub50,
395            FaceScrub100,
396            FaceScrubFromHAT,
397            FashionMNIST,
398            KannadaMNIST,
399            KMNIST,
400            Linnaeus5_32,
401            Linnaeus5_64,
402            Linnaeus5_128,
403            Linnaeus5_256,
404            MNIST,
405            NotMNIST,
406            NotMNISTFromHAT,
407            SignLanguageMNIST,
408            TrafficSignsFromHAT,
409            USPS,
410        ]:
411            # dataset classes that have `train` bool argument
412            dataset_train_and_val = original_dataset_python_class_t(
413                root=self.root[task_id],
414                train=True,
415                transform=self.train_and_val_transforms(task_id),
416                target_transform=self.target_transform(task_id),
417            )
418
419            return random_split(
420                dataset_train_and_val,
421                lengths=[1 - self.validation_percentage, self.validation_percentage],
422                generator=torch.Generator().manual_seed(
423                    42
424                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
425            )
426        elif original_dataset_python_class_t in [
427            Caltech101,
428            Caltech256,
429            EuroSAT,
430            SEMEION,
431            SUN397,
432        ]:
433            # dataset classes that don't have train and test splt
434            dataset_all = original_dataset_python_class_t(
435                root=self.root[task_id],
436                transform=self.train_and_val_transforms(task_id),
437                target_transform=self.target_transform(task_id),
438            )
439
440            dataset_train_and_val, _ = random_split(
441                dataset_all,
442                lengths=[
443                    1 - self.test_percentage,
444                    self.test_percentage,
445                ],
446                generator=torch.Generator().manual_seed(
447                    42
448                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
449            )
450
451            return random_split(
452                dataset_train_and_val,
453                lengths=[1 - self.validation_percentage, self.validation_percentage],
454                generator=torch.Generator().manual_seed(
455                    42
456                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
457            )
458        elif original_dataset_python_class_t in [Country211]:
459            # dataset classes that have `split` argument with 'train', 'valid', 'test'
460            dataset_train = original_dataset_python_class_t(
461                root=self.root[task_id],
462                split="train",
463                transform=self.train_and_val_transforms(task_id),
464                target_transform=self.target_transform(task_id),
465            )
466            dataset_val = original_dataset_python_class_t(
467                root=self.root[task_id],
468                split="valid",
469                transform=self.train_and_val_transforms(task_id),
470                target_transform=self.target_transform(task_id),
471            )
472
473            return dataset_train, dataset_val
474
475        elif original_dataset_python_class_t in [
476            DTD,
477            FGVCAircraftVariant,
478            FGVCAircraftFamily,
479            FGVCAircraftManufacturer,
480            Flowers102,
481            PCAM,
482            RenderedSST2,
483        ]:
484            # dataset classes that have `split` argument with 'train', 'val', 'test'
485            dataset_train = original_dataset_python_class_t(
486                root=self.root[task_id],
487                split="train",
488                transform=self.train_and_val_transforms(task_id),
489                target_transform=self.target_transform(task_id),
490            )
491
492            dataset_val = original_dataset_python_class_t(
493                root=self.root[task_id],
494                split="val",
495                transform=self.train_and_val_transforms(task_id),
496                target_transform=self.target_transform(task_id),
497            )
498
499            return dataset_train, dataset_val
500
501        elif original_dataset_python_class_t in [
502            FER2013,
503            Food101,
504            GTSRB,
505            StanfordCars,
506            SVHN,
507            TinyImageNet,
508        ]:
509            # dataset classes that have `split` argument with 'train', 'test'
510
511            dataset_train_and_val = original_dataset_python_class_t(
512                root=self.root[task_id],
513                split="train",
514                transform=self.train_and_val_transforms(task_id),
515                target_transform=self.target_transform(task_id),
516            )
517
518            return random_split(
519                dataset_train_and_val,
520                lengths=[1 - self.validation_percentage, self.validation_percentage],
521                generator=torch.Generator().manual_seed(
522                    42
523                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
524            )
525
526        elif original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]:
527            # dataset classes that have `split` argument with 'trainval', 'test'
528
529            dataset_train_and_val = original_dataset_python_class_t(
530                root=self.root[task_id],
531                split="trainval",
532                transform=self.train_and_val_transforms(task_id),
533                target_transform=self.target_transform(task_id),
534            )
535
536            return random_split(
537                dataset_train_and_val,
538                lengths=[1 - self.validation_percentage, self.validation_percentage],
539                generator=torch.Generator().manual_seed(
540                    42
541                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
542            )
543
544        elif original_dataset_python_class_t in [CelebA]:
545            # special case
546            dataset_train = original_dataset_python_class_t(
547                root=self.root[task_id],
548                split="train",
549                target_type="identity",
550                transform=self.train_and_val_transforms(task_id),
551                target_transform=self.target_transform(task_id),
552            )
553
554            dataset_val = original_dataset_python_class_t(
555                root=self.root[task_id],
556                split="valid",
557                target_type="identity",
558                transform=self.train_and_val_transforms(task_id),
559                target_transform=self.target_transform(task_id),
560            )
561
562            return dataset_train, dataset_val
563
564    def test_dataset(self, task_id: int) -> Dataset:
565        """Get the test dataset of task `task_id`.
566
567        **Args:**
568        - **task_id** (`int`): the task ID to get the test dataset for.
569
570        **Returns:**
571        - **test_dataset** (`Dataset`): the test dataset of task `task_id`.
572        """
573        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
574
575        if original_dataset_python_class_t in [
576            ArabicHandwrittenDigits,
577            CIFAR10,
578            CIFAR100,
579            CUB2002011,
580            EMNISTByClass,
581            EMNISTByMerge,
582            EMNISTBalanced,
583            EMNISTLetters,
584            EMNISTDigits,
585            FaceScrub10,
586            FaceScrub20,
587            FaceScrub50,
588            FaceScrub100,
589            FaceScrubFromHAT,
590            FashionMNIST,
591            KannadaMNIST,
592            KMNIST,
593            Linnaeus5_32,
594            Linnaeus5_64,
595            Linnaeus5_128,
596            Linnaeus5_256,
597            MNIST,
598            NotMNIST,
599            NotMNISTFromHAT,
600            SignLanguageMNIST,
601            TrafficSignsFromHAT,
602            USPS,
603        ]:
604            # dataset classes that have `train` bool argument
605            dataset_test = original_dataset_python_class_t(
606                root=self.root[task_id],
607                train=False,
608                transform=self.test_transforms(task_id),
609                target_transform=self.target_transform(task_id),
610            )
611
612            return dataset_test
613
614        elif original_dataset_python_class_t in [
615            Country211,
616            DTD,
617            FER2013,
618            FGVCAircraftVariant,
619            FGVCAircraftFamily,
620            FGVCAircraftManufacturer,
621            Flowers102,
622            Food101,
623            GTSRB,
624            OxfordIIITPet2,
625            OxfordIIITPet37,
626            PCAM,
627            RenderedSST2,
628            StanfordCars,
629            SVHN,
630        ]:
631            # dataset classes that have `split` argument with 'test'
632
633            dataset_test = original_dataset_python_class_t(
634                root=self.root[task_id],
635                split="test",
636                transform=self.test_transforms(task_id),
637                target_transform=self.target_transform(task_id),
638            )
639
640            return dataset_test
641
642        elif original_dataset_python_class_t in [
643            Caltech101,
644            Caltech256,
645            EuroSAT,
646            SEMEION,
647            SUN397,
648        ]:
649            # dataset classes that don't have train and test splt
650
651            dataset_all = original_dataset_python_class_t(
652                root=self.root[task_id],
653                transform=self.train_and_val_transforms(task_id),
654                target_transform=self.target_transform(task_id),
655            )
656
657            _, dataset_test = random_split(
658                dataset_all,
659                lengths=[1 - self.test_percentage, self.test_percentage],
660                generator=torch.Generator().manual_seed(42),
661            )
662
663            return dataset_test
664
665        elif original_dataset_python_class_t in [CelebA]:
666            # special case
667            dataset_test = original_dataset_python_class_t(
668                root=self.root[task_id],
669                split="test",
670                target_type="identity",
671                transform=self.test_transforms(task_id),
672                target_transform=self.target_transform(task_id),
673            )
674
675            return dataset_test
676
677        elif original_dataset_python_class_t in [TinyImageNet]:
678            # special case
679            dataset_test = original_dataset_python_class_t(
680                root=self.root[task_id],
681                split="val",
682                transform=self.test_transforms(task_id),
683                target_transform=self.target_transform(task_id),
684            )
685
686            return dataset_test

Combined MTL dataset from available datasets.

Combined( datasets: dict[int, str], root: str | dict[int, str], validation_percentage: float, test_percentage: float, sampling_strategy: str = 'mixed', batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
131    def __init__(
132        self,
133        datasets: dict[int, str],
134        root: str | dict[int, str],
135        validation_percentage: float,
136        test_percentage: float,
137        sampling_strategy: str = "mixed",
138        batch_size: int = 1,
139        num_workers: int = 0,
140        custom_transforms: (
141            Callable
142            | transforms.Compose
143            | None
144            | dict[int, Callable | transforms.Compose | None]
145        ) = None,
146        repeat_channels: int | None | dict[int, int | None] = None,
147        to_tensor: bool | dict[int, bool] = True,
148        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
149    ) -> None:
150        r"""
151        **Args:**
152        - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
153        - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories.
154        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset).
155        - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset).
156        - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of:
157            - 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
158        - **batch_size** (`int`): The batch size in train, val, test dataloader.
159        - **num_workers** (`int`): the number of workers for dataloaders.
160        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included.
161        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
162        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
163        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
164        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
165        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
166        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
167        """
168        super().__init__(
169            datasets=datasets,
170            root=root,
171            sampling_strategy=sampling_strategy,
172            batch_size=batch_size,
173            num_workers=num_workers,
174            custom_transforms=custom_transforms,
175            repeat_channels=repeat_channels,
176            to_tensor=to_tensor,
177            resize=resize,
178        )
179
180        self.test_percentage: float = test_percentage
181        """The percentage to randomly split some data into test data."""
182        self.validation_percentage: float = validation_percentage
183        """The percentage to randomly split some training data into validation data."""

Args:

  • datasets (dict[int, str]): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task.
  • root (str | dict[int, str]): the root directory where the original data files for constructing the MTL dataset physically live. If dict[int, str], it should be a dict of task IDs and their corresponding root directories.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset).
  • test_percentage (float): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset).
  • sampling_strategy (str): the sampling strategy that construct training batch from each task's dataset; one of:
    • 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
AVAILABLE_DATASETS: list[torchvision.datasets.vision.VisionDataset] = [<class 'clarena.stl_datasets.raw.ahdd.ArabicHandwrittenDigits'>, <class 'torchvision.datasets.cifar.CIFAR10'>, <class 'torchvision.datasets.cifar.CIFAR100'>, <class 'clarena.stl_datasets.raw.cub2002011.CUB2002011'>, <class 'torchvision.datasets.caltech.Caltech101'>, <class 'torchvision.datasets.caltech.Caltech256'>, <class 'torchvision.datasets.celeba.CelebA'>, <class 'torchvision.datasets.country211.Country211'>, <class 'torchvision.datasets.dtd.DTD'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTBalanced'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTByClass'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTByMerge'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTDigits'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTLetters'>, <class 'torchvision.datasets.eurosat.EuroSAT'>, <class 'torchvision.datasets.fer2013.FER2013'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftFamily'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftManufacturer'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftVariant'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub10'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub100'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrubFromHAT'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub20'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub50'>, <class 'torchvision.datasets.mnist.FashionMNIST'>, <class 'torchvision.datasets.flowers102.Flowers102'>, <class 'torchvision.datasets.food101.Food101'>, <class 'torchvision.datasets.gtsrb.GTSRB'>, <class 'torchvision.datasets.mnist.KMNIST'>, <class 'clarena.stl_datasets.raw.kannada_mnist.KannadaMNIST'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_128'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_256'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_32'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_64'>, <class 'torchvision.datasets.mnist.MNIST'>, <class 'clarena.stl_datasets.raw.notmnist.NotMNIST'>, <class 'clarena.stl_datasets.raw.notmnist.NotMNISTFromHAT'>, <class 'clarena.stl_datasets.raw.oxford_iiit_pet.OxfordIIITPet2'>, <class 'clarena.stl_datasets.raw.oxford_iiit_pet.OxfordIIITPet37'>, <class 'torchvision.datasets.pcam.PCAM'>, <class 'torchvision.datasets.rendered_sst2.RenderedSST2'>, <class 'torchvision.datasets.semeion.SEMEION'>, <class 'torchvision.datasets.sun397.SUN397'>, <class 'torchvision.datasets.svhn.SVHN'>, <class 'clarena.stl_datasets.raw.sign_language_mnist.SignLanguageMNIST'>, <class 'torchvision.datasets.stanford_cars.StanfordCars'>, <class 'clarena.stl_datasets.raw.traffic_signs.TrafficSignsFromHAT'>, <class 'tinyimagenet.TinyImageNet'>, <class 'torchvision.datasets.usps.USPS'>]

The list of available datasets.

test_percentage: float

The percentage to randomly split some data into test data.

validation_percentage: float

The percentage to randomly split some training data into validation data.

def prepare_data(self) -> None:
185    def prepare_data(self) -> None:
186        r"""Download the original datasets if haven't."""
187
188        failed_dataset_classes = []
189        for task_id in range(1, self.num_tasks + 1):
190            root = self.root[task_id]
191            dataset_class = self.original_dataset_python_classes[task_id]
192            # torchvision datasets might have different APIs
193            try:
194                # collect the error and raise it at the end to avoid stopping the whole download process
195
196                if dataset_class in [
197                    ArabicHandwrittenDigits,
198                    KannadaMNIST,
199                    SignLanguageMNIST,
200                ]:
201                    # these datasets have no automatic download function, we require users to download them manually
202                    # the following code is just to check if the dataset is already downloaded
203                    dataset_class(root=root, train=True, download=False)
204                    dataset_class(root=root, train=False, download=False)
205
206                elif dataset_class in [
207                    Caltech101,
208                    Caltech256,
209                    EuroSAT,
210                    SEMEION,
211                    SUN397,
212                ]:
213                    # dataset classes that don't have any train, val, test split
214                    dataset_class(root=root, download=True)
215
216                elif dataset_class in [
217                    ArabicHandwrittenDigits,
218                    CIFAR10,
219                    CIFAR100,
220                    CUB2002011,
221                    EMNISTByClass,
222                    EMNISTByMerge,
223                    EMNISTBalanced,
224                    EMNISTLetters,
225                    EMNISTDigits,
226                    FaceScrub10,
227                    FaceScrub20,
228                    FaceScrub50,
229                    FaceScrub100,
230                    FaceScrubFromHAT,
231                    FashionMNIST,
232                    KannadaMNIST,
233                    KMNIST,
234                    Linnaeus5_32,
235                    Linnaeus5_64,
236                    Linnaeus5_128,
237                    Linnaeus5_256,
238                    MNIST,
239                    NotMNIST,
240                    NotMNISTFromHAT,
241                    SignLanguageMNIST,
242                    TrafficSignsFromHAT,
243                    USPS,
244                ]:
245                    # dataset classes that have `train` bool argument
246                    dataset_class(root=root, train=True, download=True)
247                    dataset_class(root=root, train=False, download=True)
248
249                elif dataset_class in [
250                    Food101,
251                    GTSRB,
252                    StanfordCars,
253                    SVHN,
254                ]:
255                    # dataset classes that have `split` argument with 'train', 'test'
256                    dataset_class(
257                        root=root,
258                        split="train",
259                        download=True,
260                    )
261                    dataset_class(
262                        root=root,
263                        split="test",
264                        download=True,
265                    )
266                elif dataset_class in [Country211]:
267                    # dataset classes that have `split` argument with 'train', 'valid', 'test'
268                    dataset_class(
269                        root=root,
270                        split="train",
271                        download=True,
272                    )
273                    dataset_class(
274                        root=root,
275                        split="valid",
276                        download=True,
277                    )
278                    dataset_class(
279                        root=root,
280                        split="test",
281                        download=True,
282                    )
283
284                elif dataset_class in [
285                    DTD,
286                    FGVCAircraftVariant,
287                    FGVCAircraftFamily,
288                    FGVCAircraftManufacturer,
289                    Flowers102,
290                    PCAM,
291                    RenderedSST2,
292                ]:
293                    # dataset classes that have `split` argument with 'train', 'val', 'test'
294                    dataset_class(
295                        root=root,
296                        split="train",
297                        download=True,
298                    )
299                    dataset_class(
300                        root=root,
301                        split="val",
302                        download=True,
303                    )
304                    dataset_class(
305                        root=root,
306                        split="test",
307                        download=True,
308                    )
309                elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]:
310                    # dataset classes that have `split` argument with 'trainval', 'test'
311                    dataset_class(
312                        root=root,
313                        split="trainval",
314                        download=True,
315                    )
316                    dataset_class(
317                        root=root,
318                        split="test",
319                        download=True,
320                    )
321                elif dataset_class == CelebA:
322                    # special case
323                    dataset_class(
324                        root=root,
325                        split="train",
326                        target_type="identity",
327                        download=True,
328                    )
329                    dataset_class(
330                        root=root,
331                        split="valid",
332                        target_type="identity",
333                        download=True,
334                    )
335                    dataset_class(
336                        root=root,
337                        split="test",
338                        target_type="identity",
339                        download=True,
340                    )
341                elif dataset_class == FER2013:
342                    # special case
343                    dataset_class(
344                        root=root,
345                        split="train",
346                    )
347                    dataset_class(
348                        root=root,
349                        split="test",
350                    )
351                elif dataset_class == TinyImageNet:
352                    # special case
353                    dataset_class(root=root)
354
355            except RuntimeError:
356                failed_dataset_classes.append(dataset_class)  # save for later prompt
357            else:
358                pylogger.debug(
359                    "The original %s dataset for task %s has been downloaded to %s.",
360                    dataset_class,
361                    task_id,
362                    root,
363                )
364
365        if failed_dataset_classes:
366            raise RuntimeError(
367                f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually."
368            )

Download the original datasets if haven't.

def train_and_val_dataset( self, task_id: int) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
370    def train_and_val_dataset(self, task_id: int) -> tuple[Dataset, Dataset]:
371        r"""Get the training and validation dataset of task `task_id`.
372
373        **Args:**
374        - **task_id** (`int`): the task ID to get the training and validation dataset for.
375
376        **Returns:**
377        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `task_id`.
378        """
379        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
380
381        # torchvision datasets might have different APIs
382        if original_dataset_python_class_t in [
383            ArabicHandwrittenDigits,
384            CIFAR10,
385            CIFAR100,
386            CUB2002011,
387            EMNISTByClass,
388            EMNISTByMerge,
389            EMNISTBalanced,
390            EMNISTLetters,
391            EMNISTDigits,
392            FaceScrub10,
393            FaceScrub20,
394            FaceScrub50,
395            FaceScrub100,
396            FaceScrubFromHAT,
397            FashionMNIST,
398            KannadaMNIST,
399            KMNIST,
400            Linnaeus5_32,
401            Linnaeus5_64,
402            Linnaeus5_128,
403            Linnaeus5_256,
404            MNIST,
405            NotMNIST,
406            NotMNISTFromHAT,
407            SignLanguageMNIST,
408            TrafficSignsFromHAT,
409            USPS,
410        ]:
411            # dataset classes that have `train` bool argument
412            dataset_train_and_val = original_dataset_python_class_t(
413                root=self.root[task_id],
414                train=True,
415                transform=self.train_and_val_transforms(task_id),
416                target_transform=self.target_transform(task_id),
417            )
418
419            return random_split(
420                dataset_train_and_val,
421                lengths=[1 - self.validation_percentage, self.validation_percentage],
422                generator=torch.Generator().manual_seed(
423                    42
424                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
425            )
426        elif original_dataset_python_class_t in [
427            Caltech101,
428            Caltech256,
429            EuroSAT,
430            SEMEION,
431            SUN397,
432        ]:
433            # dataset classes that don't have train and test splt
434            dataset_all = original_dataset_python_class_t(
435                root=self.root[task_id],
436                transform=self.train_and_val_transforms(task_id),
437                target_transform=self.target_transform(task_id),
438            )
439
440            dataset_train_and_val, _ = random_split(
441                dataset_all,
442                lengths=[
443                    1 - self.test_percentage,
444                    self.test_percentage,
445                ],
446                generator=torch.Generator().manual_seed(
447                    42
448                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
449            )
450
451            return random_split(
452                dataset_train_and_val,
453                lengths=[1 - self.validation_percentage, self.validation_percentage],
454                generator=torch.Generator().manual_seed(
455                    42
456                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
457            )
458        elif original_dataset_python_class_t in [Country211]:
459            # dataset classes that have `split` argument with 'train', 'valid', 'test'
460            dataset_train = original_dataset_python_class_t(
461                root=self.root[task_id],
462                split="train",
463                transform=self.train_and_val_transforms(task_id),
464                target_transform=self.target_transform(task_id),
465            )
466            dataset_val = original_dataset_python_class_t(
467                root=self.root[task_id],
468                split="valid",
469                transform=self.train_and_val_transforms(task_id),
470                target_transform=self.target_transform(task_id),
471            )
472
473            return dataset_train, dataset_val
474
475        elif original_dataset_python_class_t in [
476            DTD,
477            FGVCAircraftVariant,
478            FGVCAircraftFamily,
479            FGVCAircraftManufacturer,
480            Flowers102,
481            PCAM,
482            RenderedSST2,
483        ]:
484            # dataset classes that have `split` argument with 'train', 'val', 'test'
485            dataset_train = original_dataset_python_class_t(
486                root=self.root[task_id],
487                split="train",
488                transform=self.train_and_val_transforms(task_id),
489                target_transform=self.target_transform(task_id),
490            )
491
492            dataset_val = original_dataset_python_class_t(
493                root=self.root[task_id],
494                split="val",
495                transform=self.train_and_val_transforms(task_id),
496                target_transform=self.target_transform(task_id),
497            )
498
499            return dataset_train, dataset_val
500
501        elif original_dataset_python_class_t in [
502            FER2013,
503            Food101,
504            GTSRB,
505            StanfordCars,
506            SVHN,
507            TinyImageNet,
508        ]:
509            # dataset classes that have `split` argument with 'train', 'test'
510
511            dataset_train_and_val = original_dataset_python_class_t(
512                root=self.root[task_id],
513                split="train",
514                transform=self.train_and_val_transforms(task_id),
515                target_transform=self.target_transform(task_id),
516            )
517
518            return random_split(
519                dataset_train_and_val,
520                lengths=[1 - self.validation_percentage, self.validation_percentage],
521                generator=torch.Generator().manual_seed(
522                    42
523                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
524            )
525
526        elif original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]:
527            # dataset classes that have `split` argument with 'trainval', 'test'
528
529            dataset_train_and_val = original_dataset_python_class_t(
530                root=self.root[task_id],
531                split="trainval",
532                transform=self.train_and_val_transforms(task_id),
533                target_transform=self.target_transform(task_id),
534            )
535
536            return random_split(
537                dataset_train_and_val,
538                lengths=[1 - self.validation_percentage, self.validation_percentage],
539                generator=torch.Generator().manual_seed(
540                    42
541                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
542            )
543
544        elif original_dataset_python_class_t in [CelebA]:
545            # special case
546            dataset_train = original_dataset_python_class_t(
547                root=self.root[task_id],
548                split="train",
549                target_type="identity",
550                transform=self.train_and_val_transforms(task_id),
551                target_transform=self.target_transform(task_id),
552            )
553
554            dataset_val = original_dataset_python_class_t(
555                root=self.root[task_id],
556                split="valid",
557                target_type="identity",
558                transform=self.train_and_val_transforms(task_id),
559                target_transform=self.target_transform(task_id),
560            )
561
562            return dataset_train, dataset_val

Get the training and validation dataset of task task_id.

Args:

  • task_id (int): the task ID to get the training and validation dataset for.

Returns:

  • train_and_val_dataset (tuple[Dataset, Dataset]): the train and validation dataset of task task_id.
def test_dataset(self, task_id: int) -> torch.utils.data.dataset.Dataset:
564    def test_dataset(self, task_id: int) -> Dataset:
565        """Get the test dataset of task `task_id`.
566
567        **Args:**
568        - **task_id** (`int`): the task ID to get the test dataset for.
569
570        **Returns:**
571        - **test_dataset** (`Dataset`): the test dataset of task `task_id`.
572        """
573        original_dataset_python_class_t = self.original_dataset_python_classes[task_id]
574
575        if original_dataset_python_class_t in [
576            ArabicHandwrittenDigits,
577            CIFAR10,
578            CIFAR100,
579            CUB2002011,
580            EMNISTByClass,
581            EMNISTByMerge,
582            EMNISTBalanced,
583            EMNISTLetters,
584            EMNISTDigits,
585            FaceScrub10,
586            FaceScrub20,
587            FaceScrub50,
588            FaceScrub100,
589            FaceScrubFromHAT,
590            FashionMNIST,
591            KannadaMNIST,
592            KMNIST,
593            Linnaeus5_32,
594            Linnaeus5_64,
595            Linnaeus5_128,
596            Linnaeus5_256,
597            MNIST,
598            NotMNIST,
599            NotMNISTFromHAT,
600            SignLanguageMNIST,
601            TrafficSignsFromHAT,
602            USPS,
603        ]:
604            # dataset classes that have `train` bool argument
605            dataset_test = original_dataset_python_class_t(
606                root=self.root[task_id],
607                train=False,
608                transform=self.test_transforms(task_id),
609                target_transform=self.target_transform(task_id),
610            )
611
612            return dataset_test
613
614        elif original_dataset_python_class_t in [
615            Country211,
616            DTD,
617            FER2013,
618            FGVCAircraftVariant,
619            FGVCAircraftFamily,
620            FGVCAircraftManufacturer,
621            Flowers102,
622            Food101,
623            GTSRB,
624            OxfordIIITPet2,
625            OxfordIIITPet37,
626            PCAM,
627            RenderedSST2,
628            StanfordCars,
629            SVHN,
630        ]:
631            # dataset classes that have `split` argument with 'test'
632
633            dataset_test = original_dataset_python_class_t(
634                root=self.root[task_id],
635                split="test",
636                transform=self.test_transforms(task_id),
637                target_transform=self.target_transform(task_id),
638            )
639
640            return dataset_test
641
642        elif original_dataset_python_class_t in [
643            Caltech101,
644            Caltech256,
645            EuroSAT,
646            SEMEION,
647            SUN397,
648        ]:
649            # dataset classes that don't have train and test splt
650
651            dataset_all = original_dataset_python_class_t(
652                root=self.root[task_id],
653                transform=self.train_and_val_transforms(task_id),
654                target_transform=self.target_transform(task_id),
655            )
656
657            _, dataset_test = random_split(
658                dataset_all,
659                lengths=[1 - self.test_percentage, self.test_percentage],
660                generator=torch.Generator().manual_seed(42),
661            )
662
663            return dataset_test
664
665        elif original_dataset_python_class_t in [CelebA]:
666            # special case
667            dataset_test = original_dataset_python_class_t(
668                root=self.root[task_id],
669                split="test",
670                target_type="identity",
671                transform=self.test_transforms(task_id),
672                target_transform=self.target_transform(task_id),
673            )
674
675            return dataset_test
676
677        elif original_dataset_python_class_t in [TinyImageNet]:
678            # special case
679            dataset_test = original_dataset_python_class_t(
680                root=self.root[task_id],
681                split="val",
682                transform=self.test_transforms(task_id),
683                target_transform=self.target_transform(task_id),
684            )
685
686            return dataset_test

Get the test dataset of task task_id.

Args:

  • task_id (int): the task ID to get the test dataset for.

Returns:

  • test_dataset (Dataset): the test dataset of task task_id.