clarena.cl_datasets.permuted_mnist
The submodule in cl_datasets
for Permuted MNIST dataset.
1r""" 2The submodule in `cl_datasets` for Permuted MNIST dataset. 3""" 4 5__all__ = ["PermutedMNIST"] 6 7import logging 8from typing import Callable 9 10import torch 11from torch.utils.data import Dataset, random_split 12from torchvision.datasets import MNIST 13from torchvision.transforms import transforms 14 15from clarena.cl_datasets import CLPermutedDataset 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class PermutedMNIST(CLPermutedDataset): 22 r"""Permuted MNIST dataset. The [original MNIST dataset](http://yann.lecun.com/exdb/mnist/) is a collection of handwritten digits. It consists of 70,000 28x28 B&W images in 10 classes (correspond to 10 digits), with 7000 images per class. There are 60,000 training examples and 10,000 test examples.""" 23 24 num_classes: int = 10 25 """The number of classes in MNIST.""" 26 27 img_size: torch.Size = torch.Size([1, 28, 28]) 28 """The size of MNIST images.""" 29 30 mean_original: tuple[float] = (0.1307,) 31 """The mean values for normalisation.""" 32 33 std_original: tuple[float] = (0.3081,) 34 """The standard deviatfion values for normalisation.""" 35 36 def __init__( 37 self, 38 root: str, 39 num_tasks: int, 40 validation_percentage: float, 41 batch_size: int = 1, 42 num_workers: int = 8, 43 custom_transforms: Callable | transforms.Compose | None = None, 44 custom_target_transforms: Callable | transforms.Compose | None = None, 45 permutation_mode: str = "first_channel_only", 46 permutation_seeds: list[int] | None = None, 47 ) -> None: 48 r"""Initialise the Permuted MNIST dataset. 49 50 **Args:** 51 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 52 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 53 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 54 - **batch_size** (`int`): The batch size in train, val, test dataloader. 55 - **num_workers** (`int`): the number of workers for dataloaders. 56 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 57 `ToTensor()`, normalise, permute and so on are not included. 58 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 59 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 60 1. 'all': permute all pixels. 61 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 62 3. 'first_channel_only': permute only the first channel. 63 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 64 """ 65 CLPermutedDataset.__init__( 66 self, 67 root=root, 68 num_tasks=num_tasks, 69 validation_percentage=validation_percentage, 70 batch_size=batch_size, 71 num_workers=num_workers, 72 custom_transforms=custom_transforms, 73 custom_target_transforms=custom_target_transforms, 74 permutation_mode=permutation_mode, 75 permutation_seeds=permutation_seeds, 76 ) 77 78 def prepare_data(self) -> None: 79 r"""Download the original MNIST dataset if haven't.""" 80 # just download 81 MNIST(root=self.root, train=True, download=True) 82 MNIST(root=self.root, train=False, download=True) 83 84 pylogger.debug( 85 "The original MNIST dataset has been downloaded to %s.", self.root 86 ) 87 88 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 89 """Get the training and validation dataset of task `self.task_id`. 90 91 **Returns:** 92 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`. 93 """ 94 dataset_train_and_val = MNIST( 95 root=self.root, 96 train=True, 97 transform=self.train_and_val_transforms(to_tensor=True), 98 download=False, 99 ) 100 return random_split( 101 dataset_train_and_val, 102 lengths=[1 - self.validation_percentage, self.validation_percentage], 103 ) 104 105 def test_dataset(self) -> Dataset: 106 r"""Get the test dataset of task `self.task_id`. 107 108 **Returns:** 109 - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`. 110 """ 111 112 return MNIST( 113 root=self.root, 114 train=False, 115 transform=self.test_transforms(to_tensor=True), 116 download=False, 117 )
class
PermutedMNIST(clarena.cl_datasets.base.CLPermutedDataset):
22class PermutedMNIST(CLPermutedDataset): 23 r"""Permuted MNIST dataset. The [original MNIST dataset](http://yann.lecun.com/exdb/mnist/) is a collection of handwritten digits. It consists of 70,000 28x28 B&W images in 10 classes (correspond to 10 digits), with 7000 images per class. There are 60,000 training examples and 10,000 test examples.""" 24 25 num_classes: int = 10 26 """The number of classes in MNIST.""" 27 28 img_size: torch.Size = torch.Size([1, 28, 28]) 29 """The size of MNIST images.""" 30 31 mean_original: tuple[float] = (0.1307,) 32 """The mean values for normalisation.""" 33 34 std_original: tuple[float] = (0.3081,) 35 """The standard deviatfion values for normalisation.""" 36 37 def __init__( 38 self, 39 root: str, 40 num_tasks: int, 41 validation_percentage: float, 42 batch_size: int = 1, 43 num_workers: int = 8, 44 custom_transforms: Callable | transforms.Compose | None = None, 45 custom_target_transforms: Callable | transforms.Compose | None = None, 46 permutation_mode: str = "first_channel_only", 47 permutation_seeds: list[int] | None = None, 48 ) -> None: 49 r"""Initialise the Permuted MNIST dataset. 50 51 **Args:** 52 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 53 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 54 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 55 - **batch_size** (`int`): The batch size in train, val, test dataloader. 56 - **num_workers** (`int`): the number of workers for dataloaders. 57 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 58 `ToTensor()`, normalise, permute and so on are not included. 59 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 60 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 61 1. 'all': permute all pixels. 62 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 63 3. 'first_channel_only': permute only the first channel. 64 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 65 """ 66 CLPermutedDataset.__init__( 67 self, 68 root=root, 69 num_tasks=num_tasks, 70 validation_percentage=validation_percentage, 71 batch_size=batch_size, 72 num_workers=num_workers, 73 custom_transforms=custom_transforms, 74 custom_target_transforms=custom_target_transforms, 75 permutation_mode=permutation_mode, 76 permutation_seeds=permutation_seeds, 77 ) 78 79 def prepare_data(self) -> None: 80 r"""Download the original MNIST dataset if haven't.""" 81 # just download 82 MNIST(root=self.root, train=True, download=True) 83 MNIST(root=self.root, train=False, download=True) 84 85 pylogger.debug( 86 "The original MNIST dataset has been downloaded to %s.", self.root 87 ) 88 89 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 90 """Get the training and validation dataset of task `self.task_id`. 91 92 **Returns:** 93 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`. 94 """ 95 dataset_train_and_val = MNIST( 96 root=self.root, 97 train=True, 98 transform=self.train_and_val_transforms(to_tensor=True), 99 download=False, 100 ) 101 return random_split( 102 dataset_train_and_val, 103 lengths=[1 - self.validation_percentage, self.validation_percentage], 104 ) 105 106 def test_dataset(self) -> Dataset: 107 r"""Get the test dataset of task `self.task_id`. 108 109 **Returns:** 110 - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`. 111 """ 112 113 return MNIST( 114 root=self.root, 115 train=False, 116 transform=self.test_transforms(to_tensor=True), 117 download=False, 118 )
Permuted MNIST dataset. The original MNIST dataset is a collection of handwritten digits. It consists of 70,000 28x28 B&W images in 10 classes (correspond to 10 digits), with 7000 images per class. There are 60,000 training examples and 10,000 test examples.
PermutedMNIST( root: str, num_tasks: int, validation_percentage: float, batch_size: int = 1, num_workers: int = 8, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, permutation_mode: str = 'first_channel_only', permutation_seeds: list[int] | None = None)
37 def __init__( 38 self, 39 root: str, 40 num_tasks: int, 41 validation_percentage: float, 42 batch_size: int = 1, 43 num_workers: int = 8, 44 custom_transforms: Callable | transforms.Compose | None = None, 45 custom_target_transforms: Callable | transforms.Compose | None = None, 46 permutation_mode: str = "first_channel_only", 47 permutation_seeds: list[int] | None = None, 48 ) -> None: 49 r"""Initialise the Permuted MNIST dataset. 50 51 **Args:** 52 - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. 53 - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset. 54 - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data. 55 - **batch_size** (`int`): The batch size in train, val, test dataloader. 56 - **num_workers** (`int`): the number of workers for dataloaders. 57 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. 58 `ToTensor()`, normalise, permute and so on are not included. 59 - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. 60 - **permutation_mode** (`str`): the mode of permutation, should be one of the following: 61 1. 'all': permute all pixels. 62 2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order. 63 3. 'first_channel_only': permute only the first channel. 64 - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`. 65 """ 66 CLPermutedDataset.__init__( 67 self, 68 root=root, 69 num_tasks=num_tasks, 70 validation_percentage=validation_percentage, 71 batch_size=batch_size, 72 num_workers=num_workers, 73 custom_transforms=custom_transforms, 74 custom_target_transforms=custom_target_transforms, 75 permutation_mode=permutation_mode, 76 permutation_seeds=permutation_seeds, 77 )
Initialise the Permuted MNIST dataset.
Args:
- root (
str
): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live. - num_tasks (
int
): the maximum number of tasks supported by the CL dataset. - validation_percentage (
float
): the percentage to randomly split some of the training data into validation data. - batch_size (
int
): The batch size in train, val, test dataloader. - num_workers (
int
): the number of workers for dataloaders. - custom_transforms (
transform
ortransforms.Compose
orNone
): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor()
, normalise, permute and so on are not included. - custom_target_transforms (
transform
ortransforms.Compose
orNone
): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included. - permutation_mode (
str
): the mode of permutation, should be one of the following:- 'all': permute all pixels.
- 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
- 'first_channel_only': permute only the first channel.
- permutation_seeds (
list[int]
orNone
): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds asnum_tasks
. Default is None, which creates a list of seeds from 1 tonum_tasks
.
def
prepare_data(self) -> None:
79 def prepare_data(self) -> None: 80 r"""Download the original MNIST dataset if haven't.""" 81 # just download 82 MNIST(root=self.root, train=True, download=True) 83 MNIST(root=self.root, train=False, download=True) 84 85 pylogger.debug( 86 "The original MNIST dataset has been downloaded to %s.", self.root 87 )
Download the original MNIST dataset if haven't.
def
train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
89 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 90 """Get the training and validation dataset of task `self.task_id`. 91 92 **Returns:** 93 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`. 94 """ 95 dataset_train_and_val = MNIST( 96 root=self.root, 97 train=True, 98 transform=self.train_and_val_transforms(to_tensor=True), 99 download=False, 100 ) 101 return random_split( 102 dataset_train_and_val, 103 lengths=[1 - self.validation_percentage, self.validation_percentage], 104 )
Get the training and validation dataset of task self.task_id
.
Returns:
- train_and_val_dataset (
tuple[Dataset, Dataset]
): the train and validation dataset of taskself.task_id
.
def
test_dataset(self) -> torch.utils.data.dataset.Dataset:
106 def test_dataset(self) -> Dataset: 107 r"""Get the test dataset of task `self.task_id`. 108 109 **Returns:** 110 - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`. 111 """ 112 113 return MNIST( 114 root=self.root, 115 train=False, 116 transform=self.test_transforms(to_tensor=True), 117 download=False, 118 )
Get the test dataset of task self.task_id
.
Returns:
- test_dataset (
Dataset
): the test dataset of taskself.task_id
.