clarena.cl_datasets.permuted_mnist

The submodule in cl_datasets for Permuted MNIST dataset.

  1"""
  2The submodule in `cl_datasets` for Permuted MNIST dataset.
  3"""
  4
  5__all__ = ["PermutedMNIST"]
  6
  7import logging
  8from typing import Callable
  9
 10import torch
 11from torch.utils.data import Dataset, random_split
 12from torchvision.datasets import MNIST
 13from torchvision.transforms import transforms
 14
 15from clarena.cl_datasets import CLPermutedDataset
 16
 17# always get logger for built-in logging in each module
 18pylogger = logging.getLogger(__name__)
 19
 20
 21class PermutedMNIST(CLPermutedDataset):
 22    """Permuted MNIST dataset."""
 23
 24    num_class: int = 10
 25    """The number of classes in MNIST."""
 26
 27    img_size: torch.Size = torch.Size([1, 28, 28])
 28    """The size of MNIST images."""
 29
 30    mean_original: tuple[float] = (0.1307,)
 31    """The mean values for normalisation."""
 32
 33    std_original: tuple[float] = (0.3081,)
 34    """The standard deviation values for normalisation."""
 35
 36    def __init__(
 37        self,
 38        root: str,
 39        num_tasks: int,
 40        validation_percentage: float,
 41        batch_size: int = 1,
 42        num_workers: int = 10,
 43        custom_transforms: Callable | transforms.Compose | None = None,
 44        custom_target_transforms: Callable | transforms.Compose | None = None,
 45        permutation_mode: str = "first_channel_only",
 46        permutation_seeds: list[int] | None = None,
 47    ) -> None:
 48        """Initialise the Permuted MNIST dataset.
 49
 50        **Args:**
 51        - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live.
 52        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
 53        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
 54        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 55        - **num_workers** (`int`): the number of workers for dataloaders.
 56        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.
 57        `ToTensor()`, normalise, permute and so on are not included.
 58        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
 59        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
 60            1. 'all': permute all pixels.
 61            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
 62            3. 'first_channel_only': permute only the first channel.
 63        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
 64        """
 65        super().__init__(
 66            root=root,
 67            num_tasks=num_tasks,
 68            validation_percentage=validation_percentage,
 69            batch_size=batch_size,
 70            num_workers=num_workers,
 71            custom_transforms=custom_transforms,
 72            custom_target_transforms=custom_target_transforms,
 73            permutation_mode=permutation_mode,
 74            permutation_seeds=permutation_seeds,
 75        )
 76
 77    def prepare_data(self) -> None:
 78        """Download the original MNIST dataset if haven't."""
 79        # just download
 80        MNIST(root=self.root, train=True, download=True)
 81        MNIST(root=self.root, train=False, download=True)
 82
 83        pylogger.debug(
 84            "The original MNIST dataset has been downloaded to %s.", self.root
 85        )
 86
 87    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 88        """Get the training and validation dataset of task `self.task_id`.
 89
 90        **Returns:**
 91        - The train and validation dataset of task `self.task_id`."""
 92        dataset_train_and_val = MNIST(
 93            root=self.root,
 94            train=True,
 95            transform=self.train_and_val_transforms(to_tensor=True),
 96            target_transform=self.target_transforms(),
 97            download=False,
 98        )
 99        return random_split(
100            dataset_train_and_val,
101            lengths=[1 - self.validation_percentage, self.validation_percentage],
102        )
103
104    def test_dataset(self) -> Dataset:
105        """Get the test dataset of task `self.task_id`.
106
107        **Returns:**
108        - The test dataset of task `self.task_id`.
109        """
110        return MNIST(
111            root=self.root,
112            train=False,
113            transform=self.test_transforms(to_tensor=True),
114            target_transform=self.target_transforms(),
115            download=False,
116        )
class PermutedMNIST(clarena.cl_datasets.base.CLPermutedDataset):
 22class PermutedMNIST(CLPermutedDataset):
 23    """Permuted MNIST dataset."""
 24
 25    num_class: int = 10
 26    """The number of classes in MNIST."""
 27
 28    img_size: torch.Size = torch.Size([1, 28, 28])
 29    """The size of MNIST images."""
 30
 31    mean_original: tuple[float] = (0.1307,)
 32    """The mean values for normalisation."""
 33
 34    std_original: tuple[float] = (0.3081,)
 35    """The standard deviation values for normalisation."""
 36
 37    def __init__(
 38        self,
 39        root: str,
 40        num_tasks: int,
 41        validation_percentage: float,
 42        batch_size: int = 1,
 43        num_workers: int = 10,
 44        custom_transforms: Callable | transforms.Compose | None = None,
 45        custom_target_transforms: Callable | transforms.Compose | None = None,
 46        permutation_mode: str = "first_channel_only",
 47        permutation_seeds: list[int] | None = None,
 48    ) -> None:
 49        """Initialise the Permuted MNIST dataset.
 50
 51        **Args:**
 52        - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live.
 53        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
 54        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
 55        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 56        - **num_workers** (`int`): the number of workers for dataloaders.
 57        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.
 58        `ToTensor()`, normalise, permute and so on are not included.
 59        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
 60        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
 61            1. 'all': permute all pixels.
 62            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
 63            3. 'first_channel_only': permute only the first channel.
 64        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
 65        """
 66        super().__init__(
 67            root=root,
 68            num_tasks=num_tasks,
 69            validation_percentage=validation_percentage,
 70            batch_size=batch_size,
 71            num_workers=num_workers,
 72            custom_transforms=custom_transforms,
 73            custom_target_transforms=custom_target_transforms,
 74            permutation_mode=permutation_mode,
 75            permutation_seeds=permutation_seeds,
 76        )
 77
 78    def prepare_data(self) -> None:
 79        """Download the original MNIST dataset if haven't."""
 80        # just download
 81        MNIST(root=self.root, train=True, download=True)
 82        MNIST(root=self.root, train=False, download=True)
 83
 84        pylogger.debug(
 85            "The original MNIST dataset has been downloaded to %s.", self.root
 86        )
 87
 88    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 89        """Get the training and validation dataset of task `self.task_id`.
 90
 91        **Returns:**
 92        - The train and validation dataset of task `self.task_id`."""
 93        dataset_train_and_val = MNIST(
 94            root=self.root,
 95            train=True,
 96            transform=self.train_and_val_transforms(to_tensor=True),
 97            target_transform=self.target_transforms(),
 98            download=False,
 99        )
