clarena.cl_datasets.combined

The submodule in cl_datasets for combined datasets.

  1r"""
  2The submodule in `cl_datasets` for combined datasets.
  3"""
  4
  5__all__ = ["Combined"]
  6
  7import logging
  8from typing import Callable
  9
 10import torch
 11from tinyimagenet import TinyImageNet
 12from torch.utils.data import Dataset, random_split
 13from torchvision import transforms
 14from torchvision.datasets import (
 15    CIFAR10,
 16    CIFAR100,
 17    DTD,
 18    FER2013,
 19    GTSRB,
 20    KMNIST,
 21    MNIST,
 22    PCAM,
 23    SEMEION,
 24    SUN397,
 25    SVHN,
 26    USPS,
 27    Caltech101,
 28    Caltech256,
 29    CelebA,
 30    Country211,
 31    EuroSAT,
 32    FashionMNIST,
 33    Flowers102,
 34    Food101,
 35    RenderedSST2,
 36    StanfordCars,
 37)
 38from torchvision.datasets.vision import VisionDataset
 39
 40from clarena.cl_datasets import CLCombinedDataset
 41from clarena.stl_datasets.raw import (
 42    CUB2002011,
 43    ArabicHandwrittenDigits,
 44    EMNISTBalanced,
 45    EMNISTByClass,
 46    EMNISTByMerge,
 47    EMNISTDigits,
 48    EMNISTLetters,
 49    FaceScrub10,
 50    FaceScrub20,
 51    FaceScrub50,
 52    FaceScrub100,
 53    FaceScrubFromHAT,
 54    FGVCAircraftFamily,
 55    FGVCAircraftManufacturer,
 56    FGVCAircraftVariant,
 57    KannadaMNIST,
 58    Linnaeus5_32,
 59    Linnaeus5_64,
 60    Linnaeus5_128,
 61    Linnaeus5_256,
 62    NotMNIST,
 63    NotMNISTFromHAT,
 64    OxfordIIITPet2,
 65    OxfordIIITPet37,
 66    SignLanguageMNIST,
 67    TrafficSignsFromHAT,
 68)
 69
 70# always get logger for built-in logging in each module
 71pylogger = logging.getLogger(__name__)
 72
 73
 74class Combined(CLCombinedDataset):
 75    r"""Combined CL dataset from available datasets."""
 76
 77    AVAILABLE_DATASETS: list[VisionDataset] = [
 78        ArabicHandwrittenDigits,
 79        CIFAR10,
 80        CIFAR100,
 81        CUB2002011,
 82        Caltech101,
 83        Caltech256,
 84        CelebA,
 85        Country211,
 86        DTD,
 87        EMNISTBalanced,
 88        EMNISTByClass,
 89        EMNISTByMerge,
 90        EMNISTDigits,
 91        EMNISTLetters,
 92        EuroSAT,
 93        FER2013,
 94        FGVCAircraftFamily,
 95        FGVCAircraftManufacturer,
 96        FGVCAircraftVariant,
 97        FaceScrub10,
 98        FaceScrub100,
 99        FaceScrubFromHAT,
100        FaceScrub20,
101        FaceScrub50,
102        FashionMNIST,
103        Flowers102,
104        Food101,
105        GTSRB,
106        KMNIST,
107        KannadaMNIST,
108        Linnaeus5_128,
109        Linnaeus5_256,
110        Linnaeus5_32,
111        Linnaeus5_64,
112        MNIST,
113        NotMNIST,
114        NotMNISTFromHAT,
115        OxfordIIITPet2,
116        OxfordIIITPet37,
117        PCAM,
118        RenderedSST2,
119        SEMEION,
120        SUN397,
121        SVHN,
122        SignLanguageMNIST,
123        StanfordCars,
124        TrafficSignsFromHAT,
125        TinyImageNet,
126        USPS,
127    ]
128    r"""The list of available datasets."""
129
130    def __init__(
131        self,
132        datasets: list[str],
133        root: list[str],
134        validation_percentage: float,
135        test_percentage: float,
136        batch_size: int | list[int] = 1,
137        num_workers: int | list[int] = 0,
138        custom_transforms: (
139            Callable
140            | transforms.Compose
141            | None
142            | list[Callable | transforms.Compose | None]
143        ) = None,
144        repeat_channels: int | None | list[int | None] = None,
145        to_tensor: bool | list[bool] = True,
146        resize: tuple[int, int] | None | list[tuple[int, int] | None] = None,
147    ) -> None:
148        r"""Initialize the Combined Torchvision dataset object providing the root where data files live.
149
150        **Args:**
151        - **datasets** (`list[str]`): the list of dataset class paths for each task. Each element in the list must be a string referring to a valid PyTorch Dataset class. It needs to be one in `self.AVAILABLE_DATASETS`.
152        - **root** (`list[str]`): the list of root directory where the original data files for constructing the CL dataset physically live.
153        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset).
154        - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset).
155        - **batch_size** (`int` | `list[int]`): The batch size in train, val, test dataloader. If `list[str]`, it should be a list of integers, each integer is the batch size for each task.
156        - **num_workers** (`int` | `list[int]`): the number of workers for dataloaders. If `list[str]`, it should be a list of integers, each integer is the num of workers for each task.
157        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or list of them): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize, permute and so on are not included. If it is a list, each item is the custom transforms for each task.
158        - **repeat_channels** (`int` | `None` | list of them): the number of channels to repeat for each task. Default is None, which means no repeat. If not None, it should be an integer. If it is a list, each item is the number of channels to repeat for each task.
159        - **to_tensor** (`bool` | `list[bool]`): whether to include `ToTensor()` transform. Default is True.
160        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. If it is a list, each item is the size to resize for each task.
161        """
162        super().__init__(
163            datasets=datasets,
164            root=root,
165            batch_size=batch_size,
166            num_workers=num_workers,
167            custom_transforms=custom_transforms,
168            repeat_channels=repeat_channels,
169            to_tensor=to_tensor,
170            resize=resize,
171        )
172
173        self.test_percentage: float = test_percentage
174        """The percentage to randomly split some data into test data."""
175        self.validation_percentage: float = validation_percentage
176        """The percentage to randomly split some training data into validation data."""
177
178    def prepare_data(self) -> None:
179        r"""Download the original datasets if haven't."""
180
181        if self.task_id != 1:
182            return  # download all original datasets only at the beginning of first task
183
184        failed_dataset_classes = []
185        for task_id in range(1, self.num_tasks + 1):
186            root = self.root[task_id]
187            dataset_class = self.original_dataset_python_classes[task_id]
188            # torchvision datasets might have different APIs
189            try:
190                # collect the error and raise it at the end to avoid stopping the whole download process
191
192                if dataset_class in [
193                    ArabicHandwrittenDigits,
194                    KannadaMNIST,
195                    SignLanguageMNIST,
196                ]:
197                    # these datasets have no automatic download function, we require users to download them manually
198                    # the following code is just to check if the dataset is already downloaded
199                    dataset_class(root=root, train=True, download=False)
200                    dataset_class(root=root, train=False, download=False)
201
202                elif dataset_class in [
203                    Caltech101,
204                    Caltech256,
205                    EuroSAT,
206                    SEMEION,
207                    SUN397,
208                ]:
209                    # dataset classes that don't have any train, val, test split
210                    dataset_class(root=root, download=True)
211
212                elif dataset_class in [
213                    ArabicHandwrittenDigits,
214                    CIFAR10,
215                    CIFAR100,
216                    CUB2002011,
217                    EMNISTByClass,
218                    EMNISTByMerge,
219                    EMNISTBalanced,
220                    EMNISTLetters,
221                    EMNISTDigits,
222                    FaceScrub10,
223                    FaceScrub20,
224                    FaceScrub50,
225                    FaceScrub100,
226                    FaceScrubFromHAT,
227                    FashionMNIST,
228                    KannadaMNIST,
229                    KMNIST,
230                    Linnaeus5_32,
231                    Linnaeus5_64,
232                    Linnaeus5_128,
233                    Linnaeus5_256,
234                    MNIST,
235                    NotMNIST,
236                    NotMNISTFromHAT,
237                    SignLanguageMNIST,
238                    TrafficSignsFromHAT,
239                    USPS,
240                ]:
241                    # dataset classes that have `train` bool argument
242                    dataset_class(root=root, train=True, download=True)
243                    dataset_class(root=root, train=False, download=True)
244
245                elif dataset_class in [
246                    Food101,
247                    GTSRB,
248                    StanfordCars,
249                    SVHN,
250                ]:
251                    # dataset classes that have `split` argument with 'train', 'test'
252                    dataset_class(
253                        root=root,
254                        split="train",
255                        download=True,
256                    )
257                    dataset_class(
258                        root=root,
259                        split="test",
260                        download=True,
261                    )
262                elif dataset_class in [Country211]:
263                    # dataset classes that have `split` argument with 'train', 'valid', 'test'
264                    dataset_class(
265                        root=root,
266                        split="train",
267                        download=True,
268                    )
269                    dataset_class(
270                        root=root,
271                        split="valid",
272                        download=True,
273                    )
274                    dataset_class(
275                        root=root,
276                        split="test",
277                        download=True,
278                    )
279
280                elif dataset_class in [
281                    DTD,
282                    FGVCAircraftVariant,
283                    FGVCAircraftFamily,
284                    FGVCAircraftManufacturer,
285                    Flowers102,
286                    PCAM,
287                    RenderedSST2,
288                ]:
289                    # dataset classes that have `split` argument with 'train', 'val', 'test'
290                    dataset_class(
291                        root=root,
292                        split="train",
293                        download=True,
294                    )
295                    dataset_class(
296                        root=root,
297                        split="val",
298                        download=True,
299                    )
300                    dataset_class(
301                        root=root,
302                        split="test",
303                        download=True,
304                    )
305                elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]:
306                    # dataset classes that have `split` argument with 'trainval', 'test'
307                    dataset_class(
308                        root=root,
309                        split="trainval",
310                        download=True,
311                    )
312                    dataset_class(
313                        root=root,
314                        split="test",
315                        download=True,
316                    )
317                elif dataset_class == CelebA:
318                    # special case
319                    dataset_class(
320                        root=root,
321                        split="train",
322                        target_type="identity",
323                        download=True,
324                    )
325                    dataset_class(
326                        root=root,
327                        split="valid",
328                        target_type="identity",
329                        download=True,
330                    )
331                    dataset_class(
332                        root=root,
333                        split="test",
334                        target_type="identity",
335                        download=True,
336                    )
337                elif dataset_class == FER2013:
338                    # special case
339                    dataset_class(
340                        root=root,
341                        split="train",
342                    )
343                    dataset_class(
344                        root=root,
345                        split="test",
346                    )
347                elif dataset_class == TinyImageNet:
348                    # special case
349                    dataset_class(root=root)
350
351            except RuntimeError:
352                failed_dataset_classes.append(dataset_class)  # save for later prompt
353            else:
354                pylogger.debug(
355                    "The original %s dataset for task %s has been downloaded to %s.",
356                    dataset_class,
357                    task_id,
358                    root,
359                )
360
361        if failed_dataset_classes:
362            raise RuntimeError(
363                f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually."
364            )
365
366    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
367        """Get the training and validation dataset of task `self.task_id`.
368
369        **Returns:**
370        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`.
371        """
372        # torchvision datasets might have different APIs
373        if self.original_dataset_python_class_t in [
374            ArabicHandwrittenDigits,
375            CIFAR10,
376            CIFAR100,
377            CUB2002011,
378            EMNISTByClass,
379            EMNISTByMerge,
380            EMNISTBalanced,
381            EMNISTLetters,
382            EMNISTDigits,
383            FaceScrub10,
384            FaceScrub20,
385            FaceScrub50,
386            FaceScrub100,
387            FaceScrubFromHAT,
388            FashionMNIST,
389            KannadaMNIST,
390            KMNIST,
391            Linnaeus5_32,
392            Linnaeus5_64,
393            Linnaeus5_128,
394            Linnaeus5_256,
395            MNIST,
396            NotMNIST,
397            NotMNISTFromHAT,
398            SignLanguageMNIST,
399            TrafficSignsFromHAT,
400            USPS,
401        ]:
402            # dataset classes that have `train` bool argument
403            dataset_train_and_val = self.original_dataset_python_class_t(
404                root=self.root_t,
405                train=True,
406                transform=self.train_and_val_transforms(),
407                target_transform=self.target_transform(),
408            )
409
410            return random_split(
411                dataset_train_and_val,
412                lengths=[1 - self.validation_percentage, self.validation_percentage],
413                generator=torch.Generator().manual_seed(
414                    42
415                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
416            )
417        elif self.original_dataset_python_class_t in [
418            Caltech101,
419            Caltech256,
420            EuroSAT,
421            SEMEION,
422            SUN397,
423        ]:
424            # dataset classes that don't have train and test splt
425            dataset_all = self.original_dataset_python_class_t(
426                root=self.root_t,
427                transform=self.train_and_val_transforms(),
428                target_transform=self.target_transform(),
429            )
430
431            dataset_train_and_val, _ = random_split(
432                dataset_all,
433                lengths=[
434                    1 - self.test_percentage,
435                    self.test_percentage,
436                ],
437                generator=torch.Generator().manual_seed(
438                    42
439                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
440            )
441
442            return random_split(
443                dataset_train_and_val,
444                lengths=[1 - self.validation_percentage, self.validation_percentage],
445                generator=torch.Generator().manual_seed(
446                    42
447                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
448            )
449        elif self.original_dataset_python_class_t in [Country211]:
450            # dataset classes that have `split` argument with 'train', 'valid', 'test'
451            dataset_train = self.original_dataset_python_class_t(
452                root=self.root_t,
453                split="train",
454                transform=self.train_and_val_transforms(),
455                target_transform=self.target_transform(),
456            )
457            dataset_val = self.original_dataset_python_class_t(
458                root=self.root_t,
459                split="valid",
460                transform=self.train_and_val_transforms(),
461                target_transform=self.target_transform(),
462            )
463
464            return dataset_train, dataset_val
465
466        elif self.original_dataset_python_class_t in [
467            DTD,
468            FGVCAircraftVariant,
469            FGVCAircraftFamily,
470            FGVCAircraftManufacturer,
471            Flowers102,
472            PCAM,
473            RenderedSST2,
474        ]:
475            # dataset classes that have `split` argument with 'train', 'val', 'test'
476            dataset_train = self.original_dataset_python_class_t(
477                root=self.root_t,
478                split="train",
479                transform=self.train_and_val_transforms(),
480                target_transform=self.target_transform(),
481            )
482
483            dataset_val = self.original_dataset_python_class_t(
484                root=self.root_t,
485                split="val",
486                transform=self.train_and_val_transforms(),
487                target_transform=self.target_transform(),
488            )
489
490            return dataset_train, dataset_val
491
492        elif self.original_dataset_python_class_t in [
493            FER2013,
494            Food101,
495            GTSRB,
496            StanfordCars,
497            SVHN,
498            TinyImageNet,
499        ]:
500            # dataset classes that have `split` argument with 'train', 'test'
501
502            dataset_train_and_val = self.original_dataset_python_class_t(
503                root=self.root_t,
504                split="train",
505                transform=self.train_and_val_transforms(),
506                target_transform=self.target_transform(),
507            )
508
509            return random_split(
510                dataset_train_and_val,
511                lengths=[1 - self.validation_percentage, self.validation_percentage],
512                generator=torch.Generator().manual_seed(
513                    42
514                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
515            )
516
517        elif self.original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]:
518            # dataset classes that have `split` argument with 'trainval', 'test'
519
520            dataset_train_and_val = self.original_dataset_python_class_t(
521                root=self.root_t,
522                split="trainval",
523                transform=self.train_and_val_transforms(),
524                target_transform=self.target_transform(),
525            )
526
527            return random_split(
528                dataset_train_and_val,
529                lengths=[1 - self.validation_percentage, self.validation_percentage],
530                generator=torch.Generator().manual_seed(
531                    42
532                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
533            )
534
535        elif self.original_dataset_python_class_t in [CelebA]:
536            # special case
537            dataset_train = self.original_dataset_python_class_t(
538                root=self.root_t,
539                split="train",
540                target_type="identity",
541                transform=self.train_and_val_transforms(),
542                target_transform=self.target_transform(),
543            )
544
545            dataset_val = self.original_dataset_python_class_t(
546                root=self.root_t,
547                split="valid",
548                target_type="identity",
549                transform=self.train_and_val_transforms(),
550                target_transform=self.target_transform(),
551            )
552
553            return dataset_train, dataset_val
554
555    def test_dataset(self) -> Dataset:
556        r"""Get the test dataset of task `self.task_id`.
557
558        **Returns:**
559        - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`.
560        """
561        # torchvision datasets might have different APIs
562
563        if self.original_dataset_python_class_t in [
564            ArabicHandwrittenDigits,
565            CIFAR10,
566            CIFAR100,
567            CUB2002011,
568            EMNISTByClass,
569            EMNISTByMerge,
570            EMNISTBalanced,
571            EMNISTLetters,
572            EMNISTDigits,
573            FaceScrub10,
574            FaceScrub20,
575            FaceScrub50,
576            FaceScrub100,
577            FaceScrubFromHAT,
578            FashionMNIST,
579            KannadaMNIST,
580            KMNIST,
581            Linnaeus5_32,
582            Linnaeus5_64,
583            Linnaeus5_128,
584            Linnaeus5_256,
585            MNIST,
586            NotMNIST,
587            NotMNISTFromHAT,
588            SignLanguageMNIST,
589            TrafficSignsFromHAT,
590            USPS,
591        ]:
592            # dataset classes that have `train` bool argument
593            dataset_test = self.original_dataset_python_class_t(
594                root=self.root_t,
595                train=False,
596                transform=self.test_transforms(),
597                target_transform=self.target_transform(),
598            )
599
600            return dataset_test
601
602        elif self.original_dataset_python_class_t in [
603            Country211,
604            DTD,
605            FER2013,
606            FGVCAircraftVariant,
607            FGVCAircraftFamily,
608            FGVCAircraftManufacturer,
609            Flowers102,
610            Food101,
611            GTSRB,
612            OxfordIIITPet2,
613            OxfordIIITPet37,
614            PCAM,
615            RenderedSST2,
616            StanfordCars,
617            SVHN,
618        ]:
619            # dataset classes that have `split` argument with 'test'
620
621            dataset_test = self.original_dataset_python_class_t(
622                root=self.root_t,
623                split="test",
624                transform=self.test_transforms(),
625                target_transform=self.target_transform(),
626            )
627
628            return dataset_test
629
630        elif self.original_dataset_python_class_t in [
631            Caltech101,
632            Caltech256,
633            EuroSAT,
634            SEMEION,
635            SUN397,
636        ]:
637            # dataset classes that don't have train and test splt
638
639            dataset_all = self.original_dataset_python_class_t(
640                root=self.root_t,
641                transform=self.train_and_val_transforms(),
642                target_transform=self.target_transform(),
643            )
644
645            _, dataset_test = random_split(
646                dataset_all,
647                lengths=[1 - self.test_percentage, self.test_percentage],
648                generator=torch.Generator().manual_seed(42),
649            )
650
651            return dataset_test
652
653        elif self.original_dataset_python_class_t in [CelebA]:
654            # special case
655            dataset_test = self.original_dataset_python_class_t(
656                root=self.root_t,
657                split="test",
658                target_type="identity",
659                transform=self.test_transforms(),
660                target_transform=self.target_transform(),
661            )
662
663            return dataset_test
664
665        elif self.original_dataset_python_class_t in [TinyImageNet]:
666            # special case
667            dataset_test = self.original_dataset_python_class_t(
668                root=self.root_t,
669                split="val",
670                transform=self.test_transforms(),
671                target_transform=self.target_transform(),
672            )
673
674            return dataset_test
class Combined(clarena.cl_datasets.base.CLCombinedDataset):
 75class Combined(CLCombinedDataset):
 76    r"""Combined CL dataset from available datasets."""
 77
 78    AVAILABLE_DATASETS: list[VisionDataset] = [
 79        ArabicHandwrittenDigits,
 80        CIFAR10,
 81        CIFAR100,
 82        CUB2002011,
 83        Caltech101,
 84        Caltech256,
 85        CelebA,
 86        Country211,
 87        DTD,
 88        EMNISTBalanced,
 89        EMNISTByClass,
 90        EMNISTByMerge,
 91        EMNISTDigits,
 92        EMNISTLetters,
 93        EuroSAT,
 94        FER2013,
 95        FGVCAircraftFamily,
 96        FGVCAircraftManufacturer,
 97        FGVCAircraftVariant,
 98        FaceScrub10,
 99        FaceScrub100,
100        FaceScrubFromHAT,
101        FaceScrub20,
102        FaceScrub50,
103        FashionMNIST,
104        Flowers102,
105        Food101,
106        GTSRB,
107        KMNIST,
108        KannadaMNIST,
109        Linnaeus5_128,
110        Linnaeus5_256,
111        Linnaeus5_32,
112        Linnaeus5_64,
113        MNIST,
114        NotMNIST,
115        NotMNISTFromHAT,
116        OxfordIIITPet2,
117        OxfordIIITPet37,
118        PCAM,
119        RenderedSST2,
120        SEMEION,
121        SUN397,
122        SVHN,
123        SignLanguageMNIST,
124        StanfordCars,
125        TrafficSignsFromHAT,
126        TinyImageNet,
127        USPS,
128    ]
129    r"""The list of available datasets."""
130
131    def __init__(
132        self,
133        datasets: list[str],
134        root: list[str],
135        validation_percentage: float,
136        test_percentage: float,
137        batch_size: int | list[int] = 1,
138        num_workers: int | list[int] = 0,
139        custom_transforms: (
140            Callable
141            | transforms.Compose
142            | None
143            | list[Callable | transforms.Compose | None]
144        ) = None,
145        repeat_channels: int | None | list[int | None] = None,
146        to_tensor: bool | list[bool] = True,
147        resize: tuple[int, int] | None | list[tuple[int, int] | None] = None,
148    ) -> None:
149        r"""Initialize the Combined Torchvision dataset object providing the root where data files live.
150
151        **Args:**
152        - **datasets** (`list[str]`): the list of dataset class paths for each task. Each element in the list must be a string referring to a valid PyTorch Dataset class. It needs to be one in `self.AVAILABLE_DATASETS`.
153        - **root** (`list[str]`): the list of root directory where the original data files for constructing the CL dataset physically live.
154        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset).
155        - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset).
156        - **batch_size** (`int` | `list[int]`): The batch size in train, val, test dataloader. If `list[str]`, it should be a list of integers, each integer is the batch size for each task.
157        - **num_workers** (`int` | `list[int]`): the number of workers for dataloaders. If `list[str]`, it should be a list of integers, each integer is the num of workers for each task.
158        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or list of them): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize, permute and so on are not included. If it is a list, each item is the custom transforms for each task.
159        - **repeat_channels** (`int` | `None` | list of them): the number of channels to repeat for each task. Default is None, which means no repeat. If not None, it should be an integer. If it is a list, each item is the number of channels to repeat for each task.
160        - **to_tensor** (`bool` | `list[bool]`): whether to include `ToTensor()` transform. Default is True.
161        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. If it is a list, each item is the size to resize for each task.
162        """
163        super().__init__(
164            datasets=datasets,
165            root=root,
166            batch_size=batch_size,
167            num_workers=num_workers,
168            custom_transforms=custom_transforms,
169            repeat_channels=repeat_channels,
170            to_tensor=to_tensor,
171            resize=resize,
172        )
173
174        self.test_percentage: float = test_percentage
175        """The percentage to randomly split some data into test data."""
176        self.validation_percentage: float = validation_percentage
177        """The percentage to randomly split some training data into validation data."""
178
179    def prepare_data(self) -> None:
180        r"""Download the original datasets if haven't."""
181
182        if self.task_id != 1:
183            return  # download all original datasets only at the beginning of first task
184
185        failed_dataset_classes = []
186        for task_id in range(1, self.num_tasks + 1):
187            root = self.root[task_id]
188            dataset_class = self.original_dataset_python_classes[task_id]
189            # torchvision datasets might have different APIs
190            try:
191                # collect the error and raise it at the end to avoid stopping the whole download process
192
193                if dataset_class in [
194                    ArabicHandwrittenDigits,
195                    KannadaMNIST,
196                    SignLanguageMNIST,
197                ]:
198                    # these datasets have no automatic download function, we require users to download them manually
199                    # the following code is just to check if the dataset is already downloaded
200                    dataset_class(root=root, train=True, download=False)
201                    dataset_class(root=root, train=False, download=False)
202
203                elif dataset_class in [
204                    Caltech101,
205                    Caltech256,
206                    EuroSAT,
207                    SEMEION,
208                    SUN397,
209                ]:
210                    # dataset classes that don't have any train, val, test split
211                    dataset_class(root=root, download=True)
212
213                elif dataset_class in [
214                    ArabicHandwrittenDigits,
215                    CIFAR10,
216                    CIFAR100,
217                    CUB2002011,
218                    EMNISTByClass,
219                    EMNISTByMerge,
220                    EMNISTBalanced,
221                    EMNISTLetters,
222                    EMNISTDigits,
223                    FaceScrub10,
224                    FaceScrub20,
225                    FaceScrub50,
226                    FaceScrub100,
227                    FaceScrubFromHAT,
228                    FashionMNIST,
229                    KannadaMNIST,
230                    KMNIST,
231                    Linnaeus5_32,
232                    Linnaeus5_64,
233                    Linnaeus5_128,
234                    Linnaeus5_256,
235                    MNIST,
236                    NotMNIST,
237                    NotMNISTFromHAT,
238                    SignLanguageMNIST,
239                    TrafficSignsFromHAT,
240                    USPS,
241                ]:
242                    # dataset classes that have `train` bool argument
243                    dataset_class(root=root, train=True, download=True)
244                    dataset_class(root=root, train=False, download=True)
245
246                elif dataset_class in [
247                    Food101,
248                    GTSRB,
249                    StanfordCars,
250                    SVHN,
251                ]:
252                    # dataset classes that have `split` argument with 'train', 'test'
253                    dataset_class(
254                        root=root,
255                        split="train",
256                        download=True,
257                    )
258                    dataset_class(
259                        root=root,
260                        split="test",
261                        download=True,
262                    )
263                elif dataset_class in [Country211]:
264                    # dataset classes that have `split` argument with 'train', 'valid', 'test'
265                    dataset_class(
266                        root=root,
267                        split="train",
268                        download=True,
269                    )
270                    dataset_class(
271                        root=root,
272                        split="valid",
273                        download=True,
274                    )
275                    dataset_class(
276                        root=root,
277                        split="test",
278                        download=True,
279                    )
280
281                elif dataset_class in [
282                    DTD,
283                    FGVCAircraftVariant,
284                    FGVCAircraftFamily,
285                    FGVCAircraftManufacturer,
286                    Flowers102,
287                    PCAM,
288                    RenderedSST2,
289                ]:
290                    # dataset classes that have `split` argument with 'train', 'val', 'test'
291                    dataset_class(
292                        root=root,
293                        split="train",
294                        download=True,
295                    )
296                    dataset_class(
297                        root=root,
298                        split="val",
299                        download=True,
300                    )
301                    dataset_class(
302                        root=root,
303                        split="test",
304                        download=True,
305                    )
306                elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]:
307                    # dataset classes that have `split` argument with 'trainval', 'test'
308                    dataset_class(
309                        root=root,
310                        split="trainval",
311                        download=True,
312                    )
313                    dataset_class(
314                        root=root,
315                        split="test",
316                        download=True,
317                    )
318                elif dataset_class == CelebA:
319                    # special case
320                    dataset_class(
321                        root=root,
322                        split="train",
323                        target_type="identity",
324                        download=True,
325                    )
326                    dataset_class(
327                        root=root,
328                        split="valid",
329                        target_type="identity",
330                        download=True,
331                    )
332                    dataset_class(
333                        root=root,
334                        split="test",
335                        target_type="identity",
336                        download=True,
337                    )
338                elif dataset_class == FER2013:
339                    # special case
340                    dataset_class(
341                        root=root,
342                        split="train",
343                    )
344                    dataset_class(
345                        root=root,
346                        split="test",
347                    )
348                elif dataset_class == TinyImageNet:
349                    # special case
350                    dataset_class(root=root)
351
352            except RuntimeError:
353                failed_dataset_classes.append(dataset_class)  # save for later prompt
354            else:
355                pylogger.debug(
356                    "The original %s dataset for task %s has been downloaded to %s.",
357                    dataset_class,
358                    task_id,
359                    root,
360                )
361
362        if failed_dataset_classes:
363            raise RuntimeError(
364                f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually."
365            )
366
367    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
368        """Get the training and validation dataset of task `self.task_id`.
369
370        **Returns:**
371        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`.
372        """
373        # torchvision datasets might have different APIs
374        if self.original_dataset_python_class_t in [
375            ArabicHandwrittenDigits,
376            CIFAR10,
377            CIFAR100,
378            CUB2002011,
379            EMNISTByClass,
380            EMNISTByMerge,
381            EMNISTBalanced,
382            EMNISTLetters,
383            EMNISTDigits,
384            FaceScrub10,
385            FaceScrub20,
386            FaceScrub50,
387            FaceScrub100,
388            FaceScrubFromHAT,
389            FashionMNIST,
390            KannadaMNIST,
391            KMNIST,
392            Linnaeus5_32,
393            Linnaeus5_64,
394            Linnaeus5_128,
395            Linnaeus5_256,
396            MNIST,
397            NotMNIST,
398            NotMNISTFromHAT,
399            SignLanguageMNIST,
400            TrafficSignsFromHAT,
401            USPS,
402        ]:
403            # dataset classes that have `train` bool argument
404            dataset_train_and_val = self.original_dataset_python_class_t(
405                root=self.root_t,
406                train=True,
407                transform=self.train_and_val_transforms(),
408                target_transform=self.target_transform(),
409            )
410
411            return random_split(
412                dataset_train_and_val,
413                lengths=[1 - self.validation_percentage, self.validation_percentage],
414                generator=torch.Generator().manual_seed(
415                    42
416                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
417            )
418        elif self.original_dataset_python_class_t in [
419            Caltech101,
420            Caltech256,
421            EuroSAT,
422            SEMEION,
423            SUN397,
424        ]:
425            # dataset classes that don't have train and test splt
426            dataset_all = self.original_dataset_python_class_t(
427                root=self.root_t,
428                transform=self.train_and_val_transforms(),
429                target_transform=self.target_transform(),
430            )
431
432            dataset_train_and_val, _ = random_split(
433                dataset_all,
434                lengths=[
435                    1 - self.test_percentage,
436                    self.test_percentage,
437                ],
438                generator=torch.Generator().manual_seed(
439                    42
440                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
441            )
442
443            return random_split(
444                dataset_train_and_val,
445                lengths=[1 - self.validation_percentage, self.validation_percentage],
446                generator=torch.Generator().manual_seed(
447                    42
448                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
449            )
450        elif self.original_dataset_python_class_t in [Country211]:
451            # dataset classes that have `split` argument with 'train', 'valid', 'test'
452            dataset_train = self.original_dataset_python_class_t(
453                root=self.root_t,
454                split="train",
455                transform=self.train_and_val_transforms(),
456                target_transform=self.target_transform(),
457            )
458            dataset_val = self.original_dataset_python_class_t(
459                root=self.root_t,
460                split="valid",
461                transform=self.train_and_val_transforms(),
462                target_transform=self.target_transform(),
463            )
464
465            return dataset_train, dataset_val
466
467        elif self.original_dataset_python_class_t in [
468            DTD,
469            FGVCAircraftVariant,
470            FGVCAircraftFamily,
471            FGVCAircraftManufacturer,
472            Flowers102,
473            PCAM,
474            RenderedSST2,
475        ]:
476            # dataset classes that have `split` argument with 'train', 'val', 'test'
477            dataset_train = self.original_dataset_python_class_t(
478                root=self.root_t,
479                split="train",
480                transform=self.train_and_val_transforms(),
481                target_transform=self.target_transform(),
482            )
483
484            dataset_val = self.original_dataset_python_class_t(
485                root=self.root_t,
486                split="val",
487                transform=self.train_and_val_transforms(),
488                target_transform=self.target_transform(),
489            )
490
491            return dataset_train, dataset_val
492
493        elif self.original_dataset_python_class_t in [
494            FER2013,
495            Food101,
496            GTSRB,
497            StanfordCars,
498            SVHN,
499            TinyImageNet,
500        ]:
501            # dataset classes that have `split` argument with 'train', 'test'
502
503            dataset_train_and_val = self.original_dataset_python_class_t(
504                root=self.root_t,
505                split="train",
506                transform=self.train_and_val_transforms(),
507                target_transform=self.target_transform(),
508            )
509
510            return random_split(
511                dataset_train_and_val,
512                lengths=[1 - self.validation_percentage, self.validation_percentage],
513                generator=torch.Generator().manual_seed(
514                    42
515                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
516            )
517
518        elif self.original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]:
519            # dataset classes that have `split` argument with 'trainval', 'test'
520
521            dataset_train_and_val = self.original_dataset_python_class_t(
522                root=self.root_t,
523                split="trainval",
524                transform=self.train_and_val_transforms(),
525                target_transform=self.target_transform(),
526            )
527
528            return random_split(
529                dataset_train_and_val,
530                lengths=[1 - self.validation_percentage, self.validation_percentage],
531                generator=torch.Generator().manual_seed(
532                    42
533                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
534            )
535
536        elif self.original_dataset_python_class_t in [CelebA]:
537            # special case
538            dataset_train = self.original_dataset_python_class_t(
539                root=self.root_t,
540                split="train",
541                target_type="identity",
542                transform=self.train_and_val_transforms(),
543                target_transform=self.target_transform(),
544            )
545
546            dataset_val = self.original_dataset_python_class_t(
547                root=self.root_t,
548                split="valid",
549                target_type="identity",
550                transform=self.train_and_val_transforms(),
551                target_transform=self.target_transform(),
552            )
553
554            return dataset_train, dataset_val
555
556    def test_dataset(self) -> Dataset:
557        r"""Get the test dataset of task `self.task_id`.
558
559        **Returns:**
560        - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`.
561        """
562        # torchvision datasets might have different APIs
563
564        if self.original_dataset_python_class_t in [
565            ArabicHandwrittenDigits,
566            CIFAR10,
567            CIFAR100,
568            CUB2002011,
569            EMNISTByClass,
570            EMNISTByMerge,
571            EMNISTBalanced,
572            EMNISTLetters,
573            EMNISTDigits,
574            FaceScrub10,
575            FaceScrub20,
576            FaceScrub50,
577            FaceScrub100,
578            FaceScrubFromHAT,
579            FashionMNIST,
580            KannadaMNIST,
581            KMNIST,
582            Linnaeus5_32,
583            Linnaeus5_64,
584            Linnaeus5_128,
585            Linnaeus5_256,
586            MNIST,
587            NotMNIST,
588            NotMNISTFromHAT,
589            SignLanguageMNIST,
590            TrafficSignsFromHAT,
591            USPS,
592        ]:
593            # dataset classes that have `train` bool argument
594            dataset_test = self.original_dataset_python_class_t(
595                root=self.root_t,
596                train=False,
597                transform=self.test_transforms(),
598                target_transform=self.target_transform(),
599            )
600
601            return dataset_test
602
603        elif self.original_dataset_python_class_t in [
604            Country211,
605            DTD,
606            FER2013,
607            FGVCAircraftVariant,
608            FGVCAircraftFamily,
609            FGVCAircraftManufacturer,
610            Flowers102,
611            Food101,
612            GTSRB,
613            OxfordIIITPet2,
614            OxfordIIITPet37,
615            PCAM,
616            RenderedSST2,
617            StanfordCars,
618            SVHN,
619        ]:
620            # dataset classes that have `split` argument with 'test'
621
622            dataset_test = self.original_dataset_python_class_t(
623                root=self.root_t,
624                split="test",
625                transform=self.test_transforms(),
626                target_transform=self.target_transform(),
627            )
628
629            return dataset_test
630
631        elif self.original_dataset_python_class_t in [
632            Caltech101,
633            Caltech256,
634            EuroSAT,
635            SEMEION,
636            SUN397,
637        ]:
638            # dataset classes that don't have train and test splt
639
640            dataset_all = self.original_dataset_python_class_t(
641                root=self.root_t,
642                transform=self.train_and_val_transforms(),
643                target_transform=self.target_transform(),
644            )
645
646            _, dataset_test = random_split(
647                dataset_all,
648                lengths=[1 - self.test_percentage, self.test_percentage],
649                generator=torch.Generator().manual_seed(42),
650            )
651
652            return dataset_test
653
654        elif self.original_dataset_python_class_t in [CelebA]:
655            # special case
656            dataset_test = self.original_dataset_python_class_t(
657                root=self.root_t,
658                split="test",
659                target_type="identity",
660                transform=self.test_transforms(),
661                target_transform=self.target_transform(),
662            )
663
664            return dataset_test
665
666        elif self.original_dataset_python_class_t in [TinyImageNet]:
667            # special case
668            dataset_test = self.original_dataset_python_class_t(
669                root=self.root_t,
670                split="val",
671                transform=self.test_transforms(),
672                target_transform=self.target_transform(),
673            )
674
675            return dataset_test

Combined CL dataset from available datasets.

Combined( datasets: list[str], root: list[str], validation_percentage: float, test_percentage: float, batch_size: int | list[int] = 1, num_workers: int | list[int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, list[Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | list[int | None] = None, to_tensor: bool | list[bool] = True, resize: tuple[int, int] | None | list[tuple[int, int] | None] = None)
131    def __init__(
132        self,
133        datasets: list[str],
134        root: list[str],
135        validation_percentage: float,
136        test_percentage: float,
137        batch_size: int | list[int] = 1,
138        num_workers: int | list[int] = 0,
139        custom_transforms: (
140            Callable
141            | transforms.Compose
142            | None
143            | list[Callable | transforms.Compose | None]
144        ) = None,
145        repeat_channels: int | None | list[int | None] = None,
146        to_tensor: bool | list[bool] = True,
147        resize: tuple[int, int] | None | list[tuple[int, int] | None] = None,
148    ) -> None:
149        r"""Initialize the Combined Torchvision dataset object providing the root where data files live.
150
151        **Args:**
152        - **datasets** (`list[str]`): the list of dataset class paths for each task. Each element in the list must be a string referring to a valid PyTorch Dataset class. It needs to be one in `self.AVAILABLE_DATASETS`.
153        - **root** (`list[str]`): the list of root directory where the original data files for constructing the CL dataset physically live.
154        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset).
155        - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset).
156        - **batch_size** (`int` | `list[int]`): The batch size in train, val, test dataloader. If `list[str]`, it should be a list of integers, each integer is the batch size for each task.
157        - **num_workers** (`int` | `list[int]`): the number of workers for dataloaders. If `list[str]`, it should be a list of integers, each integer is the num of workers for each task.
158        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or list of them): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize, permute and so on are not included. If it is a list, each item is the custom transforms for each task.
159        - **repeat_channels** (`int` | `None` | list of them): the number of channels to repeat for each task. Default is None, which means no repeat. If not None, it should be an integer. If it is a list, each item is the number of channels to repeat for each task.
160        - **to_tensor** (`bool` | `list[bool]`): whether to include `ToTensor()` transform. Default is True.
161        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. If it is a list, each item is the size to resize for each task.
162        """
163        super().__init__(
164            datasets=datasets,
165            root=root,
166            batch_size=batch_size,
167            num_workers=num_workers,
168            custom_transforms=custom_transforms,
169            repeat_channels=repeat_channels,
170            to_tensor=to_tensor,
171            resize=resize,
172        )
173
174        self.test_percentage: float = test_percentage
175        """The percentage to randomly split some data into test data."""
176        self.validation_percentage: float = validation_percentage
177        """The percentage to randomly split some training data into validation data."""

Initialize the Combined Torchvision dataset object providing the root where data files live.

Args:

  • datasets (list[str]): the list of dataset class paths for each task. Each element in the list must be a string referring to a valid PyTorch Dataset class. It needs to be one in self.AVAILABLE_DATASETS.
  • root (list[str]): the list of root directory where the original data files for constructing the CL dataset physically live.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset).
  • test_percentage (float): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset).
  • batch_size (int | list[int]): The batch size in train, val, test dataloader. If list[str], it should be a list of integers, each integer is the batch size for each task.
  • num_workers (int | list[int]): the number of workers for dataloaders. If list[str], it should be a list of integers, each integer is the num of workers for each task.
  • custom_transforms (transform or transforms.Compose or None or list of them): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalize, permute and so on are not included. If it is a list, each item is the custom transforms for each task.
  • repeat_channels (int | None | list of them): the number of channels to repeat for each task. Default is None, which means no repeat. If not None, it should be an integer. If it is a list, each item is the number of channels to repeat for each task.
  • to_tensor (bool | list[bool]): whether to include ToTensor() transform. Default is True.
  • resize (tuple[int, int] | None or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. If it is a list, each item is the size to resize for each task.
AVAILABLE_DATASETS: list[torchvision.datasets.vision.VisionDataset] = [<class 'clarena.stl_datasets.raw.ahdd.ArabicHandwrittenDigits'>, <class 'torchvision.datasets.cifar.CIFAR10'>, <class 'torchvision.datasets.cifar.CIFAR100'>, <class 'clarena.stl_datasets.raw.cub2002011.CUB2002011'>, <class 'torchvision.datasets.caltech.Caltech101'>, <class 'torchvision.datasets.caltech.Caltech256'>, <class 'torchvision.datasets.celeba.CelebA'>, <class 'torchvision.datasets.country211.Country211'>, <class 'torchvision.datasets.dtd.DTD'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTBalanced'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTByClass'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTByMerge'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTDigits'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTLetters'>, <class 'torchvision.datasets.eurosat.EuroSAT'>, <class 'torchvision.datasets.fer2013.FER2013'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftFamily'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftManufacturer'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftVariant'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub10'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub100'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrubFromHAT'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub20'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub50'>, <class 'torchvision.datasets.mnist.FashionMNIST'>, <class 'torchvision.datasets.flowers102.Flowers102'>, <class 'torchvision.datasets.food101.Food101'>, <class 'torchvision.datasets.gtsrb.GTSRB'>, <class 'torchvision.datasets.mnist.KMNIST'>, <class 'clarena.stl_datasets.raw.kannada_mnist.KannadaMNIST'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_128'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_256'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_32'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_64'>, <class 'torchvision.datasets.mnist.MNIST'>, <class 'clarena.stl_datasets.raw.notmnist.NotMNIST'>, <class 'clarena.stl_datasets.raw.notmnist.NotMNISTFromHAT'>, <class 'clarena.stl_datasets.raw.oxford_iiit_pet.OxfordIIITPet2'>, <class 'clarena.stl_datasets.raw.oxford_iiit_pet.OxfordIIITPet37'>, <class 'torchvision.datasets.pcam.PCAM'>, <class 'torchvision.datasets.rendered_sst2.RenderedSST2'>, <class 'torchvision.datasets.semeion.SEMEION'>, <class 'torchvision.datasets.sun397.SUN397'>, <class 'torchvision.datasets.svhn.SVHN'>, <class 'clarena.stl_datasets.raw.sign_language_mnist.SignLanguageMNIST'>, <class 'torchvision.datasets.stanford_cars.StanfordCars'>, <class 'clarena.stl_datasets.raw.traffic_signs.TrafficSignsFromHAT'>, <class 'tinyimagenet.TinyImageNet'>, <class 'torchvision.datasets.usps.USPS'>]

The list of available datasets.

test_percentage: float

The percentage to randomly split some data into test data.

validation_percentage: float

The percentage to randomly split some training data into validation data.

def prepare_data(self) -> None:
179    def prepare_data(self) -> None:
180        r"""Download the original datasets if haven't."""
181
182        if self.task_id != 1:
183            return  # download all original datasets only at the beginning of first task
184
185        failed_dataset_classes = []
186        for task_id in range(1, self.num_tasks + 1):
187            root = self.root[task_id]
188            dataset_class = self.original_dataset_python_classes[task_id]
189            # torchvision datasets might have different APIs
190            try:
191                # collect the error and raise it at the end to avoid stopping the whole download process
192
193                if dataset_class in [
194                    ArabicHandwrittenDigits,
195                    KannadaMNIST,
196                    SignLanguageMNIST,
197                ]:
198                    # these datasets have no automatic download function, we require users to download them manually
199                    # the following code is just to check if the dataset is already downloaded
200                    dataset_class(root=root, train=True, download=False)
201                    dataset_class(root=root, train=False, download=False)
202
203                elif dataset_class in [
204                    Caltech101,
205                    Caltech256,
206                    EuroSAT,
207                    SEMEION,
208                    SUN397,
209                ]:
210                    # dataset classes that don't have any train, val, test split
211                    dataset_class(root=root, download=True)
212
213                elif dataset_class in [
214                    ArabicHandwrittenDigits,
215                    CIFAR10,
216                    CIFAR100,
217                    CUB2002011,
218                    EMNISTByClass,
219                    EMNISTByMerge,
220                    EMNISTBalanced,
221                    EMNISTLetters,
222                    EMNISTDigits,
223                    FaceScrub10,
224                    FaceScrub20,
225                    FaceScrub50,
226                    FaceScrub100,
227                    FaceScrubFromHAT,
228                    FashionMNIST,
229                    KannadaMNIST,
230                    KMNIST,
231                    Linnaeus5_32,
232                    Linnaeus5_64,
233                    Linnaeus5_128,
234                    Linnaeus5_256,
235                    MNIST,
236                    NotMNIST,
237                    NotMNISTFromHAT,
238                    SignLanguageMNIST,
239                    TrafficSignsFromHAT,
240                    USPS,
241                ]:
242                    # dataset classes that have `train` bool argument
243                    dataset_class(root=root, train=True, download=True)
244                    dataset_class(root=root, train=False, download=True)
245
246                elif dataset_class in [
247                    Food101,
248                    GTSRB,
249                    StanfordCars,
250                    SVHN,
251                ]:
252                    # dataset classes that have `split` argument with 'train', 'test'
253                    dataset_class(
254                        root=root,
255                        split="train",
256                        download=True,
257                    )
258                    dataset_class(
259                        root=root,
260                        split="test",
261                        download=True,
262                    )
263                elif dataset_class in [Country211]:
264                    # dataset classes that have `split` argument with 'train', 'valid', 'test'
265                    dataset_class(
266                        root=root,
267                        split="train",
268                        download=True,
269                    )
270                    dataset_class(
271                        root=root,
272                        split="valid",
273                        download=True,
274                    )
275                    dataset_class(
276                        root=root,
277                        split="test",
278                        download=True,
279                    )
280
281                elif dataset_class in [
282                    DTD,
283                    FGVCAircraftVariant,
284                    FGVCAircraftFamily,
285                    FGVCAircraftManufacturer,
286                    Flowers102,
287                    PCAM,
288                    RenderedSST2,
289                ]:
290                    # dataset classes that have `split` argument with 'train', 'val', 'test'
291                    dataset_class(
292                        root=root,
293                        split="train",
294                        download=True,
295                    )
296                    dataset_class(
297                        root=root,
298                        split="val",
299                        download=True,
300                    )
301                    dataset_class(
302                        root=root,
303                        split="test",
304                        download=True,
305                    )
306                elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]:
307                    # dataset classes that have `split` argument with 'trainval', 'test'
308                    dataset_class(
309                        root=root,
310                        split="trainval",
311                        download=True,
312                    )
313                    dataset_class(
314                        root=root,
315                        split="test",
316                        download=True,
317                    )
318                elif dataset_class == CelebA:
319                    # special case
320                    dataset_class(
321                        root=root,
322                        split="train",
323                        target_type="identity",
324                        download=True,
325                    )
326                    dataset_class(
327                        root=root,
328                        split="valid",
329                        target_type="identity",
330                        download=True,
331                    )
332                    dataset_class(
333                        root=root,
334                        split="test",
335                        target_type="identity",
336                        download=True,
337                    )
338                elif dataset_class == FER2013:
339                    # special case
340                    dataset_class(
341                        root=root,
342                        split="train",
343                    )
344                    dataset_class(
345                        root=root,
346                        split="test",
347                    )
348                elif dataset_class == TinyImageNet:
349                    # special case
350                    dataset_class(root=root)
351
352            except RuntimeError:
353                failed_dataset_classes.append(dataset_class)  # save for later prompt
354            else:
355                pylogger.debug(
356                    "The original %s dataset for task %s has been downloaded to %s.",
357                    dataset_class,
358                    task_id,
359                    root,
360                )
361
362        if failed_dataset_classes:
363            raise RuntimeError(
364                f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually."
365            )

Download the original datasets if haven't.

def train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
367    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
368        """Get the training and validation dataset of task `self.task_id`.
369
370        **Returns:**
371        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`.
372        """
373        # torchvision datasets might have different APIs
374        if self.original_dataset_python_class_t in [
375            ArabicHandwrittenDigits,
376            CIFAR10,
377            CIFAR100,
378            CUB2002011,
379            EMNISTByClass,
380            EMNISTByMerge,
381            EMNISTBalanced,
382            EMNISTLetters,
383            EMNISTDigits,
384            FaceScrub10,
385            FaceScrub20,
386            FaceScrub50,
387            FaceScrub100,
388            FaceScrubFromHAT,
389            FashionMNIST,
390            KannadaMNIST,
391            KMNIST,
392            Linnaeus5_32,
393            Linnaeus5_64,
394            Linnaeus5_128,
395            Linnaeus5_256,
396            MNIST,
397            NotMNIST,
398            NotMNISTFromHAT,
399            SignLanguageMNIST,
400            TrafficSignsFromHAT,
401            USPS,
402        ]:
403            # dataset classes that have `train` bool argument
404            dataset_train_and_val = self.original_dataset_python_class_t(
405                root=self.root_t,
406                train=True,
407                transform=self.train_and_val_transforms(),
408                target_transform=self.target_transform(),
409            )
410
411            return random_split(
412                dataset_train_and_val,
413                lengths=[1 - self.validation_percentage, self.validation_percentage],
414                generator=torch.Generator().manual_seed(
415                    42
416                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
417            )
418        elif self.original_dataset_python_class_t in [
419            Caltech101,
420            Caltech256,
421            EuroSAT,
422            SEMEION,
423            SUN397,
424        ]:
425            # dataset classes that don't have train and test splt
426            dataset_all = self.original_dataset_python_class_t(
427                root=self.root_t,
428                transform=self.train_and_val_transforms(),
429                target_transform=self.target_transform(),
430            )
431
432            dataset_train_and_val, _ = random_split(
433                dataset_all,
434                lengths=[
435                    1 - self.test_percentage,
436                    self.test_percentage,
437                ],
438                generator=torch.Generator().manual_seed(
439                    42
440                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
441            )
442
443            return random_split(
444                dataset_train_and_val,
445                lengths=[1 - self.validation_percentage, self.validation_percentage],
446                generator=torch.Generator().manual_seed(
447                    42
448                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
449            )
450        elif self.original_dataset_python_class_t in [Country211]:
451            # dataset classes that have `split` argument with 'train', 'valid', 'test'
452            dataset_train = self.original_dataset_python_class_t(
453                root=self.root_t,
454                split="train",
455                transform=self.train_and_val_transforms(),
456                target_transform=self.target_transform(),
457            )
458            dataset_val = self.original_dataset_python_class_t(
459                root=self.root_t,
460                split="valid",
461                transform=self.train_and_val_transforms(),
462                target_transform=self.target_transform(),
463            )
464
465            return dataset_train, dataset_val
466
467        elif self.original_dataset_python_class_t in [
468            DTD,
469            FGVCAircraftVariant,
470            FGVCAircraftFamily,
471            FGVCAircraftManufacturer,
472            Flowers102,
473            PCAM,
474            RenderedSST2,
475        ]:
476            # dataset classes that have `split` argument with 'train', 'val', 'test'
477            dataset_train = self.original_dataset_python_class_t(
478                root=self.root_t,
479                split="train",
480                transform=self.train_and_val_transforms(),
481                target_transform=self.target_transform(),
482            )
483
484            dataset_val = self.original_dataset_python_class_t(
485                root=self.root_t,
486                split="val",
487                transform=self.train_and_val_transforms(),
488                target_transform=self.target_transform(),
489            )
490
491            return dataset_train, dataset_val
492
493        elif self.original_dataset_python_class_t in [
494            FER2013,
495            Food101,
496            GTSRB,
497            StanfordCars,
498            SVHN,
499            TinyImageNet,
500        ]:
501            # dataset classes that have `split` argument with 'train', 'test'
502
503            dataset_train_and_val = self.original_dataset_python_class_t(
504                root=self.root_t,
505                split="train",
506                transform=self.train_and_val_transforms(),
507                target_transform=self.target_transform(),
508            )
509
510            return random_split(
511                dataset_train_and_val,
512                lengths=[1 - self.validation_percentage, self.validation_percentage],
513                generator=torch.Generator().manual_seed(
514                    42
515                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
516            )
517
518        elif self.original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]:
519            # dataset classes that have `split` argument with 'trainval', 'test'
520
521            dataset_train_and_val = self.original_dataset_python_class_t(
522                root=self.root_t,
523                split="trainval",
524                transform=self.train_and_val_transforms(),
525                target_transform=self.target_transform(),
526            )
527
528            return random_split(
529                dataset_train_and_val,
530                lengths=[1 - self.validation_percentage, self.validation_percentage],
531                generator=torch.Generator().manual_seed(
532                    42
533                ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
534            )
535
536        elif self.original_dataset_python_class_t in [CelebA]:
537            # special case
538            dataset_train = self.original_dataset_python_class_t(
539                root=self.root_t,
540                split="train",
541                target_type="identity",
542                transform=self.train_and_val_transforms(),
543                target_transform=self.target_transform(),
544            )
545
546            dataset_val = self.original_dataset_python_class_t(
547                root=self.root_t,
548                split="valid",
549                target_type="identity",
550                transform=self.train_and_val_transforms(),
551                target_transform=self.target_transform(),
552            )
553
554            return dataset_train, dataset_val

Get the training and validation dataset of task self.task_id.

Returns:

  • train_and_val_dataset (tuple[Dataset, Dataset]): the train and validation dataset of task self.task_id.
def test_dataset(self) -> torch.utils.data.dataset.Dataset:
556    def test_dataset(self) -> Dataset:
557        r"""Get the test dataset of task `self.task_id`.
558
559        **Returns:**
560        - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`.
561        """
562        # torchvision datasets might have different APIs
563
564        if self.original_dataset_python_class_t in [
565            ArabicHandwrittenDigits,
566            CIFAR10,
567            CIFAR100,
568            CUB2002011,
569            EMNISTByClass,
570            EMNISTByMerge,
571            EMNISTBalanced,
572            EMNISTLetters,
573            EMNISTDigits,
574            FaceScrub10,
575            FaceScrub20,
576            FaceScrub50,
577            FaceScrub100,
578            FaceScrubFromHAT,
579            FashionMNIST,
580            KannadaMNIST,
581            KMNIST,
582            Linnaeus5_32,
583            Linnaeus5_64,
584            Linnaeus5_128,
585            Linnaeus5_256,
586            MNIST,
587            NotMNIST,
588            NotMNISTFromHAT,
589            SignLanguageMNIST,
590            TrafficSignsFromHAT,
591            USPS,
592        ]:
593            # dataset classes that have `train` bool argument
594            dataset_test = self.original_dataset_python_class_t(
595                root=self.root_t,
596                train=False,
597                transform=self.test_transforms(),
598                target_transform=self.target_transform(),
599            )
600
601            return dataset_test
602
603        elif self.original_dataset_python_class_t in [
604            Country211,
605            DTD,
606            FER2013,
607            FGVCAircraftVariant,
608            FGVCAircraftFamily,
609            FGVCAircraftManufacturer,
610            Flowers102,
611            Food101,
612            GTSRB,
613            OxfordIIITPet2,
614            OxfordIIITPet37,
615            PCAM,
616            RenderedSST2,
617            StanfordCars,
618            SVHN,
619        ]:
620            # dataset classes that have `split` argument with 'test'
621
622            dataset_test = self.original_dataset_python_class_t(
623                root=self.root_t,
624                split="test",
625                transform=self.test_transforms(),
626                target_transform=self.target_transform(),
627            )
628
629            return dataset_test
630
631        elif self.original_dataset_python_class_t in [
632            Caltech101,
633            Caltech256,
634            EuroSAT,
635            SEMEION,
636            SUN397,
637        ]:
638            # dataset classes that don't have train and test splt
639
640            dataset_all = self.original_dataset_python_class_t(
641                root=self.root_t,
642                transform=self.train_and_val_transforms(),
643                target_transform=self.target_transform(),
644            )
645
646            _, dataset_test = random_split(
647                dataset_all,
648                lengths=[1 - self.test_percentage, self.test_percentage],
649                generator=torch.Generator().manual_seed(42),
650            )
651
652            return dataset_test
653
654        elif self.original_dataset_python_class_t in [CelebA]:
655            # special case
656            dataset_test = self.original_dataset_python_class_t(
657                root=self.root_t,
658                split="test",
659                target_type="identity",
660                transform=self.test_transforms(),
661                target_transform=self.target_transform(),
662            )
663
664            return dataset_test
665
666        elif self.original_dataset_python_class_t in [TinyImageNet]:
667            # special case
668            dataset_test = self.original_dataset_python_class_t(
669                root=self.root_t,
670                split="val",
671                transform=self.test_transforms(),
672                target_transform=self.target_transform(),
673            )
674
675            return dataset_test

Get the test dataset of task self.task_id.

Returns:

  • test_dataset (Dataset): the test dataset of task self.task_id.