clarena.mtl_datasets.combined
The submodule in mtl_datasets for CelebA dataset used for multi-task learning.
1r""" 2The submodule in `mtl_datasets` for CelebA dataset used for multi-task learning. 3""" 4 5__all__ = ["Combined"] 6 7import logging 8from typing import Callable 9 10import torch 11from tinyimagenet import TinyImageNet 12from torch.utils.data import Dataset, random_split 13from torchvision.datasets import ( 14 CIFAR10, 15 CIFAR100, 16 DTD, 17 FER2013, 18 GTSRB, 19 KMNIST, 20 MNIST, 21 PCAM, 22 SEMEION, 23 SUN397, 24 SVHN, 25 USPS, 26 Caltech101, 27 Caltech256, 28 CelebA, 29 Country211, 30 EuroSAT, 31 FashionMNIST, 32 Flowers102, 33 Food101, 34 RenderedSST2, 35 StanfordCars, 36) 37from torchvision.datasets.vision import VisionDataset 38from torchvision.transforms import transforms 39 40from clarena.mtl_datasets import MTLCombinedDataset 41from clarena.stl_datasets.raw import ( 42 CUB2002011, 43 ArabicHandwrittenDigits, 44 EMNISTBalanced, 45 EMNISTByClass, 46 EMNISTByMerge, 47 EMNISTDigits, 48 EMNISTLetters, 49 FaceScrub10, 50 FaceScrub20, 51 FaceScrub50, 52 FaceScrub100, 53 FaceScrubFromHAT, 54 FGVCAircraftFamily, 55 FGVCAircraftManufacturer, 56 FGVCAircraftVariant, 57 KannadaMNIST, 58 Linnaeus5_32, 59 Linnaeus5_64, 60 Linnaeus5_128, 61 Linnaeus5_256, 62 NotMNIST, 63 NotMNISTFromHAT, 64 OxfordIIITPet2, 65 OxfordIIITPet37, 66 SignLanguageMNIST, 67 TrafficSignsFromHAT, 68) 69 70# always get logger for built-in logging in each module 71pylogger = logging.getLogger(__name__) 72 73 74class Combined(MTLCombinedDataset): 75 r"""Combined MTL dataset from available datasets.""" 76 77 AVAILABLE_DATASETS: list[VisionDataset] = [ 78 ArabicHandwrittenDigits, 79 CIFAR10, 80 CIFAR100, 81 CUB2002011, 82 Caltech101, 83 Caltech256, 84 CelebA, 85 Country211, 86 DTD, 87 EMNISTBalanced, 88 EMNISTByClass, 89 EMNISTByMerge, 90 EMNISTDigits, 91 EMNISTLetters, 92 EuroSAT, 93 FER2013, 94 FGVCAircraftFamily, 95 FGVCAircraftManufacturer, 96 FGVCAircraftVariant, 97 FaceScrub10, 98 FaceScrub100, 99 FaceScrubFromHAT, 100 FaceScrub20, 101 FaceScrub50, 102 FashionMNIST, 103 Flowers102, 104 Food101, 105 GTSRB, 106 KMNIST, 107 KannadaMNIST, 108 Linnaeus5_128, 109 Linnaeus5_256, 110 Linnaeus5_32, 111 Linnaeus5_64, 112 MNIST, 113 NotMNIST, 114 NotMNISTFromHAT, 115 OxfordIIITPet2, 116 OxfordIIITPet37, 117 PCAM, 118 RenderedSST2, 119 SEMEION, 120 SUN397, 121 SVHN, 122 SignLanguageMNIST, 123 StanfordCars, 124 TrafficSignsFromHAT, 125 TinyImageNet, 126 USPS, 127 ] 128 r"""The list of available datasets.""" 129 130 def __init__( 131 self, 132 datasets: dict[int, str], 133 root: str | dict[int, str], 134 validation_percentage: float, 135 test_percentage: float, 136 sampling_strategy: str = "mixed", 137 batch_size: int = 1, 138 num_workers: int = 0, 139 custom_transforms: ( 140 Callable 141 | transforms.Compose 142 | None 143 | dict[int, Callable | transforms.Compose | None] 144 ) = None, 145 repeat_channels: int | None | dict[int, int | None] = None, 146 to_tensor: bool | dict[int, bool] = True, 147 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 148 ) -> None: 149 r""" 150 **Args:** 151 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 152 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories. 153 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset). 154 - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset). 155 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 156 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 157 - **batch_size** (`int`): The batch size in train, val, test dataloader. 158 - **num_workers** (`int`): the number of workers for dataloaders. 159 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included. 160 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 161 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 162 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 163 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 164 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 165 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 166 """ 167 super().__init__( 168 datasets=datasets, 169 root=root, 170 sampling_strategy=sampling_strategy, 171 batch_size=batch_size, 172 num_workers=num_workers, 173 custom_transforms=custom_transforms, 174 repeat_channels=repeat_channels, 175 to_tensor=to_tensor, 176 resize=resize, 177 ) 178 179 self.test_percentage: float = test_percentage 180 """The percentage to randomly split some data into test data.""" 181 self.validation_percentage: float = validation_percentage 182 """The percentage to randomly split some training data into validation data.""" 183 184 def prepare_data(self) -> None: 185 r"""Download the original datasets if haven't.""" 186 187 failed_dataset_classes = [] 188 for task_id in range(1, self.num_tasks + 1): 189 root = self.root[task_id] 190 dataset_class = self.original_dataset_python_classes[task_id] 191 # torchvision datasets might have different APIs 192 try: 193 # collect the error and raise it at the end to avoid stopping the whole download process 194 195 if dataset_class in [ 196 ArabicHandwrittenDigits, 197 KannadaMNIST, 198 SignLanguageMNIST, 199 ]: 200 # these datasets have no automatic download function, we require users to download them manually 201 # the following code is just to check if the dataset is already downloaded 202 dataset_class(root=root, train=True, download=False) 203 dataset_class(root=root, train=False, download=False) 204 205 elif dataset_class in [ 206 Caltech101, 207 Caltech256, 208 EuroSAT, 209 SEMEION, 210 SUN397, 211 ]: 212 # dataset classes that don't have any train, val, test split 213 dataset_class(root=root, download=True) 214 215 elif dataset_class in [ 216 ArabicHandwrittenDigits, 217 CIFAR10, 218 CIFAR100, 219 CUB2002011, 220 EMNISTByClass, 221 EMNISTByMerge, 222 EMNISTBalanced, 223 EMNISTLetters, 224 EMNISTDigits, 225 FaceScrub10, 226 FaceScrub20, 227 FaceScrub50, 228 FaceScrub100, 229 FaceScrubFromHAT, 230 FashionMNIST, 231 KannadaMNIST, 232 KMNIST, 233 Linnaeus5_32, 234 Linnaeus5_64, 235 Linnaeus5_128, 236 Linnaeus5_256, 237 MNIST, 238 NotMNIST, 239 NotMNISTFromHAT, 240 SignLanguageMNIST, 241 TrafficSignsFromHAT, 242 USPS, 243 ]: 244 # dataset classes that have `train` bool argument 245 dataset_class(root=root, train=True, download=True) 246 dataset_class(root=root, train=False, download=True) 247 248 elif dataset_class in [ 249 Food101, 250 GTSRB, 251 StanfordCars, 252 SVHN, 253 ]: 254 # dataset classes that have `split` argument with 'train', 'test' 255 dataset_class( 256 root=root, 257 split="train", 258 download=True, 259 ) 260 dataset_class( 261 root=root, 262 split="test", 263 download=True, 264 ) 265 elif dataset_class in [Country211]: 266 # dataset classes that have `split` argument with 'train', 'valid', 'test' 267 dataset_class( 268 root=root, 269 split="train", 270 download=True, 271 ) 272 dataset_class( 273 root=root, 274 split="valid", 275 download=True, 276 ) 277 dataset_class( 278 root=root, 279 split="test", 280 download=True, 281 ) 282 283 elif dataset_class in [ 284 DTD, 285 FGVCAircraftVariant, 286 FGVCAircraftFamily, 287 FGVCAircraftManufacturer, 288 Flowers102, 289 PCAM, 290 RenderedSST2, 291 ]: 292 # dataset classes that have `split` argument with 'train', 'val', 'test' 293 dataset_class( 294 root=root, 295 split="train", 296 download=True, 297 ) 298 dataset_class( 299 root=root, 300 split="val", 301 download=True, 302 ) 303 dataset_class( 304 root=root, 305 split="test", 306 download=True, 307 ) 308 elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]: 309 # dataset classes that have `split` argument with 'trainval', 'test' 310 dataset_class( 311 root=root, 312 split="trainval", 313 download=True, 314 ) 315 dataset_class( 316 root=root, 317 split="test", 318 download=True, 319 ) 320 elif dataset_class == CelebA: 321 # special case 322 dataset_class( 323 root=root, 324 split="train", 325 target_type="identity", 326 download=True, 327 ) 328 dataset_class( 329 root=root, 330 split="valid", 331 target_type="identity", 332 download=True, 333 ) 334 dataset_class( 335 root=root, 336 split="test", 337 target_type="identity", 338 download=True, 339 ) 340 elif dataset_class == FER2013: 341 # special case 342 dataset_class( 343 root=root, 344 split="train", 345 ) 346 dataset_class( 347 root=root, 348 split="test", 349 ) 350 elif dataset_class == TinyImageNet: 351 # special case 352 dataset_class(root=root) 353 354 except RuntimeError: 355 failed_dataset_classes.append(dataset_class) # save for later prompt 356 else: 357 pylogger.debug( 358 "The original %s dataset for task %s has been downloaded to %s.", 359 dataset_class, 360 task_id, 361 root, 362 ) 363 364 if failed_dataset_classes: 365 raise RuntimeError( 366 f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually." 367 ) 368 369 def train_and_val_dataset(self, task_id: int) -> tuple[Dataset, Dataset]: 370 r"""Get the training and validation dataset of task `task_id`. 371 372 **Args:** 373 - **task_id** (`int`): the task ID to get the training and validation dataset for. 374 375 **Returns:** 376 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `task_id`. 377 """ 378 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 379 380 # torchvision datasets might have different APIs 381 if original_dataset_python_class_t in [ 382 ArabicHandwrittenDigits, 383 CIFAR10, 384 CIFAR100, 385 CUB2002011, 386 EMNISTByClass, 387 EMNISTByMerge, 388 EMNISTBalanced, 389 EMNISTLetters, 390 EMNISTDigits, 391 FaceScrub10, 392 FaceScrub20, 393 FaceScrub50, 394 FaceScrub100, 395 FaceScrubFromHAT, 396 FashionMNIST, 397 KannadaMNIST, 398 KMNIST, 399 Linnaeus5_32, 400 Linnaeus5_64, 401 Linnaeus5_128, 402 Linnaeus5_256, 403 MNIST, 404 NotMNIST, 405 NotMNISTFromHAT, 406 SignLanguageMNIST, 407 TrafficSignsFromHAT, 408 USPS, 409 ]: 410 # dataset classes that have `train` bool argument 411 dataset_train_and_val = original_dataset_python_class_t( 412 root=self.root[task_id], 413 train=True, 414 transform=self.train_and_val_transforms(task_id), 415 target_transform=self.target_transform(task_id), 416 ) 417 418 return random_split( 419 dataset_train_and_val, 420 lengths=[1 - self.validation_percentage, self.validation_percentage], 421 generator=torch.Generator().manual_seed( 422 42 423 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 424 ) 425 elif original_dataset_python_class_t in [ 426 Caltech101, 427 Caltech256, 428 EuroSAT, 429 SEMEION, 430 SUN397, 431 ]: 432 # dataset classes that don't have train and test splt 433 dataset_all = original_dataset_python_class_t( 434 root=self.root[task_id], 435 transform=self.train_and_val_transforms(task_id), 436 target_transform=self.target_transform(task_id), 437 ) 438 439 dataset_train_and_val, _ = random_split( 440 dataset_all, 441 lengths=[ 442 1 - self.test_percentage, 443 self.test_percentage, 444 ], 445 generator=torch.Generator().manual_seed( 446 42 447 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 448 ) 449 450 return random_split( 451 dataset_train_and_val, 452 lengths=[1 - self.validation_percentage, self.validation_percentage], 453 generator=torch.Generator().manual_seed( 454 42 455 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 456 ) 457 elif original_dataset_python_class_t in [Country211]: 458 # dataset classes that have `split` argument with 'train', 'valid', 'test' 459 dataset_train = original_dataset_python_class_t( 460 root=self.root[task_id], 461 split="train", 462 transform=self.train_and_val_transforms(task_id), 463 target_transform=self.target_transform(task_id), 464 ) 465 dataset_val = original_dataset_python_class_t( 466 root=self.root[task_id], 467 split="valid", 468 transform=self.train_and_val_transforms(task_id), 469 target_transform=self.target_transform(task_id), 470 ) 471 472 return dataset_train, dataset_val 473 474 elif original_dataset_python_class_t in [ 475 DTD, 476 FGVCAircraftVariant, 477 FGVCAircraftFamily, 478 FGVCAircraftManufacturer, 479 Flowers102, 480 PCAM, 481 RenderedSST2, 482 ]: 483 # dataset classes that have `split` argument with 'train', 'val', 'test' 484 dataset_train = original_dataset_python_class_t( 485 root=self.root[task_id], 486 split="train", 487 transform=self.train_and_val_transforms(task_id), 488 target_transform=self.target_transform(task_id), 489 ) 490 491 dataset_val = original_dataset_python_class_t( 492 root=self.root[task_id], 493 split="val", 494 transform=self.train_and_val_transforms(task_id), 495 target_transform=self.target_transform(task_id), 496 ) 497 498 return dataset_train, dataset_val 499 500 elif original_dataset_python_class_t in [ 501 FER2013, 502 Food101, 503 GTSRB, 504 StanfordCars, 505 SVHN, 506 TinyImageNet, 507 ]: 508 # dataset classes that have `split` argument with 'train', 'test' 509 510 dataset_train_and_val = original_dataset_python_class_t( 511 root=self.root[task_id], 512 split="train", 513 transform=self.train_and_val_transforms(task_id), 514 target_transform=self.target_transform(task_id), 515 ) 516 517 return random_split( 518 dataset_train_and_val, 519 lengths=[1 - self.validation_percentage, self.validation_percentage], 520 generator=torch.Generator().manual_seed( 521 42 522 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 523 ) 524 525 elif original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]: 526 # dataset classes that have `split` argument with 'trainval', 'test' 527 528 dataset_train_and_val = original_dataset_python_class_t( 529 root=self.root[task_id], 530 split="trainval", 531 transform=self.train_and_val_transforms(task_id), 532 target_transform=self.target_transform(task_id), 533 ) 534 535 return random_split( 536 dataset_train_and_val, 537 lengths=[1 - self.validation_percentage, self.validation_percentage], 538 generator=torch.Generator().manual_seed( 539 42 540 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 541 ) 542 543 elif original_dataset_python_class_t in [CelebA]: 544 # special case 545 dataset_train = original_dataset_python_class_t( 546 root=self.root[task_id], 547 split="train", 548 target_type="identity", 549 transform=self.train_and_val_transforms(task_id), 550 target_transform=self.target_transform(task_id), 551 ) 552 553 dataset_val = original_dataset_python_class_t( 554 root=self.root[task_id], 555 split="valid", 556 target_type="identity", 557 transform=self.train_and_val_transforms(task_id), 558 target_transform=self.target_transform(task_id), 559 ) 560 561 return dataset_train, dataset_val 562 563 def test_dataset(self, task_id: int) -> Dataset: 564 """Get the test dataset of task `task_id`. 565 566 **Args:** 567 - **task_id** (`int`): the task ID to get the test dataset for. 568 569 **Returns:** 570 - **test_dataset** (`Dataset`): the test dataset of task `task_id`. 571 """ 572 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 573 574 if original_dataset_python_class_t in [ 575 ArabicHandwrittenDigits, 576 CIFAR10, 577 CIFAR100, 578 CUB2002011, 579 EMNISTByClass, 580 EMNISTByMerge, 581 EMNISTBalanced, 582 EMNISTLetters, 583 EMNISTDigits, 584 FaceScrub10, 585 FaceScrub20, 586 FaceScrub50, 587 FaceScrub100, 588 FaceScrubFromHAT, 589 FashionMNIST, 590 KannadaMNIST, 591 KMNIST, 592 Linnaeus5_32, 593 Linnaeus5_64, 594 Linnaeus5_128, 595 Linnaeus5_256, 596 MNIST, 597 NotMNIST, 598 NotMNISTFromHAT, 599 SignLanguageMNIST, 600 TrafficSignsFromHAT, 601 USPS, 602 ]: 603 # dataset classes that have `train` bool argument 604 dataset_test = original_dataset_python_class_t( 605 root=self.root[task_id], 606 train=False, 607 transform=self.test_transforms(task_id), 608 target_transform=self.target_transform(task_id), 609 ) 610 611 return dataset_test 612 613 elif original_dataset_python_class_t in [ 614 Country211, 615 DTD, 616 FER2013, 617 FGVCAircraftVariant, 618 FGVCAircraftFamily, 619 FGVCAircraftManufacturer, 620 Flowers102, 621 Food101, 622 GTSRB, 623 OxfordIIITPet2, 624 OxfordIIITPet37, 625 PCAM, 626 RenderedSST2, 627 StanfordCars, 628 SVHN, 629 ]: 630 # dataset classes that have `split` argument with 'test' 631 632 dataset_test = original_dataset_python_class_t( 633 root=self.root[task_id], 634 split="test", 635 transform=self.test_transforms(task_id), 636 target_transform=self.target_transform(task_id), 637 ) 638 639 return dataset_test 640 641 elif original_dataset_python_class_t in [ 642 Caltech101, 643 Caltech256, 644 EuroSAT, 645 SEMEION, 646 SUN397, 647 ]: 648 # dataset classes that don't have train and test splt 649 650 dataset_all = original_dataset_python_class_t( 651 root=self.root[task_id], 652 transform=self.train_and_val_transforms(task_id), 653 target_transform=self.target_transform(task_id), 654 ) 655 656 _, dataset_test = random_split( 657 dataset_all, 658 lengths=[1 - self.test_percentage, self.test_percentage], 659 generator=torch.Generator().manual_seed(42), 660 ) 661 662 return dataset_test 663 664 elif original_dataset_python_class_t in [CelebA]: 665 # special case 666 dataset_test = original_dataset_python_class_t( 667 root=self.root[task_id], 668 split="test", 669 target_type="identity", 670 transform=self.test_transforms(task_id), 671 target_transform=self.target_transform(task_id), 672 ) 673 674 return dataset_test 675 676 elif original_dataset_python_class_t in [TinyImageNet]: 677 # special case 678 dataset_test = original_dataset_python_class_t( 679 root=self.root[task_id], 680 split="val", 681 transform=self.test_transforms(task_id), 682 target_transform=self.target_transform(task_id), 683 ) 684 685 return dataset_test
class
Combined(clarena.mtl_datasets.base.MTLCombinedDataset):
75class Combined(MTLCombinedDataset): 76 r"""Combined MTL dataset from available datasets.""" 77 78 AVAILABLE_DATASETS: list[VisionDataset] = [ 79 ArabicHandwrittenDigits, 80 CIFAR10, 81 CIFAR100, 82 CUB2002011, 83 Caltech101, 84 Caltech256, 85 CelebA, 86 Country211, 87 DTD, 88 EMNISTBalanced, 89 EMNISTByClass, 90 EMNISTByMerge, 91 EMNISTDigits, 92 EMNISTLetters, 93 EuroSAT, 94 FER2013, 95 FGVCAircraftFamily, 96 FGVCAircraftManufacturer, 97 FGVCAircraftVariant, 98 FaceScrub10, 99 FaceScrub100, 100 FaceScrubFromHAT, 101 FaceScrub20, 102 FaceScrub50, 103 FashionMNIST, 104 Flowers102, 105 Food101, 106 GTSRB, 107 KMNIST, 108 KannadaMNIST, 109 Linnaeus5_128, 110 Linnaeus5_256, 111 Linnaeus5_32, 112 Linnaeus5_64, 113 MNIST, 114 NotMNIST, 115 NotMNISTFromHAT, 116 OxfordIIITPet2, 117 OxfordIIITPet37, 118 PCAM, 119 RenderedSST2, 120 SEMEION, 121 SUN397, 122 SVHN, 123 SignLanguageMNIST, 124 StanfordCars, 125 TrafficSignsFromHAT, 126 TinyImageNet, 127 USPS, 128 ] 129 r"""The list of available datasets.""" 130 131 def __init__( 132 self, 133 datasets: dict[int, str], 134 root: str | dict[int, str], 135 validation_percentage: float, 136 test_percentage: float, 137 sampling_strategy: str = "mixed", 138 batch_size: int = 1, 139 num_workers: int = 0, 140 custom_transforms: ( 141 Callable 142 | transforms.Compose 143 | None 144 | dict[int, Callable | transforms.Compose | None] 145 ) = None, 146 repeat_channels: int | None | dict[int, int | None] = None, 147 to_tensor: bool | dict[int, bool] = True, 148 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 149 ) -> None: 150 r""" 151 **Args:** 152 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 153 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories. 154 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset). 155 - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset). 156 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 157 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 158 - **batch_size** (`int`): The batch size in train, val, test dataloader. 159 - **num_workers** (`int`): the number of workers for dataloaders. 160 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included. 161 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 162 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 163 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 164 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 165 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 166 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 167 """ 168 super().__init__( 169 datasets=datasets, 170 root=root, 171 sampling_strategy=sampling_strategy, 172 batch_size=batch_size, 173 num_workers=num_workers, 174 custom_transforms=custom_transforms, 175 repeat_channels=repeat_channels, 176 to_tensor=to_tensor, 177 resize=resize, 178 ) 179 180 self.test_percentage: float = test_percentage 181 """The percentage to randomly split some data into test data.""" 182 self.validation_percentage: float = validation_percentage 183 """The percentage to randomly split some training data into validation data.""" 184 185 def prepare_data(self) -> None: 186 r"""Download the original datasets if haven't.""" 187 188 failed_dataset_classes = [] 189 for task_id in range(1, self.num_tasks + 1): 190 root = self.root[task_id] 191 dataset_class = self.original_dataset_python_classes[task_id] 192 # torchvision datasets might have different APIs 193 try: 194 # collect the error and raise it at the end to avoid stopping the whole download process 195 196 if dataset_class in [ 197 ArabicHandwrittenDigits, 198 KannadaMNIST, 199 SignLanguageMNIST, 200 ]: 201 # these datasets have no automatic download function, we require users to download them manually 202 # the following code is just to check if the dataset is already downloaded 203 dataset_class(root=root, train=True, download=False) 204 dataset_class(root=root, train=False, download=False) 205 206 elif dataset_class in [ 207 Caltech101, 208 Caltech256, 209 EuroSAT, 210 SEMEION, 211 SUN397, 212 ]: 213 # dataset classes that don't have any train, val, test split 214 dataset_class(root=root, download=True) 215 216 elif dataset_class in [ 217 ArabicHandwrittenDigits, 218 CIFAR10, 219 CIFAR100, 220 CUB2002011, 221 EMNISTByClass, 222 EMNISTByMerge, 223 EMNISTBalanced, 224 EMNISTLetters, 225 EMNISTDigits, 226 FaceScrub10, 227 FaceScrub20, 228 FaceScrub50, 229 FaceScrub100, 230 FaceScrubFromHAT, 231 FashionMNIST, 232 KannadaMNIST, 233 KMNIST, 234 Linnaeus5_32, 235 Linnaeus5_64, 236 Linnaeus5_128, 237 Linnaeus5_256, 238 MNIST, 239 NotMNIST, 240 NotMNISTFromHAT, 241 SignLanguageMNIST, 242 TrafficSignsFromHAT, 243 USPS, 244 ]: 245 # dataset classes that have `train` bool argument 246 dataset_class(root=root, train=True, download=True) 247 dataset_class(root=root, train=False, download=True) 248 249 elif dataset_class in [ 250 Food101, 251 GTSRB, 252 StanfordCars, 253 SVHN, 254 ]: 255 # dataset classes that have `split` argument with 'train', 'test' 256 dataset_class( 257 root=root, 258 split="train", 259 download=True, 260 ) 261 dataset_class( 262 root=root, 263 split="test", 264 download=True, 265 ) 266 elif dataset_class in [Country211]: 267 # dataset classes that have `split` argument with 'train', 'valid', 'test' 268 dataset_class( 269 root=root, 270 split="train", 271 download=True, 272 ) 273 dataset_class( 274 root=root, 275 split="valid", 276 download=True, 277 ) 278 dataset_class( 279 root=root, 280 split="test", 281 download=True, 282 ) 283 284 elif dataset_class in [ 285 DTD, 286 FGVCAircraftVariant, 287 FGVCAircraftFamily, 288 FGVCAircraftManufacturer, 289 Flowers102, 290 PCAM, 291 RenderedSST2, 292 ]: 293 # dataset classes that have `split` argument with 'train', 'val', 'test' 294 dataset_class( 295 root=root, 296 split="train", 297 download=True, 298 ) 299 dataset_class( 300 root=root, 301 split="val", 302 download=True, 303 ) 304 dataset_class( 305 root=root, 306 split="test", 307 download=True, 308 ) 309 elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]: 310 # dataset classes that have `split` argument with 'trainval', 'test' 311 dataset_class( 312 root=root, 313 split="trainval", 314 download=True, 315 ) 316 dataset_class( 317 root=root, 318 split="test", 319 download=True, 320 ) 321 elif dataset_class == CelebA: 322 # special case 323 dataset_class( 324 root=root, 325 split="train", 326 target_type="identity", 327 download=True, 328 ) 329 dataset_class( 330 root=root, 331 split="valid", 332 target_type="identity", 333 download=True, 334 ) 335 dataset_class( 336 root=root, 337 split="test", 338 target_type="identity", 339 download=True, 340 ) 341 elif dataset_class == FER2013: 342 # special case 343 dataset_class( 344 root=root, 345 split="train", 346 ) 347 dataset_class( 348 root=root, 349 split="test", 350 ) 351 elif dataset_class == TinyImageNet: 352 # special case 353 dataset_class(root=root) 354 355 except RuntimeError: 356 failed_dataset_classes.append(dataset_class) # save for later prompt 357 else: 358 pylogger.debug( 359 "The original %s dataset for task %s has been downloaded to %s.", 360 dataset_class, 361 task_id, 362 root, 363 ) 364 365 if failed_dataset_classes: 366 raise RuntimeError( 367 f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually." 368 ) 369 370 def train_and_val_dataset(self, task_id: int) -> tuple[Dataset, Dataset]: 371 r"""Get the training and validation dataset of task `task_id`. 372 373 **Args:** 374 - **task_id** (`int`): the task ID to get the training and validation dataset for. 375 376 **Returns:** 377 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `task_id`. 378 """ 379 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 380 381 # torchvision datasets might have different APIs 382 if original_dataset_python_class_t in [ 383 ArabicHandwrittenDigits, 384 CIFAR10, 385 CIFAR100, 386 CUB2002011, 387 EMNISTByClass, 388 EMNISTByMerge, 389 EMNISTBalanced, 390 EMNISTLetters, 391 EMNISTDigits, 392 FaceScrub10, 393 FaceScrub20, 394 FaceScrub50, 395 FaceScrub100, 396 FaceScrubFromHAT, 397 FashionMNIST, 398 KannadaMNIST, 399 KMNIST, 400 Linnaeus5_32, 401 Linnaeus5_64, 402 Linnaeus5_128, 403 Linnaeus5_256, 404 MNIST, 405 NotMNIST, 406 NotMNISTFromHAT, 407 SignLanguageMNIST, 408 TrafficSignsFromHAT, 409 USPS, 410 ]: 411 # dataset classes that have `train` bool argument 412 dataset_train_and_val = original_dataset_python_class_t( 413 root=self.root[task_id], 414 train=True, 415 transform=self.train_and_val_transforms(task_id), 416 target_transform=self.target_transform(task_id), 417 ) 418 419 return random_split( 420 dataset_train_and_val, 421 lengths=[1 - self.validation_percentage, self.validation_percentage], 422 generator=torch.Generator().manual_seed( 423 42 424 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 425 ) 426 elif original_dataset_python_class_t in [ 427 Caltech101, 428 Caltech256, 429 EuroSAT, 430 SEMEION, 431 SUN397, 432 ]: 433 # dataset classes that don't have train and test splt 434 dataset_all = original_dataset_python_class_t( 435 root=self.root[task_id], 436 transform=self.train_and_val_transforms(task_id), 437 target_transform=self.target_transform(task_id), 438 ) 439 440 dataset_train_and_val, _ = random_split( 441 dataset_all, 442 lengths=[ 443 1 - self.test_percentage, 444 self.test_percentage, 445 ], 446 generator=torch.Generator().manual_seed( 447 42 448 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 449 ) 450 451 return random_split( 452 dataset_train_and_val, 453 lengths=[1 - self.validation_percentage, self.validation_percentage], 454 generator=torch.Generator().manual_seed( 455 42 456 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 457 ) 458 elif original_dataset_python_class_t in [Country211]: 459 # dataset classes that have `split` argument with 'train', 'valid', 'test' 460 dataset_train = original_dataset_python_class_t( 461 root=self.root[task_id], 462 split="train", 463 transform=self.train_and_val_transforms(task_id), 464 target_transform=self.target_transform(task_id), 465 ) 466 dataset_val = original_dataset_python_class_t( 467 root=self.root[task_id], 468 split="valid", 469 transform=self.train_and_val_transforms(task_id), 470 target_transform=self.target_transform(task_id), 471 ) 472 473 return dataset_train, dataset_val 474 475 elif original_dataset_python_class_t in [ 476 DTD, 477 FGVCAircraftVariant, 478 FGVCAircraftFamily, 479 FGVCAircraftManufacturer, 480 Flowers102, 481 PCAM, 482 RenderedSST2, 483 ]: 484 # dataset classes that have `split` argument with 'train', 'val', 'test' 485 dataset_train = original_dataset_python_class_t( 486 root=self.root[task_id], 487 split="train", 488 transform=self.train_and_val_transforms(task_id), 489 target_transform=self.target_transform(task_id), 490 ) 491 492 dataset_val = original_dataset_python_class_t( 493 root=self.root[task_id], 494 split="val", 495 transform=self.train_and_val_transforms(task_id), 496 target_transform=self.target_transform(task_id), 497 ) 498 499 return dataset_train, dataset_val 500 501 elif original_dataset_python_class_t in [ 502 FER2013, 503 Food101, 504 GTSRB, 505 StanfordCars, 506 SVHN, 507 TinyImageNet, 508 ]: 509 # dataset classes that have `split` argument with 'train', 'test' 510 511 dataset_train_and_val = original_dataset_python_class_t( 512 root=self.root[task_id], 513 split="train", 514 transform=self.train_and_val_transforms(task_id), 515 target_transform=self.target_transform(task_id), 516 ) 517 518 return random_split( 519 dataset_train_and_val, 520 lengths=[1 - self.validation_percentage, self.validation_percentage], 521 generator=torch.Generator().manual_seed( 522 42 523 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 524 ) 525 526 elif original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]: 527 # dataset classes that have `split` argument with 'trainval', 'test' 528 529 dataset_train_and_val = original_dataset_python_class_t( 530 root=self.root[task_id], 531 split="trainval", 532 transform=self.train_and_val_transforms(task_id), 533 target_transform=self.target_transform(task_id), 534 ) 535 536 return random_split( 537 dataset_train_and_val, 538 lengths=[1 - self.validation_percentage, self.validation_percentage], 539 generator=torch.Generator().manual_seed( 540 42 541 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 542 ) 543 544 elif original_dataset_python_class_t in [CelebA]: 545 # special case 546 dataset_train = original_dataset_python_class_t( 547 root=self.root[task_id], 548 split="train", 549 target_type="identity", 550 transform=self.train_and_val_transforms(task_id), 551 target_transform=self.target_transform(task_id), 552 ) 553 554 dataset_val = original_dataset_python_class_t( 555 root=self.root[task_id], 556 split="valid", 557 target_type="identity", 558 transform=self.train_and_val_transforms(task_id), 559 target_transform=self.target_transform(task_id), 560 ) 561 562 return dataset_train, dataset_val 563 564 def test_dataset(self, task_id: int) -> Dataset: 565 """Get the test dataset of task `task_id`. 566 567 **Args:** 568 - **task_id** (`int`): the task ID to get the test dataset for. 569 570 **Returns:** 571 - **test_dataset** (`Dataset`): the test dataset of task `task_id`. 572 """ 573 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 574 575 if original_dataset_python_class_t in [ 576 ArabicHandwrittenDigits, 577 CIFAR10, 578 CIFAR100, 579 CUB2002011, 580 EMNISTByClass, 581 EMNISTByMerge, 582 EMNISTBalanced, 583 EMNISTLetters, 584 EMNISTDigits, 585 FaceScrub10, 586 FaceScrub20, 587 FaceScrub50, 588 FaceScrub100, 589 FaceScrubFromHAT, 590 FashionMNIST, 591 KannadaMNIST, 592 KMNIST, 593 Linnaeus5_32, 594 Linnaeus5_64, 595 Linnaeus5_128, 596 Linnaeus5_256, 597 MNIST, 598 NotMNIST, 599 NotMNISTFromHAT, 600 SignLanguageMNIST, 601 TrafficSignsFromHAT, 602 USPS, 603 ]: 604 # dataset classes that have `train` bool argument 605 dataset_test = original_dataset_python_class_t( 606 root=self.root[task_id], 607 train=False, 608 transform=self.test_transforms(task_id), 609 target_transform=self.target_transform(task_id), 610 ) 611 612 return dataset_test 613 614 elif original_dataset_python_class_t in [ 615 Country211, 616 DTD, 617 FER2013, 618 FGVCAircraftVariant, 619 FGVCAircraftFamily, 620 FGVCAircraftManufacturer, 621 Flowers102, 622 Food101, 623 GTSRB, 624 OxfordIIITPet2, 625 OxfordIIITPet37, 626 PCAM, 627 RenderedSST2, 628 StanfordCars, 629 SVHN, 630 ]: 631 # dataset classes that have `split` argument with 'test' 632 633 dataset_test = original_dataset_python_class_t( 634 root=self.root[task_id], 635 split="test", 636 transform=self.test_transforms(task_id), 637 target_transform=self.target_transform(task_id), 638 ) 639 640 return dataset_test 641 642 elif original_dataset_python_class_t in [ 643 Caltech101, 644 Caltech256, 645 EuroSAT, 646 SEMEION, 647 SUN397, 648 ]: 649 # dataset classes that don't have train and test splt 650 651 dataset_all = original_dataset_python_class_t( 652 root=self.root[task_id], 653 transform=self.train_and_val_transforms(task_id), 654 target_transform=self.target_transform(task_id), 655 ) 656 657 _, dataset_test = random_split( 658 dataset_all, 659 lengths=[1 - self.test_percentage, self.test_percentage], 660 generator=torch.Generator().manual_seed(42), 661 ) 662 663 return dataset_test 664 665 elif original_dataset_python_class_t in [CelebA]: 666 # special case 667 dataset_test = original_dataset_python_class_t( 668 root=self.root[task_id], 669 split="test", 670 target_type="identity", 671 transform=self.test_transforms(task_id), 672 target_transform=self.target_transform(task_id), 673 ) 674 675 return dataset_test 676 677 elif original_dataset_python_class_t in [TinyImageNet]: 678 # special case 679 dataset_test = original_dataset_python_class_t( 680 root=self.root[task_id], 681 split="val", 682 transform=self.test_transforms(task_id), 683 target_transform=self.target_transform(task_id), 684 ) 685 686 return dataset_test
Combined MTL dataset from available datasets.
Combined( datasets: dict[int, str], root: str | dict[int, str], validation_percentage: float, test_percentage: float, sampling_strategy: str = 'mixed', batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
131 def __init__( 132 self, 133 datasets: dict[int, str], 134 root: str | dict[int, str], 135 validation_percentage: float, 136 test_percentage: float, 137 sampling_strategy: str = "mixed", 138 batch_size: int = 1, 139 num_workers: int = 0, 140 custom_transforms: ( 141 Callable 142 | transforms.Compose 143 | None 144 | dict[int, Callable | transforms.Compose | None] 145 ) = None, 146 repeat_channels: int | None | dict[int, int | None] = None, 147 to_tensor: bool | dict[int, bool] = True, 148 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 149 ) -> None: 150 r""" 151 **Args:** 152 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 153 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories. 154 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset). 155 - **test_percentage** (`float`): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset). 156 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 157 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 158 - **batch_size** (`int`): The batch size in train, val, test dataloader. 159 - **num_workers** (`int`): the number of workers for dataloaders. 160 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included. 161 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 162 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 163 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 164 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 165 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 166 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 167 """ 168 super().__init__( 169 datasets=datasets, 170 root=root, 171 sampling_strategy=sampling_strategy, 172 batch_size=batch_size, 173 num_workers=num_workers, 174 custom_transforms=custom_transforms, 175 repeat_channels=repeat_channels, 176 to_tensor=to_tensor, 177 resize=resize, 178 ) 179 180 self.test_percentage: float = test_percentage 181 """The percentage to randomly split some data into test data.""" 182 self.validation_percentage: float = validation_percentage 183 """The percentage to randomly split some training data into validation data."""
Args:
- datasets (
dict[int, str]): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. - root (
str|dict[int, str]): the root directory where the original data files for constructing the MTL dataset physically live. Ifdict[int, str], it should be a dict of task IDs and their corresponding root directories. - validation_percentage (
float): the percentage to randomly split some of the training data into validation data (only if validation set is not provided in the dataset). - test_percentage (
float): the percentage to randomly split some of the entire data into test data (only if test set is not provided in the dataset). - sampling_strategy (
str): the sampling strategy that construct training batch from each task's dataset; one of:- 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
- batch_size (
int): The batch size in train, val, test dataloader. - num_workers (
int): the number of workers for dataloaders. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
AVAILABLE_DATASETS: list[torchvision.datasets.vision.VisionDataset] =
[<class 'clarena.stl_datasets.raw.ahdd.ArabicHandwrittenDigits'>, <class 'torchvision.datasets.cifar.CIFAR10'>, <class 'torchvision.datasets.cifar.CIFAR100'>, <class 'clarena.stl_datasets.raw.cub2002011.CUB2002011'>, <class 'torchvision.datasets.caltech.Caltech101'>, <class 'torchvision.datasets.caltech.Caltech256'>, <class 'torchvision.datasets.celeba.CelebA'>, <class 'torchvision.datasets.country211.Country211'>, <class 'torchvision.datasets.dtd.DTD'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTBalanced'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTByClass'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTByMerge'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTDigits'>, <class 'clarena.stl_datasets.raw.emnist.EMNISTLetters'>, <class 'torchvision.datasets.eurosat.EuroSAT'>, <class 'torchvision.datasets.fer2013.FER2013'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftFamily'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftManufacturer'>, <class 'clarena.stl_datasets.raw.fgvc_aircraft.FGVCAircraftVariant'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub10'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub100'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrubFromHAT'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub20'>, <class 'clarena.stl_datasets.raw.facescrub.FaceScrub50'>, <class 'torchvision.datasets.mnist.FashionMNIST'>, <class 'torchvision.datasets.flowers102.Flowers102'>, <class 'torchvision.datasets.food101.Food101'>, <class 'torchvision.datasets.gtsrb.GTSRB'>, <class 'torchvision.datasets.mnist.KMNIST'>, <class 'clarena.stl_datasets.raw.kannada_mnist.KannadaMNIST'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_128'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_256'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_32'>, <class 'clarena.stl_datasets.raw.linnaeus5.Linnaeus5_64'>, <class 'torchvision.datasets.mnist.MNIST'>, <class 'clarena.stl_datasets.raw.notmnist.NotMNIST'>, <class 'clarena.stl_datasets.raw.notmnist.NotMNISTFromHAT'>, <class 'clarena.stl_datasets.raw.oxford_iiit_pet.OxfordIIITPet2'>, <class 'clarena.stl_datasets.raw.oxford_iiit_pet.OxfordIIITPet37'>, <class 'torchvision.datasets.pcam.PCAM'>, <class 'torchvision.datasets.rendered_sst2.RenderedSST2'>, <class 'torchvision.datasets.semeion.SEMEION'>, <class 'torchvision.datasets.sun397.SUN397'>, <class 'torchvision.datasets.svhn.SVHN'>, <class 'clarena.stl_datasets.raw.sign_language_mnist.SignLanguageMNIST'>, <class 'torchvision.datasets.stanford_cars.StanfordCars'>, <class 'clarena.stl_datasets.raw.traffic_signs.TrafficSignsFromHAT'>, <class 'tinyimagenet.TinyImageNet'>, <class 'torchvision.datasets.usps.USPS'>]
The list of available datasets.
validation_percentage: float
The percentage to randomly split some training data into validation data.
def
prepare_data(self) -> None:
185 def prepare_data(self) -> None: 186 r"""Download the original datasets if haven't.""" 187 188 failed_dataset_classes = [] 189 for task_id in range(1, self.num_tasks + 1): 190 root = self.root[task_id] 191 dataset_class = self.original_dataset_python_classes[task_id] 192 # torchvision datasets might have different APIs 193 try: 194 # collect the error and raise it at the end to avoid stopping the whole download process 195 196 if dataset_class in [ 197 ArabicHandwrittenDigits, 198 KannadaMNIST, 199 SignLanguageMNIST, 200 ]: 201 # these datasets have no automatic download function, we require users to download them manually 202 # the following code is just to check if the dataset is already downloaded 203 dataset_class(root=root, train=True, download=False) 204 dataset_class(root=root, train=False, download=False) 205 206 elif dataset_class in [ 207 Caltech101, 208 Caltech256, 209 EuroSAT, 210 SEMEION, 211 SUN397, 212 ]: 213 # dataset classes that don't have any train, val, test split 214 dataset_class(root=root, download=True) 215 216 elif dataset_class in [ 217 ArabicHandwrittenDigits, 218 CIFAR10, 219 CIFAR100, 220 CUB2002011, 221 EMNISTByClass, 222 EMNISTByMerge, 223 EMNISTBalanced, 224 EMNISTLetters, 225 EMNISTDigits, 226 FaceScrub10, 227 FaceScrub20, 228 FaceScrub50, 229 FaceScrub100, 230 FaceScrubFromHAT, 231 FashionMNIST, 232 KannadaMNIST, 233 KMNIST, 234 Linnaeus5_32, 235 Linnaeus5_64, 236 Linnaeus5_128, 237 Linnaeus5_256, 238 MNIST, 239 NotMNIST, 240 NotMNISTFromHAT, 241 SignLanguageMNIST, 242 TrafficSignsFromHAT, 243 USPS, 244 ]: 245 # dataset classes that have `train` bool argument 246 dataset_class(root=root, train=True, download=True) 247 dataset_class(root=root, train=False, download=True) 248 249 elif dataset_class in [ 250 Food101, 251 GTSRB, 252 StanfordCars, 253 SVHN, 254 ]: 255 # dataset classes that have `split` argument with 'train', 'test' 256 dataset_class( 257 root=root, 258 split="train", 259 download=True, 260 ) 261 dataset_class( 262 root=root, 263 split="test", 264 download=True, 265 ) 266 elif dataset_class in [Country211]: 267 # dataset classes that have `split` argument with 'train', 'valid', 'test' 268 dataset_class( 269 root=root, 270 split="train", 271 download=True, 272 ) 273 dataset_class( 274 root=root, 275 split="valid", 276 download=True, 277 ) 278 dataset_class( 279 root=root, 280 split="test", 281 download=True, 282 ) 283 284 elif dataset_class in [ 285 DTD, 286 FGVCAircraftVariant, 287 FGVCAircraftFamily, 288 FGVCAircraftManufacturer, 289 Flowers102, 290 PCAM, 291 RenderedSST2, 292 ]: 293 # dataset classes that have `split` argument with 'train', 'val', 'test' 294 dataset_class( 295 root=root, 296 split="train", 297 download=True, 298 ) 299 dataset_class( 300 root=root, 301 split="val", 302 download=True, 303 ) 304 dataset_class( 305 root=root, 306 split="test", 307 download=True, 308 ) 309 elif dataset_class in [OxfordIIITPet2, OxfordIIITPet37]: 310 # dataset classes that have `split` argument with 'trainval', 'test' 311 dataset_class( 312 root=root, 313 split="trainval", 314 download=True, 315 ) 316 dataset_class( 317 root=root, 318 split="test", 319 download=True, 320 ) 321 elif dataset_class == CelebA: 322 # special case 323 dataset_class( 324 root=root, 325 split="train", 326 target_type="identity", 327 download=True, 328 ) 329 dataset_class( 330 root=root, 331 split="valid", 332 target_type="identity", 333 download=True, 334 ) 335 dataset_class( 336 root=root, 337 split="test", 338 target_type="identity", 339 download=True, 340 ) 341 elif dataset_class == FER2013: 342 # special case 343 dataset_class( 344 root=root, 345 split="train", 346 ) 347 dataset_class( 348 root=root, 349 split="test", 350 ) 351 elif dataset_class == TinyImageNet: 352 # special case 353 dataset_class(root=root) 354 355 except RuntimeError: 356 failed_dataset_classes.append(dataset_class) # save for later prompt 357 else: 358 pylogger.debug( 359 "The original %s dataset for task %s has been downloaded to %s.", 360 dataset_class, 361 task_id, 362 root, 363 ) 364 365 if failed_dataset_classes: 366 raise RuntimeError( 367 f"The following datasets failed to download: {failed_dataset_classes}. Please try downloading them again or manually." 368 )
Download the original datasets if haven't.
def
train_and_val_dataset( self, task_id: int) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
370 def train_and_val_dataset(self, task_id: int) -> tuple[Dataset, Dataset]: 371 r"""Get the training and validation dataset of task `task_id`. 372 373 **Args:** 374 - **task_id** (`int`): the task ID to get the training and validation dataset for. 375 376 **Returns:** 377 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `task_id`. 378 """ 379 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 380 381 # torchvision datasets might have different APIs 382 if original_dataset_python_class_t in [ 383 ArabicHandwrittenDigits, 384 CIFAR10, 385 CIFAR100, 386 CUB2002011, 387 EMNISTByClass, 388 EMNISTByMerge, 389 EMNISTBalanced, 390 EMNISTLetters, 391 EMNISTDigits, 392 FaceScrub10, 393 FaceScrub20, 394 FaceScrub50, 395 FaceScrub100, 396 FaceScrubFromHAT, 397 FashionMNIST, 398 KannadaMNIST, 399 KMNIST, 400 Linnaeus5_32, 401 Linnaeus5_64, 402 Linnaeus5_128, 403 Linnaeus5_256, 404 MNIST, 405 NotMNIST, 406 NotMNISTFromHAT, 407 SignLanguageMNIST, 408 TrafficSignsFromHAT, 409 USPS, 410 ]: 411 # dataset classes that have `train` bool argument 412 dataset_train_and_val = original_dataset_python_class_t( 413 root=self.root[task_id], 414 train=True, 415 transform=self.train_and_val_transforms(task_id), 416 target_transform=self.target_transform(task_id), 417 ) 418 419 return random_split( 420 dataset_train_and_val, 421 lengths=[1 - self.validation_percentage, self.validation_percentage], 422 generator=torch.Generator().manual_seed( 423 42 424 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 425 ) 426 elif original_dataset_python_class_t in [ 427 Caltech101, 428 Caltech256, 429 EuroSAT, 430 SEMEION, 431 SUN397, 432 ]: 433 # dataset classes that don't have train and test splt 434 dataset_all = original_dataset_python_class_t( 435 root=self.root[task_id], 436 transform=self.train_and_val_transforms(task_id), 437 target_transform=self.target_transform(task_id), 438 ) 439 440 dataset_train_and_val, _ = random_split( 441 dataset_all, 442 lengths=[ 443 1 - self.test_percentage, 444 self.test_percentage, 445 ], 446 generator=torch.Generator().manual_seed( 447 42 448 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 449 ) 450 451 return random_split( 452 dataset_train_and_val, 453 lengths=[1 - self.validation_percentage, self.validation_percentage], 454 generator=torch.Generator().manual_seed( 455 42 456 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 457 ) 458 elif original_dataset_python_class_t in [Country211]: 459 # dataset classes that have `split` argument with 'train', 'valid', 'test' 460 dataset_train = original_dataset_python_class_t( 461 root=self.root[task_id], 462 split="train", 463 transform=self.train_and_val_transforms(task_id), 464 target_transform=self.target_transform(task_id), 465 ) 466 dataset_val = original_dataset_python_class_t( 467 root=self.root[task_id], 468 split="valid", 469 transform=self.train_and_val_transforms(task_id), 470 target_transform=self.target_transform(task_id), 471 ) 472 473 return dataset_train, dataset_val 474 475 elif original_dataset_python_class_t in [ 476 DTD, 477 FGVCAircraftVariant, 478 FGVCAircraftFamily, 479 FGVCAircraftManufacturer, 480 Flowers102, 481 PCAM, 482 RenderedSST2, 483 ]: 484 # dataset classes that have `split` argument with 'train', 'val', 'test' 485 dataset_train = original_dataset_python_class_t( 486 root=self.root[task_id], 487 split="train", 488 transform=self.train_and_val_transforms(task_id), 489 target_transform=self.target_transform(task_id), 490 ) 491 492 dataset_val = original_dataset_python_class_t( 493 root=self.root[task_id], 494 split="val", 495 transform=self.train_and_val_transforms(task_id), 496 target_transform=self.target_transform(task_id), 497 ) 498 499 return dataset_train, dataset_val 500 501 elif original_dataset_python_class_t in [ 502 FER2013, 503 Food101, 504 GTSRB, 505 StanfordCars, 506 SVHN, 507 TinyImageNet, 508 ]: 509 # dataset classes that have `split` argument with 'train', 'test' 510 511 dataset_train_and_val = original_dataset_python_class_t( 512 root=self.root[task_id], 513 split="train", 514 transform=self.train_and_val_transforms(task_id), 515 target_transform=self.target_transform(task_id), 516 ) 517 518 return random_split( 519 dataset_train_and_val, 520 lengths=[1 - self.validation_percentage, self.validation_percentage], 521 generator=torch.Generator().manual_seed( 522 42 523 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 524 ) 525 526 elif original_dataset_python_class_t in [OxfordIIITPet2, OxfordIIITPet37]: 527 # dataset classes that have `split` argument with 'trainval', 'test' 528 529 dataset_train_and_val = original_dataset_python_class_t( 530 root=self.root[task_id], 531 split="trainval", 532 transform=self.train_and_val_transforms(task_id), 533 target_transform=self.target_transform(task_id), 534 ) 535 536 return random_split( 537 dataset_train_and_val, 538 lengths=[1 - self.validation_percentage, self.validation_percentage], 539 generator=torch.Generator().manual_seed( 540 42 541 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 542 ) 543 544 elif original_dataset_python_class_t in [CelebA]: 545 # special case 546 dataset_train = original_dataset_python_class_t( 547 root=self.root[task_id], 548 split="train", 549 target_type="identity", 550 transform=self.train_and_val_transforms(task_id), 551 target_transform=self.target_transform(task_id), 552 ) 553 554 dataset_val = original_dataset_python_class_t( 555 root=self.root[task_id], 556 split="valid", 557 target_type="identity", 558 transform=self.train_and_val_transforms(task_id), 559 target_transform=self.target_transform(task_id), 560 ) 561 562 return dataset_train, dataset_val
Get the training and validation dataset of task task_id.
Args:
- task_id (
int): the task ID to get the training and validation dataset for.
Returns:
- train_and_val_dataset (
tuple[Dataset, Dataset]): the train and validation dataset of tasktask_id.
def
test_dataset(self, task_id: int) -> torch.utils.data.dataset.Dataset:
564 def test_dataset(self, task_id: int) -> Dataset: 565 """Get the test dataset of task `task_id`. 566 567 **Args:** 568 - **task_id** (`int`): the task ID to get the test dataset for. 569 570 **Returns:** 571 - **test_dataset** (`Dataset`): the test dataset of task `task_id`. 572 """ 573 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 574 575 if original_dataset_python_class_t in [ 576 ArabicHandwrittenDigits, 577 CIFAR10, 578 CIFAR100, 579 CUB2002011, 580 EMNISTByClass, 581 EMNISTByMerge, 582 EMNISTBalanced, 583 EMNISTLetters, 584 EMNISTDigits, 585 FaceScrub10, 586 FaceScrub20, 587 FaceScrub50, 588 FaceScrub100, 589 FaceScrubFromHAT, 590 FashionMNIST, 591 KannadaMNIST, 592 KMNIST, 593 Linnaeus5_32, 594 Linnaeus5_64, 595 Linnaeus5_128, 596 Linnaeus5_256, 597 MNIST, 598 NotMNIST, 599 NotMNISTFromHAT, 600 SignLanguageMNIST, 601 TrafficSignsFromHAT, 602 USPS, 603 ]: 604 # dataset classes that have `train` bool argument 605 dataset_test = original_dataset_python_class_t( 606 root=self.root[task_id], 607 train=False, 608 transform=self.test_transforms(task_id), 609 target_transform=self.target_transform(task_id), 610 ) 611 612 return dataset_test 613 614 elif original_dataset_python_class_t in [ 615 Country211, 616 DTD, 617 FER2013, 618 FGVCAircraftVariant, 619 FGVCAircraftFamily, 620 FGVCAircraftManufacturer, 621 Flowers102, 622 Food101, 623 GTSRB, 624 OxfordIIITPet2, 625 OxfordIIITPet37, 626 PCAM, 627 RenderedSST2, 628 StanfordCars, 629 SVHN, 630 ]: 631 # dataset classes that have `split` argument with 'test' 632 633 dataset_test = original_dataset_python_class_t( 634 root=self.root[task_id], 635 split="test", 636 transform=self.test_transforms(task_id), 637 target_transform=self.target_transform(task_id), 638 ) 639 640 return dataset_test 641 642 elif original_dataset_python_class_t in [ 643 Caltech101, 644 Caltech256, 645 EuroSAT, 646 SEMEION, 647 SUN397, 648 ]: 649 # dataset classes that don't have train and test splt 650 651 dataset_all = original_dataset_python_class_t( 652 root=self.root[task_id], 653 transform=self.train_and_val_transforms(task_id), 654 target_transform=self.target_transform(task_id), 655 ) 656 657 _, dataset_test = random_split( 658 dataset_all, 659 lengths=[1 - self.test_percentage, self.test_percentage], 660 generator=torch.Generator().manual_seed(42), 661 ) 662 663 return dataset_test 664 665 elif original_dataset_python_class_t in [CelebA]: 666 # special case 667 dataset_test = original_dataset_python_class_t( 668 root=self.root[task_id], 669 split="test", 670 target_type="identity", 671 transform=self.test_transforms(task_id), 672 target_transform=self.target_transform(task_id), 673 ) 674 675 return dataset_test 676 677 elif original_dataset_python_class_t in [TinyImageNet]: 678 # special case 679 dataset_test = original_dataset_python_class_t( 680 root=self.root[task_id], 681 split="val", 682 transform=self.test_transforms(task_id), 683 target_transform=self.target_transform(task_id), 684 ) 685 686 return dataset_test
Get the test dataset of task task_id.
Args:
- task_id (
int): the task ID to get the test dataset for.
Returns:
- test_dataset (
Dataset): the test dataset of tasktask_id.