clarena.cl_datasets.split_cifar100
The submodule in cl_datasets
for Split CIFAR100 dataset.
1""" 2The submodule in `cl_datasets` for Split CIFAR100 dataset. 3""" 4 5__all__ = ["SplitCIFAR100"] 6 7import logging 8from typing import Callable 9 10from torch.utils.data import Dataset, random_split 11from torchvision.datasets import CIFAR100 12from torchvision.transforms import transforms 13 14from clarena.cl_datasets import CLSplitDataset 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class SplitCIFAR100(CLSplitDataset): 21 """Split CIFAR100 dataset.""" 22 23 num_classes: int = 100 24 """The number of classes in CIFAR100 dataset.""" 25 26 mean_original: tuple[float] = (0.1307,) 27 """The mean values for normalisation.""" 28 29 std_original: tuple[float] = (0.3081,) 30 """The standard deviation values for normalisation.""" 31 32 def __init__( 33 self, 34 root: str, 35 num_tasks: int, 36 class_split: list[list[int]], 37 validation_percentage: float, 38 batch_size: int = 1, 39 num_workers: int = 10, 40 custom_transforms: Callable | transforms.Compose | None = None, 41 custom_target_transforms: Callable | transforms.Compose | None = None, 42 ) -> None: 43 """Initialise the Permuted MNIST dataset. 44 45 **Args:** 46 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 47 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 48 - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. 49 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 50 - **batch_size** (`int`): The batch size in train, val, test dataloader. 51 - **num_workers** (`int`): the number of workers for dataloaders. 52 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 53 `ToTensor()`, normalise, permute and so on are not included. 54 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 55 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 56 1. 'all': permute all pixels. 57 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 58 3. 'first_channel_only': permute only the first channel. 59 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 60 """ 61 super().__init__( 62 root=root, 63 num_tasks=num_tasks, 64 class_split=class_split, 65 validation_percentage=validation_percentage, 66 batch_size=batch_size, 67 num_workers=num_workers, 68 custom_transforms=custom_transforms, 69 custom_target_transforms=custom_target_transforms, 70 ) 71 72 def prepare_data(self) -> None: 73 """Download the original MNIST dataset if haven't.""" 74 # just download 75 CIFAR100(root=self.root, train=True, download=True) 76 CIFAR100(root=self.root, train=False, download=True) 77 78 pylogger.debug( 79 "The original CIFAR100 dataset has been downloaded to %s.", self.root 80 ) 81 82 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 83 """Get the training and validation dataset of task `self.task_id`. 84 85 **Returns:** 86 - The train and validation dataset of task `self.task_id`.""" 87 dataset_train_and_val = self.get_class_subset( 88 CIFAR100( 89 root=self.root, 90 train=True, 91 transform=self.train_and_val_transforms(to_tensor=True), 92 target_transform=self.target_transforms(), 93 download=False, 94 ), 95 classes=self.class_split[self.task_id - 1], 96 ) 97 return random_split( 98 dataset_train_and_val, 99 lengths=[1 - self.validation_percentage, self.validation_percentage], 100 ) 101 102 def test_dataset(self) -> Dataset: 103 """Get the test dataset of task `self.task_id`. 104 105 **Returns:** 106 - The test dataset of task `self.task_id`. 107 """ 108 return self.get_class_subset( 109 CIFAR100( 110 root=self.root, 111 train=False, 112 transform=self.test_transforms(to_tensor=True), 113 target_transform=self.target_transforms(), 114 download=False, 115 ), 116 classes=self.class_split[self.task_id - 1], 117 )
class
SplitCIFAR100(clarena.cl_datasets.base.CLSplitDataset):
21class SplitCIFAR100(CLSplitDataset): 22 """Split CIFAR100 dataset.""" 23 24 num_classes: int = 100 25 """The number of classes in CIFAR100 dataset.""" 26 27 mean_original: tuple[float] = (0.1307,) 28 """The mean values for normalisation.""" 29 30 std_original: tuple[float] = (0.3081,) 31 """The standard deviation values for normalisation.""" 32 33 def __init__( 34 self, 35 root: str, 36 num_tasks: int, 37 class_split: list[list[int]], 38 validation_percentage: float, 39 batch_size: int = 1, 40 num_workers: int = 10, 41 custom_transforms: Callable | transforms.Compose | None = None, 42 custom_target_transforms: Callable | transforms.Compose | None = None, 43 ) -> None: 44 """Initialise the Permuted MNIST dataset. 45 46 **Args:** 47 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 48 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 49 - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. 50 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 51 - **batch_size** (`int`): The batch size in train, val, test dataloader. 52 - **num_workers** (`int`): the number of workers for dataloaders. 53 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 54 `ToTensor()`, normalise, permute and so on are not included. 55 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 56 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 57 1. 'all': permute all pixels. 58 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 59 3. 'first_channel_only': permute only the first channel. 60 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 61 """ 62 super().__init__( 63 root=root, 64 num_tasks=num_tasks, 65 class_split=class_split, 66 validation_percentage=validation_percentage, 67 batch_size=batch_size, 68 num_workers=num_workers, 69 custom_transforms=custom_transforms, 70 custom_target_transforms=custom_target_transforms, 71 ) 72 73 def prepare_data(self) -> None: 74 """Download the original MNIST dataset if haven't.""" 75 # just download 76 CIFAR100(root=self.root, train=True, download=True) 77 CIFAR100(root=self.root, train=False, download=True) 78 79 pylogger.debug( 80 "The original CIFAR100 dataset has been downloaded to %s.", self.root 81 ) 82 83 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 84 """Get the training and validation dataset of task `self.task_id`. 85 86 **Returns:** 87 - The train and validation dataset of task `self.task_id`.""" 88 dataset_train_and_val = self.get_class_subset( 89 CIFAR100( 90 root=self.root, 91 train=True, 92 transform=self.train_and_val_transforms(to_tensor=True), 93 target_transform=self.target_transforms(), 94 download=False, 95 ), 96 classes=self.class_split[self.task_id - 1], 97 ) 98 return random_split( 99 dataset_train_and_val, 100 lengths=[1 - self.validation_percentage, self.validation_percentage], 101 ) 102 103 def test_dataset(self) -> Dataset: 104 """Get the test dataset of task `self.task_id`. 105 106 **Returns:** 107 - The test dataset of task `self.task_id`. 108 """ 109 return self.get_class_subset( 110 CIFAR100( 111 root=self.root, 112 train=False, 113 transform=self.test_transforms(to_tensor=True), 114 target_transform=self.target_transforms(), 115 download=False, 116 ), 117 classes=self.class_split[self.task_id - 1], 118 )
Split CIFAR100 dataset.
SplitCIFAR100( root: str, num_tasks: int, class_split: list[list[int]], validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None)
33 def __init__( 34 self, 35 root: str, 36 num_tasks: int, 37 class_split: list[list[int]], 38 validation_percentage: float, 39 batch_size: int = 1, 40 num_workers: int = 10, 41 custom_transforms: Callable | transforms.Compose | None = None, 42 custom_target_transforms: Callable | transforms.Compose | None = None, 43 ) -> None: 44 """Initialise the Permuted MNIST dataset. 45 46 **Args:** 47 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 48 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 49 - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. 50 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 51 - **batch_size** (`int`): The batch size in train, val, test dataloader. 52 - **num_workers** (`int`): the number of workers for dataloaders. 53 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 54 `ToTensor()`, normalise, permute and so on are not included. 55 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 56 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 57 1. 'all': permute all pixels. 58 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 59 3. 'first_channel_only': permute only the first channel. 60 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 61 """ 62 super().__init__( 63 root=root, 64 num_tasks=num_tasks, 65 class_split=class_split, 66 validation_percentage=validation_percentage, 67 batch_size=batch_size, 68 num_workers=num_workers, 69 custom_transforms=custom_transforms, 70 custom_target_transforms=custom_target_transforms, 71 )
Initialise the Permuted MNIST dataset.
Args:
- root (
str
): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - class_split (
list[list[int]]
): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. - permutation_mode (
str
): the mode of permutation, should be one of the following:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- permutation_seeds (
list[int]
orNone
): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds asnum_tasks
. Default is None, which creates a list of seeds from 1 tonum_tasks
.
def
prepare_data(self) -> None:
73 def prepare_data(self) -> None: 74 """Download the original MNIST dataset if haven't.""" 75 # just download 76 CIFAR100(root=self.root, train=True, download=True) 77 CIFAR100(root=self.root, train=False, download=True) 78 79 pylogger.debug( 80 "The original CIFAR100 dataset has been downloaded to %s.", self.root 81 )
Download the original MNIST dataset if haven't.
def
train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
83 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 84 """Get the training and validation dataset of task `self.task_id`. 85 86 **Returns:** 87 - The train and validation dataset of task `self.task_id`.""" 88 dataset_train_and_val = self.get_class_subset( 89 CIFAR100( 90 root=self.root, 91 train=True, 92 transform=self.train_and_val_transforms(to_tensor=True), 93 target_transform=self.target_transforms(), 94 download=False, 95 ), 96 classes=self.class_split[self.task_id - 1], 97 ) 98 return random_split( 99 dataset_train_and_val, 100 lengths=[1 - self.validation_percentage, self.validation_percentage], 101 )
Get the training and validation dataset of task self.task_id
.
Returns:
- The train and validation dataset of task
self.task_id
.
def
test_dataset(self) -> torch.utils.data.dataset.Dataset:
103 def test_dataset(self) -> Dataset: 104 """Get the test dataset of task `self.task_id`. 105 106 **Returns:** 107 - The test dataset of task `self.task_id`. 108 """ 109 return self.get_class_subset( 110 CIFAR100( 111 root=self.root, 112 train=False, 113 transform=self.test_transforms(to_tensor=True), 114 target_transform=self.target_transforms(), 115 download=False, 116 ), 117 classes=self.class_split[self.task_id - 1], 118 )
Get the test dataset of task self.task_id
.
Returns:
- The test dataset of task
self.task_id
.