clarena.cl_datasets.combined
The submodule in cl_datasets for combined datasets.
1r""" 2The submodule in `cl_datasets` for combined datasets. 3""" 4 5__all__ = ["Combined"] 6 7import logging 8from typing import Callable 9 10import torch 11from tinyimagenet import TinyImageNet 12from torch.utils.data import Dataset, random_split 13from torchvision import transforms 14from torchvision.datasets import ( 15 CIFAR10, 16 CIFAR100, 17 DTD, 18 FER2013, 19 GTSRB, 20 KMNIST, 21 MNIST, 22 PCAM, 23 SEMEION, 24 SUN397, 25 SVHN, 26 USPS, 27 Caltech101, 28 Caltech256, 29 CelebA, 30 Country211, 31 EuroSAT, 32 FashionMNIST, 33 Flowers102, 34 Food101, 35 RenderedSST2, 36 StanfordCars, 37) 38from torchvision.datasets.vision import VisionDataset 39 40from clarena.cl_datasets import CLCombinedDataset 41from clarena.stl_datasets.raw import ( 42 CUB2002011, 43 ArabicHandwrittenDigits, 44 EMNISTBalanced, 45 EMNISTByClass, 46 EMNISTByMerge, 47 EMNISTDigits, 48 EMNISTLetters, 49 FaceScrub10, 50 FaceScrub20, 51 FaceScrub50, 52 FaceScrub100, 53 FaceScrubFromHAT, 54 FGVCAircraftFamily, 55 FGVCAircraftManufacturer, 56 FGVCAircraftVariant, 57 KannadaMNIST, 58 Linnaeus5_32, 59 Linnaeus5_64, 60 Linnaeus5_128, 61 Linnaeus5_256, 62 NotMNIST, 63 NotMNISTFromHAT, 64 OxfordIIITPet2, 65 OxfordIIITPet37, 66 SignLanguageMNIST, 67 TrafficSignsFromHAT, 68) 69 70# always get logger for built-in logging in each module 71pylogger = logging.getLogger(__name__) 72 73 74class Combined(CLCombinedDataset): 75 r"""Combined CL dataset from available datasets.""" 76 77 AVAILABLE_DATASETS: list[VisionDataset] = [ 78 ArabicHandwrittenDigits, 79 CIFAR10, 80 CIFAR100, 81 CUB2002011, 82 Caltech101, 83 Caltech256, 84 CelebA, 85 Country211, 86 DTD, 87 EMNISTBalanced, 88 EMNISTByClass, 89 EMNISTByMerge, 90 EMNISTDigits, 91 EMNISTLetters, 92 EuroSAT, 93 FER2013, 94 FGVCAircraftFamily, 95 FGVCAircraftManufacturer, 96 FGVCAircraftVariant, 97 FaceScrub10, 98 FaceScrub100, 99 FaceScrubFromHAT, 100 FaceScrub20, 101 FaceScrub50, 102 FashionMNIST, 103 Flowers102, 104 Food101, 105 GTSRB, 106 KMNIST, 107 KannadaMNIST, 108 Linnaeus5_128, 109 Linnaeus5_256, 110 Linnaeus5_32, 111 Linnaeus5_64, 112 MNIST, 113 NotMNIST, 114 NotMNISTFromHAT, 115 OxfordIIITPet2, 116 OxfordIIITPet37, 117 PCAM, 118 RenderedSST2, 119 SEMEION, 120 SUN397, 121 SVHN, 122 SignLanguageMNIST, 123 StanfordCars, 124 TrafficSignsFromHAT, 125 TinyImageNet, 126 USPS, 127 ] 128 r"""The list of available datasets.""" 129 130 def __init__( 131 self, 132 datasets: list[str], 133 root: list[str], 134 validation_percentage: float, 135 test_percentage: float, 136 batch_size: int | list[int] = 1, 137 num_workers: int | list[int] = 0, 138 custom_transforms: ( 139 Callable 140 | transforms.Compose 141 | None 142 | list[Callable | transforms.Compose | None] 143 ) = None, 144 repeat_channels: int | None | list[int | None] = None, 145 to_tensor: bool | list[bool] = True, 146 resize: tuple[int, int] | None | list[tuple[int, int] | None] = None, 147 ) -> None: 148 r"""Initialize the Combined Torchvision dataset object providing the root where data files live. 149 150 **Args:** 151 - **datasets** (`list[str]`): the list of dataset class paths for each task. Each element in the list must be a string referring to a valid PyTorch Dataset class. It needs to be one in `self.AVAILABLE_DATASETS`. 152 - **root** (`list[str]`): the list of root directory where the original data files for constructing the CL dataset physically live. 153 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset). 154 - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset). 155 - **batch_size** (`int` | `list[int]`): The batch size in train, val, test dataloader. If `list[str]`, it should be a list of integers, each integer is the batch size for each task. 156 - **num_workers** (`int` | `list[int]`): the number of workers for dataloaders. If `list[str]`, it should be a list of integers, each integer is the num of workers for each task. 157 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or list of them): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize, permute and so on are not included. If it is a list, each item is the custom transforms for each task. 158 - **repeat_channels** (`int` | `None` | list of them): the number of channels to repeat for each task. Default is None, which means no repeat. If not None, it should be an integer. If it is a list, each item is the number of channels to repeat for each task. 159 - **to_tensor** (`bool` | `list[bool]`): whether to include `ToTensor()` transform. Default is True. 160 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. If it is a list, each item is the size to resize for each task. 161 """ 162 super().__init__( 163 datasets=datasets, 164 root=root, 165 batch_size=batch_size, 166 num_workers=num_workers, 167 custom_transforms=custom_transforms, 168 repeat_channels=repeat_channels, 169 to_tensor=to_tensor, 170 resize=resize, 171 ) 172 173 self.test_percentage: float = test_percentage 174 """The percentage to randomly split some data into test data.""" 175 self.validation_percentage: float = validation_percentage 176 """The percentage to randomly split some training data into validation data.""" 177 178 def prepare_data(self) -> None: 179 r"""Download the original datasets if haven't.""" 180 181 if self.task_id != 1: 182 return # download all original datasets only at the beginning of first task 183 184 failed_dataset_classes = [] 185 for task_id in range(1, self.num_tasks + 1): 186 root = self.root[task_id] 187 dataset_class = self.original_dataset_python_classes[task_id] 188 # torchvision datasets might have different APIs 189 try: 190 # collect the error and raise it at the end to avoid stopping the whole download process 191 192 if dataset_class in [ 193 ArabicHandwrittenDigits, 194 KannadaMNIST, 195 SignLanguageMNIST, 196 ]: 197 # these datasets have no automatic download function, we require users to download them manually 198 # the following code is just to check if the dataset is already downloaded 199 dataset_class(root=root, train=True, download=False) 200 dataset_class(root=root, train=False, download=False) 201 202 elif dataset_class in [ 203 Caltech101, 204 Caltech256, 205 EuroSAT, 206 SEMEION, 207 SUN397, 208 ]: 209 # dataset classes that don't have any train, val, test split 210 dataset_class(root=root, download=True) 211 212 elif dataset_class in [ 213 ArabicHandwrittenDigits, 214 CIFAR10, 215 CIFAR100, 216 CUB2002011, 217 EMNISTByClass, 218 EMNISTByMerge, 219 EMNISTBalanced, 220 EMNISTLetters, 221 EMNISTDigits, 222 FaceScrub10, 223 FaceScrub20, 224 FaceScrub50, 225 FaceScrub100, 226 FaceScrubFromHAT, 227 FashionMNIST, 228 KannadaMNIST, 229 KMNIST, 230 Linnaeus5_32, 231 Linnaeus5_64, 232 Linnaeus5_128, 233 Linnaeus5_256, 234 MNIST, 235 NotMNIST, 236 NotMNISTFromHAT, 237 SignLanguageMNIST, 238 TrafficSignsFromHAT, 239 USPS, 240 ]: 241 # dataset classes that have `train` bool argument 242 dataset_class(root=root, train=True, download=True) 243 dataset_class(root=root, train=False, download=True) 244 245 elif dataset_class in [ 246 Food101, 247 GTSRB, 248 StanfordCars, 249 SVHN, 250 ]: 251 # dataset classes that have `split` argument with 'train', 'test' 252 dataset_class( 253 root=root, 254 split="train", 255 download=True, 256 ) 257 dataset_class( 258 root=root, 259 split="test", 260 download=True, 261 ) 262 elif dataset_class in [Country211]: 263 # dataset classes that have `split` argument with 'train', 'valid', 'test' 264 dataset_class( 265 root=root, 266 split="train", 267 download=True, 268 ) 269 dataset_class( 270 root=root, 271 split="valid", 272 download=True, 273 ) 274 dataset_class( 275 root=root, 276 split="test", 277 download=True, 278 ) 279 280 elif dataset_class in [ 281 DTD, 282 FGVCAircraftVariant, 283 FGVCAircraftFamily, 284 FGVCAircraftManufacturer, 285 Flowers102, 286 PCAM, 287 RenderedSST2, 288 ]: 289 # dataset classes that have `split` argument with 'train', 'val', 'test' 290 dataset_class( 291 root=root, 292 split="train", 293 download=True, 294 ) 295 dataset_class( 296 root=root, 297 split="val", 298 download=True, 299 ) 300 dataset_class( 301 root=root, 302 split="test", 303 download=True, 304 ) 305 elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]: 306 # dataset classes that have `split` argument with 'trainval', 'test' 307 dataset_class( 308 root=root, 309 split="trainval", 310 download=True, 311 ) 312 dataset_class( 313 root=root, 314 split="test", 315 download=True, 316 ) 317 elif dataset_class == CelebA: 318 # special case 319 dataset_class( 320 root=root, 321 split="train", 322 target_type="identity", 323 download=True, 324 ) 325 dataset_class( 326 root=root, 327 split="valid", 328 target_type="identity", 329 download=True, 330 ) 331 dataset_class( 332 root=root, 333 split="test", 334 target_type="identity", 335 download=True, 336 ) 337 elif dataset_class == FER2013: 338 # special case 339 dataset_class( 340 root=root, 341 split="train", 342 ) 343 dataset_class( 344 root=root, 345 split="test", 346 ) 347 elif dataset_class == TinyImageNet: 348 # special case 349 dataset_class(root=root) 350 351 except RuntimeError: 352 failed_dataset_classes.append(dataset_class) # save for later prompt 353 else: 354 pylogger.debug( 355 "The original %s dataset for task %s has been downloaded to %s.", 356 dataset_class, 357 task_id, 358 root, 359 ) 360 361 if failed_dataset_classes: 362 raise RuntimeError( 363 f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually." 364 ) 365 366 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 367 """Get the training and validation dataset of task `self.task_id`. 368 369 **Returns:** 370 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`. 371 """ 372 # torchvision datasets might have different APIs 373 if self.original_dataset_python_class_t in [ 374 ArabicHandwrittenDigits, 375 CIFAR10, 376 CIFAR100, 377 CUB2002011, 378 EMNISTByClass, 379 EMNISTByMerge, 380 EMNISTBalanced, 381 EMNISTLetters, 382 EMNISTDigits, 383 FaceScrub10, 384 FaceScrub20, 385 FaceScrub50, 386 FaceScrub100, 387 FaceScrubFromHAT, 388 FashionMNIST, 389 KannadaMNIST, 390 KMNIST, 391 Linnaeus5_32, 392 Linnaeus5_64, 393 Linnaeus5_128, 394 Linnaeus5_256, 395 MNIST, 396 NotMNIST, 397 NotMNISTFromHAT, 398 SignLanguageMNIST, 399 TrafficSignsFromHAT, 400 USPS, 401 ]: 402 # dataset classes that have `train` bool argument 403 dataset_train_and_val = self.original_dataset_python_class_t( 404 root=self.root_t, 405 train=True, 406 transform=self.train_and_val_transforms(), 407 target_transform=self.target_transform(), 408 ) 409 410 return random_split( 411 dataset_train_and_val, 412 lengths=[1 - self.validation_percentage, self.validation_percentage], 413 generator=torch.Generator().manual_seed( 414 42 415 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 416 ) 417 elif self.original_dataset_python_class_t in [ 418 Caltech101, 419 Caltech256, 420 EuroSAT, 421 SEMEION, 422 SUN397, 423 ]: 424 # dataset classes that don't have train and test splt 425 dataset_all = self.original_dataset_python_class_t( 426 root=self.root_t, 427 transform=self.train_and_val_transforms(), 428 target_transform=self.target_transform(), 429 ) 430 431 dataset_train_and_val, _ = random_split( 432 dataset_all, 433 lengths=[ 434 1 - self.test_percentage, 435 self.test_percentage, 436 ], 437 generator=torch.Generator().manual_seed( 438 42 439 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 440 ) 441 442 return random_split( 443 dataset_train_and_val, 444 lengths=[1 - self.validation_percentage, self.validation_percentage], 445 generator=torch.Generator().manual_seed( 446 42 447 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 448 ) 449 elif self.original_dataset_python_class_t in [Country211]: 450 # dataset classes that have `split` argument with 'train', 'valid', 'test' 451 dataset_train = self.original_dataset_python_class_t( 452 root=self.root_t, 453 split="train", 454 transform=self.train_and_val_transforms(), 455 target_transform=self.target_transform(), 456 ) 457 dataset_val = self.original_dataset_python_class_t( 458 root=self.root_t, 459 split="valid", 460 transform=self.train_and_val_transforms(), 461 target_transform=self.target_transform(), 462 ) 463 464 return dataset_train, dataset_val 465 466 elif self.original_dataset_python_class_t in [ 467 DTD, 468 FGVCAircraftVariant, 469 FGVCAircraftFamily, 470 FGVCAircraftManufacturer, 471 Flowers102, 472 PCAM, 473 RenderedSST2, 474 ]: 475 # dataset classes that have `split` argument with 'train', 'val', 'test' 476 dataset_train = self.original_dataset_python_class_t( 477 root=self.root_t, 478 split="train", 479 transform=self.train_and_val_transforms(), 480 target_transform=self.target_transform(), 481 ) 482 483 dataset_val = self.original_dataset_python_class_t( 484 root=self.root_t, 485 split="val", 486 transform=self.train_and_val_transforms(), 487 target_transform=self.target_transform(), 488 ) 489 490 return dataset_train, dataset_val 491 492 elif self.original_dataset_python_class_t in [ 493 FER2013, 494 Food101, 495 GTSRB, 496 StanfordCars, 497 SVHN, 498 TinyImageNet, 499 ]: 500 # dataset classes that have `split` argument with 'train', 'test' 501 502 dataset_train_and_val = self.original_dataset_python_class_t( 503 root=self.root_t, 504 split="train", 505 transform=self.train_and_val_transforms(), 506 target_transform=self.target_transform(), 507 ) 508 509 return random_split( 510 dataset_train_and_val, 511 lengths=[1 - self.validation_percentage, self.validation_percentage], 512 generator=torch.Generator().manual_seed( 513 42 514 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 515 ) 516 517 elif self.original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]: 518 # dataset classes that have `split` argument with 'trainval', 'test' 519 520 dataset_train_and_val = self.original_dataset_python_class_t( 521 root=self.root_t, 522 split="trainval", 523 transform=self.train_and_val_transforms(), 524 target_transform=self.target_transform(), 525 ) 526 527 return random_split( 528 dataset_train_and_val, 529 lengths=[1 - self.validation_percentage, self.validation_percentage], 530 generator=torch.Generator().manual_seed( 531 42 532 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 533 ) 534 535 elif self.original_dataset_python_class_t in [CelebA]: 536 # special case 537 dataset_train = self.original_dataset_python_class_t( 538 root=self.root_t, 539 split="train", 540 target_type="identity", 541 transform=self.train_and_val_transforms(), 542 target_transform=self.target_transform(), 543 ) 544 545 dataset_val = self.original_dataset_python_class_t( 546 root=self.root_t, 547 split="valid", 548 target_type="identity", 549 transform=self.train_and_val_transforms(), 550 target_transform=self.target_transform(), 551 ) 552 553 return dataset_train, dataset_val 554 555 def test_dataset(self) -> Dataset: 556 r"""Get the test dataset of task `self.task_id`. 557 558 **Returns:** 559 - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`. 560 """ 561 # torchvision datasets might have different APIs 562 563 if self.original_dataset_python_class_t in [ 564 ArabicHandwrittenDigits, 565 CIFAR10, 566 CIFAR100, 567 CUB2002011, 568 EMNISTByClass, 569 EMNISTByMerge, 570 EMNISTBalanced, 571 EMNISTLetters, 572 EMNISTDigits, 573 FaceScrub10, 574 FaceScrub20, 575 FaceScrub50, 576 FaceScrub100, 577 FaceScrubFromHAT, 578 FashionMNIST, 579 KannadaMNIST, 580 KMNIST, 581 Linnaeus5_32, 582 Linnaeus5_64, 583 Linnaeus5_128, 584 Linnaeus5_256, 585 MNIST, 586 NotMNIST, 587 NotMNISTFromHAT, 588 SignLanguageMNIST, 589 TrafficSignsFromHAT, 590 USPS, 591 ]: 592 # dataset classes that have `train` bool argument 593 dataset_test = self.original_dataset_python_class_t( 594 root=self.root_t, 595 train=False, 596 transform=self.test_transforms(), 597 target_transform=self.target_transform(), 598 ) 599 600 return dataset_test 601 602 elif self.original_dataset_python_class_t in [ 603 Country211, 604 DTD, 605 FER2013, 606 FGVCAircraftVariant, 607 FGVCAircraftFamily, 608 FGVCAircraftManufacturer, 609 Flowers102, 610 Food101, 611 GTSRB, 612 OxfordIIITPet2, 613 OxfordIIITPet37, 614 PCAM, 615 RenderedSST2, 616 StanfordCars, 617 SVHN, 618 ]: 619 # dataset classes that have `split` argument with 'test' 620 621 dataset_test = self.original_dataset_python_class_t( 622 root=self.root_t, 623 split="test", 624 transform=self.test_transforms(), 625 target_transform=self.target_transform(), 626 ) 627 628 return dataset_test 629 630 elif self.original_dataset_python_class_t in [ 631 Caltech101, 632 Caltech256, 633 EuroSAT, 634 SEMEION, 635 SUN397, 636 ]: 637 # dataset classes that don't have train and test splt 638 639 dataset_all = self.original_dataset_python_class_t( 640 root=self.root_t, 641 transform=self.train_and_val_transforms(), 642 target_transform=self.target_transform(), 643 ) 644 645 _, dataset_test = random_split( 646 dataset_all, 647 lengths=[1 - self.test_percentage, self.test_percentage], 648 generator=torch.Generator().manual_seed(42), 649 ) 650 651 return dataset_test 652 653 elif self.original_dataset_python_class_t in [CelebA]: 654 # special case 655 dataset_test = self.original_dataset_python_class_t( 656 root=self.root_t, 657 split="test", 658 target_type="identity", 659 transform=self.test_transforms(), 660 target_transform=self.target_transform(), 661 ) 662 663 return dataset_test 664 665 elif self.original_dataset_python_class_t in [TinyImageNet]: 666 # special case 667 dataset_test = self.original_dataset_python_class_t( 668 root=self.root_t, 669 split="val", 670 transform=self.test_transforms(), 671 target_transform=self.target_transform(), 672 ) 673 674 return dataset_test
class
Combined(clarena.cl_datasets.base.CLCombinedDataset):
75class Combined(CLCombinedDataset): 76 r"""Combined CL dataset from available datasets.""" 77 78 AVAILABLE_DATASETS: list[VisionDataset] = [ 79 ArabicHandwrittenDigits, 80 CIFAR10, 81 CIFAR100, 82 CUB2002011, 83 Caltech101, 84 Caltech256, 85 CelebA, 86 Country211, 87 DTD, 88 EMNISTBalanced, 89 EMNISTByClass, 90 EMNISTByMerge, 91 EMNISTDigits, 92 EMNISTLetters, 93 EuroSAT, 94 FER2013, 95 FGVCAircraftFamily, 96 FGVCAircraftManufacturer, 97 FGVCAircraftVariant, 98 FaceScrub10, 99 FaceScrub100, 100 FaceScrubFromHAT, 101 FaceScrub20, 102 FaceScrub50, 103 FashionMNIST, 104 Flowers102, 105 Food101, 106 GTSRB, 107 KMNIST, 108 KannadaMNIST, 109 Linnaeus5_128, 110 Linnaeus5_256, 111 Linnaeus5_32, 112 Linnaeus5_64, 113 MNIST, 114 NotMNIST, 115 NotMNISTFromHAT, 116 OxfordIIITPet2, 117 OxfordIIITPet37, 118 PCAM, 119 RenderedSST2, 120 SEMEION, 121 SUN397, 122 SVHN, 123 SignLanguageMNIST, 124 StanfordCars, 125 TrafficSignsFromHAT, 126 TinyImageNet, 127 USPS, 128 ] 129 r"""The list of available datasets.""" 130 131 def __init__( 132 self, 133 datasets: list[str], 134 root: list[str], 135 validation_percentage: float, 136 test_percentage: float, 137 batch_size: int | list[int] = 1, 138 num_workers: int | list[int] = 0, 139 custom_transforms: ( 140 Callable 141 | transforms.Compose 142 | None 143 | list[Callable | transforms.Compose | None] 144 ) = None, 145 repeat_channels: int | None | list[int | None] = None, 146 to_tensor: bool | list[bool] = True, 147 resize: tuple[int, int] | None | list[tuple[int, int] | None] = None, 148 ) -> None: 149 r"""Initialize the Combined Torchvision dataset object providing the root where data files live. 150 151 **Args:** 152 - **datasets** (`list[str]`): the list of dataset class paths for each task. Each element in the list must be a string referring to a valid PyTorch Dataset class. It needs to be one in `self.AVAILABLE_DATASETS`. 153 - **root** (`list[str]`): the list of root directory where the original data files for constructing the CL dataset physically live. 154 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset). 155 - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset). 156 - **batch_size** (`int` | `list[int]`): The batch size in train, val, test dataloader. If `list[str]`, it should be a list of integers, each integer is the batch size for each task. 157 - **num_workers** (`int` | `list[int]`): the number of workers for dataloaders. If `list[str]`, it should be a list of integers, each integer is the num of workers for each task. 158 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or list of them): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize, permute and so on are not included. If it is a list, each item is the custom transforms for each task. 159 - **repeat_channels** (`int` | `None` | list of them): the number of channels to repeat for each task. Default is None, which means no repeat. If not None, it should be an integer. If it is a list, each item is the number of channels to repeat for each task. 160 - **to_tensor** (`bool` | `list[bool]`): whether to include `ToTensor()` transform. Default is True. 161 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. If it is a list, each item is the size to resize for each task. 162 """ 163 super().__init__( 164 datasets=datasets, 165 root=root, 166 batch_size=batch_size, 167 num_workers=num_workers, 168 custom_transforms=custom_transforms, 169 repeat_channels=repeat_channels, 170 to_tensor=to_tensor, 171 resize=resize, 172 ) 173 174 self.test_percentage: float = test_percentage 175 """The percentage to randomly split some data into test data.""" 176 self.validation_percentage: float = validation_percentage 177 """The percentage to randomly split some training data into validation data.""" 178 179 def prepare_data(self) -> None: 180 r"""Download the original datasets if haven't.""" 181 182 if self.task_id != 1: 183 return # download all original datasets only at the beginning of first task 184 185 failed_dataset_classes = [] 186 for task_id in range(1, self.num_tasks + 1): 187 root = self.root[task_id] 188 dataset_class = self.original_dataset_python_classes[task_id] 189 # torchvision datasets might have different APIs 190 try: 191 # collect the error and raise it at the end to avoid stopping the whole download process 192 193 if dataset_class in [ 194 ArabicHandwrittenDigits, 195 KannadaMNIST, 196 SignLanguageMNIST, 197 ]: 198 # these datasets have no automatic download function, we require users to download them manually 199 # the following code is just to check if the dataset is already downloaded 200 dataset_class(root=root, train=True, download=False) 201 dataset_class(root=root, train=False, download=False) 202 203 elif dataset_class in [ 204 Caltech101, 205 Caltech256, 206 EuroSAT, 207 SEMEION, 208 SUN397, 209 ]: 210 # dataset classes that don't have any train, val, test split 211 dataset_class(root=root, download=True) 212 213 elif dataset_class in [ 214 ArabicHandwrittenDigits, 215 CIFAR10, 216 CIFAR100, 217 CUB2002011, 218 EMNISTByClass, 219 EMNISTByMerge, 220 EMNISTBalanced, 221 EMNISTLetters, 222 EMNISTDigits, 223 FaceScrub10, 224 FaceScrub20, 225 FaceScrub50, 226 FaceScrub100, 227 FaceScrubFromHAT, 228 FashionMNIST, 229 KannadaMNIST, 230 KMNIST, 231 Linnaeus5_32, 232 Linnaeus5_64, 233 Linnaeus5_128, 234 Linnaeus5_256, 235 MNIST, 236 NotMNIST, 237 NotMNISTFromHAT, 238 SignLanguageMNIST, 239 TrafficSignsFromHAT, 240 USPS, 241 ]: 242 # dataset classes that have `train` bool argument 243 dataset_class(root=root, train=True, download=True) 244 dataset_class(root=root, train=False, download=True) 245 246 elif dataset_class in [ 247 Food101, 248 GTSRB, 249 StanfordCars, 250 SVHN, 251 ]: 252 # dataset classes that have `split` argument with 'train', 'test' 253 dataset_class( 254 root=root, 255 split="train", 256 download=True, 257 ) 258 dataset_class( 259 root=root, 260 split="test", 261 download=True, 262 ) 263 elif dataset_class in [Country211]: 264 # dataset classes that have `split` argument with 'train', 'valid', 'test' 265 dataset_class( 266 root=root, 267 split="train", 268 download=True, 269 ) 270 dataset_class( 271 root=root, 272 split="valid", 273 download=True, 274 ) 275 dataset_class( 276 root=root, 277 split="test", 278 download=True, 279 ) 280 281 elif dataset_class in [ 282 DTD, 283 FGVCAircraftVariant, 284 FGVCAircraftFamily, 285 FGVCAircraftManufacturer, 286 Flowers102, 287 PCAM, 288 RenderedSST2, 289 ]: 290 # dataset classes that have `split` argument with 'train', 'val', 'test' 291 dataset_class( 292 root=root, 293 split="train", 294 download=True, 295 ) 296 dataset_class( 297 root=root, 298 split="val", 299 download=True, 300 ) 301 dataset_class( 302 root=root, 303 split="test", 304 download=True, 305 ) 306 elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]: 307 # dataset classes that have `split` argument with 'trainval', 'test' 308 dataset_class( 309 root=root, 310 split="trainval", 311 download=True, 312 ) 313 dataset_class( 314 root=root, 315 split="test", 316 download=True, 317 ) 318 elif dataset_class == CelebA: 319 # special case 320 dataset_class( 321 root=root, 322 split="train", 323 target_type="identity", 324 download=True, 325 ) 326 dataset_class( 327 root=root, 328 split="valid", 329 target_type="identity", 330 download=True, 331 ) 332 dataset_class( 333 root=root, 334 split="test", 335 target_type="identity", 336 download=True, 337 ) 338 elif dataset_class == FER2013: 339 # special case 340 dataset_class( 341 root=root, 342 split="train", 343 ) 344 dataset_class( 345 root=root, 346 split="test", 347 ) 348 elif dataset_class == TinyImageNet: 349 # special case 350 dataset_class(root=root) 351 352 except RuntimeError: 353 failed_dataset_classes.append(dataset_class) # save for later prompt 354 else: 355 pylogger.debug( 356 "The original %s dataset for task %s has been downloaded to %s.", 357 dataset_class, 358 task_id, 359 root, 360 ) 361 362 if failed_dataset_classes: 363 raise RuntimeError( 364 f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually." 365 ) 366 367 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 368 """Get the training and validation dataset of task `self.task_id`. 369 370 **Returns:** 371 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`. 372 """ 373 # torchvision datasets might have different APIs 374 if self.original_dataset_python_class_t in [ 375 ArabicHandwrittenDigits, 376 CIFAR10, 377 CIFAR100, 378 CUB2002011, 379 EMNISTByClass, 380 EMNISTByMerge, 381 EMNISTBalanced, 382 EMNISTLetters, 383 EMNISTDigits, 384 FaceScrub10, 385 FaceScrub20, 386 FaceScrub50, 387 FaceScrub100, 388 FaceScrubFromHAT, 389 FashionMNIST, 390 KannadaMNIST, 391 KMNIST, 392 Linnaeus5_32, 393 Linnaeus5_64, 394 Linnaeus5_128, 395 Linnaeus5_256, 396 MNIST, 397 NotMNIST, 398 NotMNISTFromHAT, 399 SignLanguageMNIST, 400 TrafficSignsFromHAT, 401 USPS, 402 ]: 403 # dataset classes that have `train` bool argument 404 dataset_train_and_val = self.original_dataset_python_class_t( 405 root=self.root_t, 406 train=True, 407 transform=self.train_and_val_transforms(), 408 target_transform=self.target_transform(), 409 ) 410 411 return random_split( 412 dataset_train_and_val, 413 lengths=[1 - self.validation_percentage, self.validation_percentage], 414 generator=torch.Generator().manual_seed( 415 42 416 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 417 ) 418 elif self.original_dataset_python_class_t in [ 419 Caltech101, 420 Caltech256, 421 EuroSAT, 422 SEMEION, 423 SUN397, 424 ]: 425 # dataset classes that don't have train and test splt 426 dataset_all = self.original_dataset_python_class_t( 427 root=self.root_t, 428 transform=self.train_and_val_transforms(), 429 target_transform=self.target_transform(), 430 ) 431 432 dataset_train_and_val, _ = random_split( 433 dataset_all, 434 lengths=[ 435 1 - self.test_percentage, 436 self.test_percentage, 437 ], 438 generator=torch.Generator().manual_seed( 439 42 440 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 441 ) 442 443 return random_split( 444 dataset_train_and_val, 445 lengths=[1 - self.validation_percentage, self.validation_percentage], 446 generator=torch.Generator().manual_seed( 447 42 448 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 449 ) 450 elif self.original_dataset_python_class_t in [Country211]: 451 # dataset classes that have `split` argument with 'train', 'valid', 'test' 452 dataset_train = self.original_dataset_python_class_t( 453 root=self.root_t, 454 split="train", 455 transform=self.train_and_val_transforms(), 456 target_transform=self.target_transform(), 457 ) 458 dataset_val = self.original_dataset_python_class_t( 459 root=self.root_t, 460 split="valid", 461 transform=self.train_and_val_transforms(), 462 target_transform=self.target_transform(), 463 ) 464 465 return dataset_train, dataset_val 466 467 elif self.original_dataset_python_class_t in [ 468 DTD, 469 FGVCAircraftVariant, 470 FGVCAircraftFamily, 471 FGVCAircraftManufacturer, 472 Flowers102, 473 PCAM, 474 RenderedSST2, 475 ]: 476 # dataset classes that have `split` argument with 'train', 'val', 'test' 477 dataset_train = self.original_dataset_python_class_t( 478 root=self.root_t, 479 split="train", 480 transform=self.train_and_val_transforms(), 481 target_transform=self.target_transform(), 482 ) 483 484 dataset_val = self.original_dataset_python_class_t( 485 root=self.root_t, 486 split="val", 487 transform=self.train_and_val_transforms(), 488 target_transform=self.target_transform(), 489 ) 490 491 return dataset_train, dataset_val 492 493 elif self.original_dataset_python_class_t in [ 494 FER2013, 495 Food101, 496 GTSRB, 497 StanfordCars, 498 SVHN, 499 TinyImageNet, 500 ]: 501 # dataset classes that have `split` argument with 'train', 'test' 502 503 dataset_train_and_val = self.original_dataset_python_class_t( 504 root=self.root_t, 505 split="train", 506 transform=self.train_and_val_transforms(), 507 target_transform=self.target_transform(), 508 ) 509 510 return random_split( 511 dataset_train_and_val, 512 lengths=[1 - self.validation_percentage, self.validation_percentage], 513 generator=torch.Generator().manual_seed( 514 42 515 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 516 ) 517 518 elif self.original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]: 519 # dataset classes that have `split` argument with 'trainval', 'test' 520 521 dataset_train_and_val = self.original_dataset_python_class_t( 522 root=self.root_t, 523 split="trainval", 524 transform=self.train_and_val_transforms(), 525 target_transform=self.target_transform(), 526 ) 527 528 return random_split( 529 dataset_train_and_val, 530 lengths=[1 - self.validation_percentage, self.validation_percentage], 531 generator=torch.Generator().manual_seed( 532 42 533 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 534 ) 535 536 elif self.original_dataset_python_class_t in [CelebA]: 537 # special case 538 dataset_train = self.original_dataset_python_class_t( 539 root=self.root_t, 540 split="train", 541 target_type="identity", 542 transform=self.train_and_val_transforms(), 543 target_transform=self.target_transform(), 544 ) 545 546 dataset_val = self.original_dataset_python_class_t( 547 root=self.root_t, 548 split="valid", 549 target_type="identity", 550 transform=self.train_and_val_transforms(), 551 target_transform=self.target_transform(), 552 ) 553 554 return dataset_train, dataset_val 555 556 def test_dataset(self) -> Dataset: 557 r"""Get the test dataset of task `self.task_id`. 558 559 **Returns:** 560 - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`. 561 """ 562 # torchvision datasets might have different APIs 563 564 if self.original_dataset_python_class_t in [ 565 ArabicHandwrittenDigits, 566 CIFAR10, 567 CIFAR100, 568 CUB2002011, 569 EMNISTByClass, 570 EMNISTByMerge, 571 EMNISTBalanced, 572 EMNISTLetters, 573 EMNISTDigits, 574 FaceScrub10, 575 FaceScrub20, 576 FaceScrub50, 577 FaceScrub100, 578 FaceScrubFromHAT, 579 FashionMNIST, 580 KannadaMNIST, 581 KMNIST, 582 Linnaeus5_32, 583 Linnaeus5_64, 584 Linnaeus5_128, 585 Linnaeus5_256, 586 MNIST, 587 NotMNIST, 588 NotMNISTFromHAT, 589 SignLanguageMNIST, 590 TrafficSignsFromHAT, 591 USPS, 592 ]: 593 # dataset classes that have `train` bool argument 594 dataset_test = self.original_dataset_python_class_t( 595 root=self.root_t, 596 train=False, 597 transform=self.test_transforms(), 598 target_transform=self.target_transform(), 599 ) 600 601 return dataset_test 602 603 elif self.original_dataset_python_class_t in [ 604 Country211, 605 DTD, 606 FER2013, 607 FGVCAircraftVariant, 608 FGVCAircraftFamily, 609 FGVCAircraftManufacturer, 610 Flowers102, 611 Food101, 612 GTSRB, 613 OxfordIIITPet2, 614 OxfordIIITPet37, 615 PCAM, 616 RenderedSST2, 617 StanfordCars, 618 SVHN, 619 ]: 620 # dataset classes that have `split` argument with 'test' 621 622 dataset_test = self.original_dataset_python_class_t( 623 root=self.root_t, 624 split="test", 625 transform=self.test_transforms(), 626 target_transform=self.target_transform(), 627 ) 628 629 return dataset_test 630 631 elif self.original_dataset_python_class_t in [ 632 Caltech101, 633 Caltech256, 634 EuroSAT, 635 SEMEION, 636 SUN397, 637 ]: 638 # dataset classes that don't have train and test splt 639 640 dataset_all = self.original_dataset_python_class_t( 641 root=self.root_t, 642 transform=self.train_and_val_transforms(), 643 target_transform=self.target_transform(), 644 ) 645 646 _, dataset_test = random_split( 647 dataset_all, 648 lengths=[1 - self.test_percentage, self.test_percentage], 649 generator=torch.Generator().manual_seed(42), 650 ) 651 652 return dataset_test 653 654 elif self.original_dataset_python_class_t in [CelebA]: 655 # special case 656 dataset_test = self.original_dataset_python_class_t( 657 root=self.root_t, 658 split="test", 659 target_type="identity", 660 transform=self.test_transforms(), 661 target_transform=self.target_transform(), 662 ) 663 664 return dataset_test 665 666 elif self.original_dataset_python_class_t in [TinyImageNet]: 667 # special case 668 dataset_test = self.original_dataset_python_class_t( 669 root=self.root_t, 670 split="val", 671 transform=self.test_transforms(), 672 target_transform=self.target_transform(), 673 ) 674 675 return dataset_test
Combined CL dataset from available datasets.
Combined( datasets: list[str], root: list[str], validation_percentage: float, test_percentage: float, batch_size: int | list[int] = 1, num_workers: int | list[int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, list[Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | list[int | None] = None, to_tensor: bool | list[bool] = True, resize: tuple[int, int] | None | list[tuple[int, int] | None] = None)
131 def __init__( 132 self, 133 datasets: list[str], 134 root: list[str], 135 validation_percentage: float, 136 test_percentage: float, 137 batch_size: int | list[int] = 1, 138 num_workers: int | list[int] = 0, 139 custom_transforms: ( 140 Callable 141 | transforms.Compose 142 | None 143 | list[Callable | transforms.Compose | None] 144 ) = None, 145 repeat_channels: int | None | list[int | None] = None, 146 to_tensor: bool | list[bool] = True, 147 resize: tuple[int, int] | None | list[tuple[int, int] | None] = None, 148 ) -> None: 149 r"""Initialize the Combined Torchvision dataset object providing the root where data files live. 150 151 **Args:** 152 - **datasets** (`list[str]`): the list of dataset class paths for each task. Each element in the list must be a string referring to a valid PyTorch Dataset class. It needs to be one in `self.AVAILABLE_DATASETS`. 153 - **root** (`list[str]`): the list of root directory where the original data files for constructing the CL dataset physically live. 154 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset). 155 - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset). 156 - **batch_size** (`int` | `list[int]`): The batch size in train, val, test dataloader. If `list[str]`, it should be a list of integers, each integer is the batch size for each task. 157 - **num_workers** (`int` | `list[int]`): the number of workers for dataloaders. If `list[str]`, it should be a list of integers, each integer is the num of workers for each task. 158 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or list of them): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize, permute and so on are not included. If it is a list, each item is the custom transforms for each task. 159 - **repeat_channels** (`int` | `None` | list of them): the number of channels to repeat for each task. Default is None, which means no repeat. If not None, it should be an integer. If it is a list, each item is the number of channels to repeat for each task. 160 - **to_tensor** (`bool` | `list[bool]`): whether to include `ToTensor()` transform. Default is True. 161 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. If it is a list, each item is the size to resize for each task. 162 """ 163 super().__init__( 164 datasets=datasets, 165 root=root, 166 batch_size=batch_size, 167 num_workers=num_workers, 168 custom_transforms=custom_transforms, 169 repeat_channels=repeat_channels, 170 to_tensor=to_tensor, 171 resize=resize, 172 ) 173 174 self.test_percentage: float = test_percentage 175 """The percentage to randomly split some data into test data.""" 176 self.validation_percentage: float = validation_percentage 177 """The percentage to randomly split some training data into validation data."""
Initialize the Combined Torchvision dataset object providing the root where data files live.
Args:
- datasets (
list[str]): the list of dataset class paths for each task. Each element in the list must be a string referring to a valid PyTorch Dataset class. It needs to be one inself.AVAILABLE_DATASETS. - root (
list[str]): the list of root directory where the original data files for constructing the CL dataset physically live. - validation_percentage (
float): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset). - test_percentage (
float): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset). - batch_size (
int|list[int]): The batch size in train, val, test dataloader. Iflist[str], it should be a list of integers, each integer is the batch size for each task. - num_workers (
int|list[int]): the number of workers for dataloaders. Iflist[str], it should be a list of integers, each integer is the num of workers for each task. - custom_transforms (
transformortransforms.ComposeorNoneor list of them): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor(), normalize, permute and so on are not included. If it is a list, each item is the custom transforms for each task. - repeat_channels (
int|None| list of them): the number of channels to repeat for each task. Default is None, which means no repeat. If not None, it should be an integer. If it is a list, each item is the number of channels to repeat for each task. - to_tensor (
bool|list[bool]): whether to includeToTensor()transform. Default is True. - resize (
tuple[int, int]|Noneor list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. If it is a list, each item is the size to resize for each task.
AVAILABLE_DATASETS: list[torchvision.datasets.vision.VisionDataset] =
[<class 'clarena.stl_datasets.raw.ahdd.ArabicHandwrittenDigits'>, <class 'torchvision.datasets.cifar.CIFAR10'>, <class 'torchvision.datasets.cifar.CIFAR100'>, <class 'clarena.stl_datasets.raw.cub2002011.CUB2002011'>, <class 'torchvision.datasets.caltech.Caltech101'>, <class 'torchvision.datasets.caltech.Caltech256'>, <class 'torchvision.datasets.celeba.CelebA'>, <class 'torchvision.datasets.country211.Country211'>, <class 'torchvision.datasets.dtd.DTD'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTBalanced'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTByClass'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTByMerge'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTDigits'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTLetters'>, <class 'torchvision.datasets.eurosat.EuroSAT'>, <class 'torchvision.datasets.fer2013.FER2013'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftFamily'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftManufacturer'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftVariant'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub10'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub100'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrubFromHAT'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub20'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub50'>, <class 'torchvision.datasets.mnist.FashionMNIST'>, <class 'torchvision.datasets.flowers102.Flowers102'>, <class 'torchvision.datasets.food101.Food101'>, <class 'torchvision.datasets.gtsrb.GTSRB'>, <class 'torchvision.datasets.mnist.KMNIST'>, <class 'clarena.stl_datasets.raw.kannada_mnist.KannadaMNIST'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_128'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_256'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_32'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_64'>, <class 'torchvision.datasets.mnist.MNIST'>, <class 'clarena.stl_datasets.raw.notmnist.NotMNIST'>, <class 'clarena.stl_datasets.raw.notmnist.NotMNISTFromHAT'>, <class 'clarena.stl_datasets.raw.oxford_iiit_pet.OxfordIIITPet2'>, <class 'clarena.stl_datasets.raw.oxford_iiit_pet.OxfordIIITPet37'>, <class 'torchvision.datasets.pcam.PCAM'>, <class 'torchvision.datasets.rendered_sst2.RenderedSST2'>, <class 'torchvision.datasets.semeion.SEMEION'>, <class 'torchvision.datasets.sun397.SUN397'>, <class 'torchvision.datasets.svhn.SVHN'>, <class 'clarena.stl_datasets.raw.sign_language_mnist.SignLanguageMNIST'>, <class 'torchvision.datasets.stanford_cars.StanfordCars'>, <class 'clarena.stl_datasets.raw.traffic_signs.TrafficSignsFromHAT'>, <class 'tinyimagenet.TinyImageNet'>, <class 'torchvision.datasets.usps.USPS'>]
The list of available datasets.
validation_percentage: float
The percentage to randomly split some training data into validation data.
def
prepare_data(self) -> None:
179 def prepare_data(self) -> None: 180 r"""Download the original datasets if haven't.""" 181 182 if self.task_id != 1: 183 return # download all original datasets only at the beginning of first task 184 185 failed_dataset_classes = [] 186 for task_id in range(1, self.num_tasks + 1): 187 root = self.root[task_id] 188 dataset_class = self.original_dataset_python_classes[task_id] 189 # torchvision datasets might have different APIs 190 try: 191 # collect the error and raise it at the end to avoid stopping the whole download process 192 193 if dataset_class in [ 194 ArabicHandwrittenDigits, 195 KannadaMNIST, 196 SignLanguageMNIST, 197 ]: 198 # these datasets have no automatic download function, we require users to download them manually 199 # the following code is just to check if the dataset is already downloaded 200 dataset_class(root=root, train=True, download=False) 201 dataset_class(root=root, train=False, download=False) 202 203 elif dataset_class in [ 204 Caltech101, 205 Caltech256, 206 EuroSAT, 207 SEMEION, 208 SUN397, 209 ]: 210 # dataset classes that don't have any train, val, test split 211 dataset_class(root=root, download=True) 212 213 elif dataset_class in [ 214 ArabicHandwrittenDigits, 215 CIFAR10, 216 CIFAR100, 217 CUB2002011, 218 EMNISTByClass, 219 EMNISTByMerge, 220 EMNISTBalanced, 221 EMNISTLetters, 222 EMNISTDigits, 223 FaceScrub10, 224 FaceScrub20, 225 FaceScrub50, 226 FaceScrub100, 227 FaceScrubFromHAT, 228 FashionMNIST, 229 KannadaMNIST, 230 KMNIST, 231 Linnaeus5_32, 232 Linnaeus5_64, 233 Linnaeus5_128, 234 Linnaeus5_256, 235 MNIST, 236 NotMNIST, 237 NotMNISTFromHAT, 238 SignLanguageMNIST, 239 TrafficSignsFromHAT, 240 USPS, 241 ]: 242 # dataset classes that have `train` bool argument 243 dataset_class(root=root, train=True, download=True) 244 dataset_class(root=root, train=False, download=True) 245 246 elif dataset_class in [ 247 Food101, 248 GTSRB, 249 StanfordCars, 250 SVHN, 251 ]: 252 # dataset classes that have `split` argument with 'train', 'test' 253 dataset_class( 254 root=root, 255 split="train", 256 download=True, 257 ) 258 dataset_class( 259 root=root, 260 split="test", 261 download=True, 262 ) 263 elif dataset_class in [Country211]: 264 # dataset classes that have `split` argument with 'train', 'valid', 'test' 265 dataset_class( 266 root=root, 267 split="train", 268 download=True, 269 ) 270 dataset_class( 271 root=root, 272 split="valid", 273 download=True, 274 ) 275 dataset_class( 276 root=root, 277 split="test", 278 download=True, 279 ) 280 281 elif dataset_class in [ 282 DTD, 283 FGVCAircraftVariant, 284 FGVCAircraftFamily, 285 FGVCAircraftManufacturer, 286 Flowers102, 287 PCAM, 288 RenderedSST2, 289 ]: 290 # dataset classes that have `split` argument with 'train', 'val', 'test' 291 dataset_class( 292 root=root, 293 split="train", 294 download=True, 295 ) 296 dataset_class( 297 root=root, 298 split="val", 299 download=True, 300 ) 301 dataset_class( 302 root=root, 303 split="test", 304 download=True, 305 ) 306 elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]: 307 # dataset classes that have `split` argument with 'trainval', 'test' 308 dataset_class( 309 root=root, 310 split="trainval", 311 download=True, 312 ) 313 dataset_class( 314 root=root, 315 split="test", 316 download=True, 317 ) 318 elif dataset_class == CelebA: 319 # special case 320 dataset_class( 321 root=root, 322 split="train", 323 target_type="identity", 324 download=True, 325 ) 326 dataset_class( 327 root=root, 328 split="valid", 329 target_type="identity", 330 download=True, 331 ) 332 dataset_class( 333 root=root, 334 split="test", 335 target_type="identity", 336 download=True, 337 ) 338 elif dataset_class == FER2013: 339 # special case 340 dataset_class( 341 root=root, 342 split="train", 343 ) 344 dataset_class( 345 root=root, 346 split="test", 347 ) 348 elif dataset_class == TinyImageNet: 349 # special case 350 dataset_class(root=root) 351 352 except RuntimeError: 353 failed_dataset_classes.append(dataset_class) # save for later prompt 354 else: 355 pylogger.debug( 356 "The original %s dataset for task %s has been downloaded to %s.", 357 dataset_class, 358 task_id, 359 root, 360 ) 361 362 if failed_dataset_classes: 363 raise RuntimeError( 364 f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually." 365 )
Download the original datasets if haven't.
def
train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
367 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 368 """Get the training and validation dataset of task `self.task_id`. 369 370 **Returns:** 371 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`. 372 """ 373 # torchvision datasets might have different APIs 374 if self.original_dataset_python_class_t in [ 375 ArabicHandwrittenDigits, 376 CIFAR10, 377 CIFAR100, 378 CUB2002011, 379 EMNISTByClass, 380 EMNISTByMerge, 381 EMNISTBalanced, 382 EMNISTLetters, 383 EMNISTDigits, 384 FaceScrub10, 385 FaceScrub20, 386 FaceScrub50, 387 FaceScrub100, 388 FaceScrubFromHAT, 389 FashionMNIST, 390 KannadaMNIST, 391 KMNIST, 392 Linnaeus5_32, 393 Linnaeus5_64, 394 Linnaeus5_128, 395 Linnaeus5_256, 396 MNIST, 397 NotMNIST, 398 NotMNISTFromHAT, 399 SignLanguageMNIST, 400 TrafficSignsFromHAT, 401 USPS, 402 ]: 403 # dataset classes that have `train` bool argument 404 dataset_train_and_val = self.original_dataset_python_class_t( 405 root=self.root_t, 406 train=True, 407 transform=self.train_and_val_transforms(), 408 target_transform=self.target_transform(), 409 ) 410 411 return random_split( 412 dataset_train_and_val, 413 lengths=[1 - self.validation_percentage, self.validation_percentage], 414 generator=torch.Generator().manual_seed( 415 42 416 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 417 ) 418 elif self.original_dataset_python_class_t in [ 419 Caltech101, 420 Caltech256, 421 EuroSAT, 422 SEMEION, 423 SUN397, 424 ]: 425 # dataset classes that don't have train and test splt 426 dataset_all = self.original_dataset_python_class_t( 427 root=self.root_t, 428 transform=self.train_and_val_transforms(), 429 target_transform=self.target_transform(), 430 ) 431 432 dataset_train_and_val, _ = random_split( 433 dataset_all, 434 lengths=[ 435 1 - self.test_percentage, 436 self.test_percentage, 437 ], 438 generator=torch.Generator().manual_seed( 439 42 440 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 441 ) 442 443 return random_split( 444 dataset_train_and_val, 445 lengths=[1 - self.validation_percentage, self.validation_percentage], 446 generator=torch.Generator().manual_seed( 447 42 448 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 449 ) 450 elif self.original_dataset_python_class_t in [Country211]: 451 # dataset classes that have `split` argument with 'train', 'valid', 'test' 452 dataset_train = self.original_dataset_python_class_t( 453 root=self.root_t, 454 split="train", 455 transform=self.train_and_val_transforms(), 456 target_transform=self.target_transform(), 457 ) 458 dataset_val = self.original_dataset_python_class_t( 459 root=self.root_t, 460 split="valid", 461 transform=self.train_and_val_transforms(), 462 target_transform=self.target_transform(), 463 ) 464 465 return dataset_train, dataset_val 466 467 elif self.original_dataset_python_class_t in [ 468 DTD, 469 FGVCAircraftVariant, 470 FGVCAircraftFamily, 471 FGVCAircraftManufacturer, 472 Flowers102, 473 PCAM, 474 RenderedSST2, 475 ]: 476 # dataset classes that have `split` argument with 'train', 'val', 'test' 477 dataset_train = self.original_dataset_python_class_t( 478 root=self.root_t, 479 split="train", 480 transform=self.train_and_val_transforms(), 481 target_transform=self.target_transform(), 482 ) 483 484 dataset_val = self.original_dataset_python_class_t( 485 root=self.root_t, 486 split="val", 487 transform=self.train_and_val_transforms(), 488 target_transform=self.target_transform(), 489 ) 490 491 return dataset_train, dataset_val 492 493 elif self.original_dataset_python_class_t in [ 494 FER2013, 495 Food101, 496 GTSRB, 497 StanfordCars, 498 SVHN, 499 TinyImageNet, 500 ]: 501 # dataset classes that have `split` argument with 'train', 'test' 502 503 dataset_train_and_val = self.original_dataset_python_class_t( 504 root=self.root_t, 505 split="train", 506 transform=self.train_and_val_transforms(), 507 target_transform=self.target_transform(), 508 ) 509 510 return random_split( 511 dataset_train_and_val, 512 lengths=[1 - self.validation_percentage, self.validation_percentage], 513 generator=torch.Generator().manual_seed( 514 42 515 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 516 ) 517 518 elif self.original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]: 519 # dataset classes that have `split` argument with 'trainval', 'test' 520 521 dataset_train_and_val = self.original_dataset_python_class_t( 522 root=self.root_t, 523 split="trainval", 524 transform=self.train_and_val_transforms(), 525 target_transform=self.target_transform(), 526 ) 527 528 return random_split( 529 dataset_train_and_val, 530 lengths=[1 - self.validation_percentage, self.validation_percentage], 531 generator=torch.Generator().manual_seed( 532 42 533 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 534 ) 535 536 elif self.original_dataset_python_class_t in [CelebA]: 537 # special case 538 dataset_train = self.original_dataset_python_class_t( 539 root=self.root_t, 540 split="train", 541 target_type="identity", 542 transform=self.train_and_val_transforms(), 543 target_transform=self.target_transform(), 544 ) 545 546 dataset_val = self.original_dataset_python_class_t( 547 root=self.root_t, 548 split="valid", 549 target_type="identity", 550 transform=self.train_and_val_transforms(), 551 target_transform=self.target_transform(), 552 ) 553 554 return dataset_train, dataset_val
Get the training and validation dataset of task self.task_id.
Returns:
- train_and_val_dataset (
tuple[Dataset, Dataset]): the train and validation dataset of taskself.task_id.
def
test_dataset(self) -> torch.utils.data.dataset.Dataset:
556 def test_dataset(self) -> Dataset: 557 r"""Get the test dataset of task `self.task_id`. 558 559 **Returns:** 560 - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`. 561 """ 562 # torchvision datasets might have different APIs 563 564 if self.original_dataset_python_class_t in [ 565 ArabicHandwrittenDigits, 566 CIFAR10, 567 CIFAR100, 568 CUB2002011, 569 EMNISTByClass, 570 EMNISTByMerge, 571 EMNISTBalanced, 572 EMNISTLetters, 573 EMNISTDigits, 574 FaceScrub10, 575 FaceScrub20, 576 FaceScrub50, 577 FaceScrub100, 578 FaceScrubFromHAT, 579 FashionMNIST, 580 KannadaMNIST, 581 KMNIST, 582 Linnaeus5_32, 583 Linnaeus5_64, 584 Linnaeus5_128, 585 Linnaeus5_256, 586 MNIST, 587 NotMNIST, 588 NotMNISTFromHAT, 589 SignLanguageMNIST, 590 TrafficSignsFromHAT, 591 USPS, 592 ]: 593 # dataset classes that have `train` bool argument 594 dataset_test = self.original_dataset_python_class_t( 595 root=self.root_t, 596 train=False, 597 transform=self.test_transforms(), 598 target_transform=self.target_transform(), 599 ) 600 601 return dataset_test 602 603 elif self.original_dataset_python_class_t in [ 604 Country211, 605 DTD, 606 FER2013, 607 FGVCAircraftVariant, 608 FGVCAircraftFamily, 609 FGVCAircraftManufacturer, 610 Flowers102, 611 Food101, 612 GTSRB, 613 OxfordIIITPet2, 614 OxfordIIITPet37, 615 PCAM, 616 RenderedSST2, 617 StanfordCars, 618 SVHN, 619 ]: 620 # dataset classes that have `split` argument with 'test' 621 622 dataset_test = self.original_dataset_python_class_t( 623 root=self.root_t, 624 split="test", 625 transform=self.test_transforms(), 626 target_transform=self.target_transform(), 627 ) 628 629 return dataset_test 630 631 elif self.original_dataset_python_class_t in [ 632 Caltech101, 633 Caltech256, 634 EuroSAT, 635 SEMEION, 636 SUN397, 637 ]: 638 # dataset classes that don't have train and test splt 639 640 dataset_all = self.original_dataset_python_class_t( 641 root=self.root_t, 642 transform=self.train_and_val_transforms(), 643 target_transform=self.target_transform(), 644 ) 645 646 _, dataset_test = random_split( 647 dataset_all, 648 lengths=[1 - self.test_percentage, self.test_percentage], 649 generator=torch.Generator().manual_seed(42), 650 ) 651 652 return dataset_test 653 654 elif self.original_dataset_python_class_t in [CelebA]: 655 # special case 656 dataset_test = self.original_dataset_python_class_t( 657 root=self.root_t, 658 split="test", 659 target_type="identity", 660 transform=self.test_transforms(), 661 target_transform=self.target_transform(), 662 ) 663 664 return dataset_test 665 666 elif self.original_dataset_python_class_t in [TinyImageNet]: 667 # special case 668 dataset_test = self.original_dataset_python_class_t( 669 root=self.root_t, 670 split="val", 671 transform=self.test_transforms(), 672 target_transform=self.target_transform(), 673 ) 674 675 return dataset_test
Get the test dataset of task self.task_id.
Returns:
- test_dataset (
Dataset): the test dataset of taskself.task_id.