100        return random_split(
101            dataset_train_and_val,
102            lengths=[1 - self.validation_percentage, self.validation_percentage],
103        )
104
105    def test_dataset(self) -> Dataset:
106        """Get the test dataset of task `self.task_id`.
107
108        **Returns:**
109        - The test dataset of task `self.task_id`.
110        """
111        return MNIST(
112            root=self.root,
113            train=False,
114            transform=self.test_transforms(to_tensor=True),
115            target_transform=self.target_transforms(),
116            download=False,
117        )

Permuted MNIST dataset.

PermutedMNIST( root: str, num_tasks: int, validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, permutation_mode: str = 'first_channel_only', permutation_seeds: list[int] | None = None)
37    def __init__(
38        self,
39        root: str,
40        num_tasks: int,
41        validation_percentage: float,
42        batch_size: int = 1,
43        num_workers: int = 10,
44        custom_transforms: Callable | transforms.Compose | None = None,
45        custom_target_transforms: Callable | transforms.Compose | None = None,
46        permutation_mode: str = "first_channel_only",
47        permutation_seeds: list[int] | None = None,
48    ) -> None:
49        """Initialise the Permuted MNIST dataset.
50
51        **Args:**
52        - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live.
53        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
54        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
55        - **batch_size** (`int`): The batch size in train, val, test dataloader.
56        - **num_workers** (`int`): the number of workers for dataloaders.
57        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.
58        `ToTensor()`, normalise, permute and so on are not included.
59        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
60        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
61            1. 'all': permute all pixels.
62            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
63            3. 'first_channel_only': permute only the first channel.
64        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
65        """
66        super().__init__(
67            root=root,
68            num_tasks=num_tasks,
69            validation_percentage=validation_percentage,
70            batch_size=batch_size,
71            num_workers=num_workers,
72            custom_transforms=custom_transforms,
73            custom_target_transforms=custom_target_transforms,
74            permutation_mode=permutation_mode,
75            permutation_seeds=permutation_seeds,
76        )

Initialise the Permuted MNIST dataset.

Args:

  • root (str): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalise, permute and so on are not included.
  • custom_target_transforms (transform or transforms.Compose or None): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
  • permutation_mode (str): the mode of permutation, should be one of the following:
    1. 'all': permute all pixels.
    2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    3. 'first_channel_only': permute only the first channel.
  • permutation_seeds (list[int] or None): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as num_tasks. Default is None, which creates a list of seeds from 1 to num_tasks.
num_class: int = 10

The number of classes in MNIST.

img_size: torch.Size = torch.Size([1, 28, 28])

The size of MNIST images.

mean_original: tuple[float] = (0.1307,)

The mean values for normalisation.

std_original: tuple[float] = (0.3081,)

The standard deviation values for normalisation.

def prepare_data(self) -> None:
78    def prepare_data(self) -> None:
79        """Download the original MNIST dataset if haven't."""
80        # just download
81        MNIST(root=self.root, train=True, download=True)
82        MNIST(root=self.root, train=False, download=True)
83
84        pylogger.debug(
85            "The original MNIST dataset has been downloaded to %s.", self.root
86        )

Download the original MNIST dataset if haven't.

def train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
 88    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 89        """Get the training and validation dataset of task `self.task_id`.
 90
 91        **Returns:**
 92        - The train and validation dataset of task `self.task_id`."""
 93        dataset_train_and_val = MNIST(
 94            root=self.root,
 95            train=True,
 96            transform=self.train_and_val_transforms(to_tensor=True),
 97            target_transform=self.target_transforms(),
 98            download=False,
 99        )
100        return random_split(
101            dataset_train_and_val,
102            lengths=[1 - self.validation_percentage, self.validation_percentage],
103        )

Get the training and validation dataset of task self.task_id.

Returns:

  • The train and validation dataset of task self.task_id.
def test_dataset(self) -> torch.utils.data.dataset.Dataset:
105    def test_dataset(self) -> Dataset:
106        """Get the test dataset of task `self.task_id`.
107
108        **Returns:**
109        - The test dataset of task `self.task_id`.
110        """
111        return MNIST(
112            root=self.root,
113            train=False,
114            transform=self.test_transforms(to_tensor=True),
115            target_transform=self.target_transforms(),
116            download=False,
117        )

Get the test dataset of task self.task_id.

Returns:

  • The test dataset of task self.task_id.