clarena.cl_datasets.permuted_mnist
The submodule in cl_datasets
for Permuted MNIST dataset.
1""" 2The submodule in `cl_datasets` for Permuted MNIST dataset. 3""" 4 5__all__ = ["PermutedMNIST"] 6 7import logging 8from typing import Callable 9 10import torch 11from torch.utils.data import Dataset, random_split 12from torchvision.datasets import MNIST 13from torchvision.transforms import transforms 14 15from clarena.cl_datasets import CLPermutedDataset 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class PermutedMNIST(CLPermutedDataset): 22 """Permuted MNIST dataset.""" 23 24 num_class: int = 10 25 """The number of classes in MNIST.""" 26 27 img_size: torch.Size = torch.Size([1, 28, 28]) 28 """The size of MNIST images.""" 29 30 mean_original: tuple[float] = (0.1307,) 31 """The mean values for normalisation.""" 32 33 std_original: tuple[float] = (0.3081,) 34 """The standard deviation values for normalisation.""" 35 36 def __init__( 37 self, 38 root: str, 39 num_tasks: int, 40 validation_percentage: float, 41 batch_size: int = 1, 42 num_workers: int = 10, 43 custom_transforms: Callable | transforms.Compose | None = None, 44 custom_target_transforms: Callable | transforms.Compose | None = None, 45 permutation_mode: str = "first_channel_only", 46 permutation_seeds: list[int] | None = None, 47 ) -> None: 48 """Initialise the Permuted MNIST dataset. 49 50 **Args:** 51 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 52 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 53 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 54 - **batch_size** (`int`): The batch size in train, val, test dataloader. 55 - **num_workers** (`int`): the number of workers for dataloaders. 56 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 57 `ToTensor()`, normalise, permute and so on are not included. 58 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 59 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 60 1. 'all': permute all pixels. 61 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 62 3. 'first_channel_only': permute only the first channel. 63 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 64 """ 65 super().__init__( 66 root=root, 67 num_tasks=num_tasks, 68 validation_percentage=validation_percentage, 69 batch_size=batch_size, 70 num_workers=num_workers, 71 custom_transforms=custom_transforms, 72 custom_target_transforms=custom_target_transforms, 73 permutation_mode=permutation_mode, 74 permutation_seeds=permutation_seeds, 75 ) 76 77 def prepare_data(self) -> None: 78 """Download the original MNIST dataset if haven't.""" 79 # just download 80 MNIST(root=self.root, train=True, download=True) 81 MNIST(root=self.root, train=False, download=True) 82 83 pylogger.debug( 84 "The original MNIST dataset has been downloaded to %s.", self.root 85 ) 86 87 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 88 """Get the training and validation dataset of task `self.task_id`. 89 90 **Returns:** 91 - The train and validation dataset of task `self.task_id`.""" 92 dataset_train_and_val = MNIST( 93 root=self.root, 94 train=True, 95 transform=self.train_and_val_transforms(to_tensor=True), 96 target_transform=self.target_transforms(), 97 download=False, 98 ) 99 return random_split( 100 dataset_train_and_val, 101 lengths=[1 - self.validation_percentage, self.validation_percentage], 102 ) 103 104 def test_dataset(self) -> Dataset: 105 """Get the test dataset of task `self.task_id`. 106 107 **Returns:** 108 - The test dataset of task `self.task_id`. 109 """ 110 return MNIST( 111 root=self.root, 112 train=False, 113 transform=self.test_transforms(to_tensor=True), 114 target_transform=self.target_transforms(), 115 download=False, 116 )
class
PermutedMNIST(clarena.cl_datasets.base.CLPermutedDataset):
22class PermutedMNIST(CLPermutedDataset): 23 """Permuted MNIST dataset.""" 24 25 num_class: int = 10 26 """The number of classes in MNIST.""" 27 28 img_size: torch.Size = torch.Size([1, 28, 28]) 29 """The size of MNIST images.""" 30 31 mean_original: tuple[float] = (0.1307,) 32 """The mean values for normalisation.""" 33 34 std_original: tuple[float] = (0.3081,) 35 """The standard deviation values for normalisation.""" 36 37 def __init__( 38 self, 39 root: str, 40 num_tasks: int, 41 validation_percentage: float, 42 batch_size: int = 1, 43 num_workers: int = 10, 44 custom_transforms: Callable | transforms.Compose | None = None, 45 custom_target_transforms: Callable | transforms.Compose | None = None, 46 permutation_mode: str = "first_channel_only", 47 permutation_seeds: list[int] | None = None, 48 ) -> None: 49 """Initialise the Permuted MNIST dataset. 50 51 **Args:** 52 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 53 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 54 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 55 - **batch_size** (`int`): The batch size in train, val, test dataloader. 56 - **num_workers** (`int`): the number of workers for dataloaders. 57 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 58 `ToTensor()`, normalise, permute and so on are not included. 59 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 60 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 61 1. 'all': permute all pixels. 62 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 63 3. 'first_channel_only': permute only the first channel. 64 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 65 """ 66 super().__init__( 67 root=root, 68 num_tasks=num_tasks, 69 validation_percentage=validation_percentage, 70 batch_size=batch_size, 71 num_workers=num_workers, 72 custom_transforms=custom_transforms, 73 custom_target_transforms=custom_target_transforms, 74 permutation_mode=permutation_mode, 75 permutation_seeds=permutation_seeds, 76 ) 77 78 def prepare_data(self) -> None: 79 """Download the original MNIST dataset if haven't.""" 80 # just download 81 MNIST(root=self.root, train=True, download=True) 82 MNIST(root=self.root, train=False, download=True) 83 84 pylogger.debug( 85 "The original MNIST dataset has been downloaded to %s.", self.root 86 ) 87 88 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 89 """Get the training and validation dataset of task `self.task_id`. 90 91 **Returns:** 92 - The train and validation dataset of task `self.task_id`.""" 93 dataset_train_and_val = MNIST( 94 root=self.root, 95 train=True, 96 transform=self.train_and_val_transforms(to_tensor=True), 97 target_transform=self.target_transforms(), 98 download=False, 99 ) 100 return random_split( 101 dataset_train_and_val, 102 lengths=[1 - self.validation_percentage, self.validation_percentage], 103 ) 104 105 def test_dataset(self) -> Dataset: 106 """Get the test dataset of task `self.task_id`. 107 108 **Returns:** 109 - The test dataset of task `self.task_id`. 110 """ 111 return MNIST( 112 root=self.root, 113 train=False, 114 transform=self.test_transforms(to_tensor=True), 115 target_transform=self.target_transforms(), 116 download=False, 117 )
Permuted MNIST dataset.
PermutedMNIST( root: str, num_tasks: int, validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, permutation_mode: str = 'first_channel_only', permutation_seeds: list[int] | None = None)
37 def __init__( 38 self, 39 root: str, 40 num_tasks: int, 41 validation_percentage: float, 42 batch_size: int = 1, 43 num_workers: int = 10, 44 custom_transforms: Callable | transforms.Compose | None = None, 45 custom_target_transforms: Callable | transforms.Compose | None = None, 46 permutation_mode: str = "first_channel_only", 47 permutation_seeds: list[int] | None = None, 48 ) -> None: 49 """Initialise the Permuted MNIST dataset. 50 51 **Args:** 52 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 53 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 54 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 55 - **batch_size** (`int`): The batch size in train, val, test dataloader. 56 - **num_workers** (`int`): the number of workers for dataloaders. 57 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 58 `ToTensor()`, normalise, permute and so on are not included. 59 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 60 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 61 1. 'all': permute all pixels. 62 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 63 3. 'first_channel_only': permute only the first channel. 64 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 65 """ 66 super().__init__( 67 root=root, 68 num_tasks=num_tasks, 69 validation_percentage=validation_percentage, 70 batch_size=batch_size, 71 num_workers=num_workers, 72 custom_transforms=custom_transforms, 73 custom_target_transforms=custom_target_transforms, 74 permutation_mode=permutation_mode, 75 permutation_seeds=permutation_seeds, 76 )
Initialise the Permuted MNIST dataset.
Args:
- root (
str
): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. - permutation_mode (
str
): the mode of permutation, should be one of the following:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- permutation_seeds (
list[int]
orNone
): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds asnum_tasks
. Default is None, which creates a list of seeds from 1 tonum_tasks
.
def
prepare_data(self) -> None:
78 def prepare_data(self) -> None: 79 """Download the original MNIST dataset if haven't.""" 80 # just download 81 MNIST(root=self.root, train=True, download=True) 82 MNIST(root=self.root, train=False, download=True) 83 84 pylogger.debug( 85 "The original MNIST dataset has been downloaded to %s.", self.root 86 )
Download the original MNIST dataset if haven't.
def
train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
88 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 89 """Get the training and validation dataset of task `self.task_id`. 90 91 **Returns:** 92 - The train and validation dataset of task `self.task_id`.""" 93 dataset_train_and_val = MNIST( 94 root=self.root, 95 train=True, 96 transform=self.train_and_val_transforms(to_tensor=True), 97 target_transform=self.target_transforms(), 98 download=False, 99 ) 100 return random_split( 101 dataset_train_and_val, 102 lengths=[1 - self.validation_percentage, self.validation_percentage], 103 )
Get the training and validation dataset of task self.task_id
.
Returns:
- The train and validation dataset of task
self.task_id
.
def
test_dataset(self) -> torch.utils.data.dataset.Dataset:
105 def test_dataset(self) -> Dataset: 106 """Get the test dataset of task `self.task_id`. 107 108 **Returns:** 109 - The test dataset of task `self.task_id`. 110 """ 111 return MNIST( 112 root=self.root, 113 train=False, 114 transform=self.test_transforms(to_tensor=True), 115 target_transform=self.target_transforms(), 116 download=False, 117 )
Get the test dataset of task self.task_id
.
Returns:
- The test dataset of task
self.task_id
.