clarena.cl_datasets.split_cifar100

The submodule in cl_datasets for Split CIFAR100 dataset.

  1"""
  2The submodule in `cl_datasets` for Split CIFAR100 dataset.
  3"""
  4
  5__all__ = ["SplitCIFAR100"]
  6
  7import logging
  8from typing import Callable
  9
 10from torch.utils.data import Dataset, random_split
 11from torchvision.datasets import CIFAR100
 12from torchvision.transforms import transforms
 13
 14from clarena.cl_datasets import CLSplitDataset
 15
 16# always get logger for built-in logging in each module
 17pylogger = logging.getLogger(__name__)
 18
 19
 20class SplitCIFAR100(CLSplitDataset):
 21    """Split CIFAR100 dataset."""
 22
 23    num_classes: int = 100
 24    """The number of classes in CIFAR100 dataset."""
 25
 26    mean_original: tuple[float] = (0.1307,)
 27    """The mean values for normalisation."""
 28
 29    std_original: tuple[float] = (0.3081,)
 30    """The standard deviation values for normalisation."""
 31
 32    def __init__(
 33        self,
 34        root: str,
 35        num_tasks: int,
 36        class_split: list[list[int]],
 37        validation_percentage: float,
 38        batch_size: int = 1,
 39        num_workers: int = 10,
 40        custom_transforms: Callable | transforms.Compose | None = None,
 41        custom_target_transforms: Callable | transforms.Compose | None = None,
 42    ) -> None:
 43        """Initialise the Permuted MNIST dataset.
 44
 45        **Args:**
 46        - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live.
 47        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
 48        - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
 49        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
 50        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 51        - **num_workers** (`int`): the number of workers for dataloaders.
 52        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.
 53        `ToTensor()`, normalise, permute and so on are not included.
 54        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
 55        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
 56            1. 'all': permute all pixels.
 57            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
 58            3. 'first_channel_only': permute only the first channel.
 59        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
 60        """
 61        super().__init__(
 62            root=root,
 63            num_tasks=num_tasks,
 64            class_split=class_split,
 65            validation_percentage=validation_percentage,
 66            batch_size=batch_size,
 67            num_workers=num_workers,
 68            custom_transforms=custom_transforms,
 69            custom_target_transforms=custom_target_transforms,
 70        )
 71
 72    def prepare_data(self) -> None:
 73        """Download the original MNIST dataset if haven't."""
 74        # just download
 75        CIFAR100(root=self.root, train=True, download=True)
 76        CIFAR100(root=self.root, train=False, download=True)
 77
 78        pylogger.debug(
 79            "The original CIFAR100 dataset has been downloaded to %s.", self.root
 80        )
 81
 82    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 83        """Get the training and validation dataset of task `self.task_id`.
 84
 85        **Returns:**
 86        - The train and validation dataset of task `self.task_id`."""
 87        dataset_train_and_val = self.get_class_subset(
 88            CIFAR100(
 89                root=self.root,
 90                train=True,
 91                transform=self.train_and_val_transforms(to_tensor=True),
 92                target_transform=self.target_transforms(),
 93                download=False,
 94            ),
 95            classes=self.class_split[self.task_id - 1],
 96        )
 97        return random_split(
 98            dataset_train_and_val,
 99            lengths=[1 - self.validation_percentage, self.validation_percentage],
100        )
101
102    def test_dataset(self) -> Dataset:
103        """Get the test dataset of task `self.task_id`.
104
105        **Returns:**
106        - The test dataset of task `self.task_id`.
107        """
108        return self.get_class_subset(
109            CIFAR100(
110                root=self.root,
111                train=False,
112                transform=self.test_transforms(to_tensor=True),
113                target_transform=self.target_transforms(),
114                download=False,
115            ),
116            classes=self.class_split[self.task_id - 1],
117        )
class SplitCIFAR100(clarena.cl_datasets.base.CLSplitDataset):
 21class SplitCIFAR100(CLSplitDataset):
 22    """Split CIFAR100 dataset."""
 23
 24    num_classes: int = 100
 25    """The number of classes in CIFAR100 dataset."""
 26
 27    mean_original: tuple[float] = (0.1307,)
 28    """The mean values for normalisation."""
 29
 30    std_original: tuple[float] = (0.3081,)
 31    """The standard deviation values for normalisation."""
 32
 33    def __init__(
 34        self,
 35        root: str,
 36        num_tasks: int,
 37        class_split: list[list[int]],
 38        validation_percentage: float,
 39        batch_size: int = 1,
 40        num_workers: int = 10,
 41        custom_transforms: Callable | transforms.Compose | None = None,
 42        custom_target_transforms: Callable | transforms.Compose | None = None,
 43    ) -> None:
 44        """Initialise the Permuted MNIST dataset.
 45
 46        **Args:**
 47        - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live.
 48        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
 49        - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
 50        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
 51        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 52        - **num_workers** (`int`): the number of workers for dataloaders.
 53        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.
 54        `ToTensor()`, normalise, permute and so on are not included.
 55        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
 56        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
 57            1. 'all': permute all pixels.
 58            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
 59            3. 'first_channel_only': permute only the first channel.
 60        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
 61        """
 62        super().__init__(
 63            root=root,
 64            num_tasks=num_tasks,
 65            class_split=class_split,
 66            validation_percentage=validation_percentage,
 67            batch_size=batch_size,
 68            num_workers=num_workers,
 69            custom_transforms=custom_transforms,
 70            custom_target_transforms=custom_target_transforms,
 71        )
 72
 73    def prepare_data(self) -> None:
 74        """Download the original MNIST dataset if haven't."""
 75        # just download
 76        CIFAR100(root=self.root, train=True, download=True)
 77        CIFAR100(root=self.root, train=False, download=True)
 78
 79        pylogger.debug(
 80            "The original CIFAR100 dataset has been downloaded to %s.", self.root
 81        )
 82
 83    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 84        """Get the training and validation dataset of task `self.task_id`.
 85
 86        **Returns:**
 87        - The train and validation dataset of task `self.task_id`."""
 88        dataset_train_and_val = self.get_class_subset(
 89            CIFAR100(
 90                root=self.root,
 91                train=True,
 92                transform=self.train_and_val_transforms(to_tensor=True),
 93                target_transform=self.target_transforms(),
 94                download=False,
 95            ),
 96            classes=self.class_split[self.task_id - 1],
 97        )
 98        return random_split(
 99            dataset_train_and_val,
100            lengths=[1 - self.validation_percentage, self.validation_percentage],
101        )
102
103    def test_dataset(self) -> Dataset:
104        """Get the test dataset of task `self.task_id`.
105
106        **Returns:**
107        - The test dataset of task `self.task_id`.
108        """
109        return self.get_class_subset(
110            CIFAR100(
111                root=self.root,
112                train=False,
113                transform=self.test_transforms(to_tensor=True),
114                target_transform=self.target_transforms(),
115                download=False,
116            ),
117            classes=self.class_split[self.task_id - 1],
118        )

Split CIFAR100 dataset.

SplitCIFAR100( root: str, num_tasks: int, class_split: list[list[int]], validation_percentage: float, batch_size: int = 1, num_workers: int = 10, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, custom_target_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None)
33    def __init__(
34        self,
35        root: str,
36        num_tasks: int,
37        class_split: list[list[int]],
38        validation_percentage: float,
39        batch_size: int = 1,
40        num_workers: int = 10,
41        custom_transforms: Callable | transforms.Compose | None = None,
42        custom_target_transforms: Callable | transforms.Compose | None = None,
43    ) -> None:
44        """Initialise the Permuted MNIST dataset.
45
46        **Args:**
47        - **root** (`str`): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live.
48        - **num_tasks** (`int`): the maximum number of tasks supported by the CL dataset.
49        - **class_split** (`list[list[int]]`): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
50        - **validation_percentage** (`float`): the percentage to randomly split some of the training data into validation data.
51        - **batch_size** (`int`): The batch size in train, val, test dataloader.
52        - **num_workers** (`int`): the number of workers for dataloaders.
53        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.
54        `ToTensor()`, normalise, permute and so on are not included.
55        - **custom_target_transforms** (`transform` or `transforms.Compose` or `None`): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
56        - **permutation_mode** (`str`): the mode of permutation, should be one of the following:
57            1. 'all': permute all pixels.
58            2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
59            3. 'first_channel_only': permute only the first channel.
60        - **permutation_seeds** (`list[int]` or `None`): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as `num_tasks`. Default is None, which creates a list of seeds from 1 to `num_tasks`.
61        """
62        super().__init__(
63            root=root,
64            num_tasks=num_tasks,
65            class_split=class_split,
66            validation_percentage=validation_percentage,
67            batch_size=batch_size,
68            num_workers=num_workers,
69            custom_transforms=custom_transforms,
70            custom_target_transforms=custom_target_transforms,
71        )

Initialise the Permuted MNIST dataset.

Args:

  • root (str): the root directory where the original MNIST data 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' live.
  • num_tasks (int): the maximum number of tasks supported by the CL dataset.
  • class_split (list[list[int]]): the class split for each task. Each element in the list is a list of class labels (integers starting from 0) to split for a task.
  • validation_percentage (float): the percentage to randomly split some of the training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalise, permute and so on are not included.
  • custom_target_transforms (transform or transforms.Compose or None): the custom target transforms to apply to dataset labels. Can be a single transform, composed transforms or no transform. CL class mapping is not included.
  • permutation_mode (str): the mode of permutation, should be one of the following:
    1. 'all': permute all pixels.
    2. 'by_channel': permute channel by channel separately. All channels are applied the same permutation order.
    3. 'first_channel_only': permute only the first channel.
  • permutation_seeds (list[int] or None): the seeds for permutation operations used to construct tasks. Make sure it has the same number of seeds as num_tasks. Default is None, which creates a list of seeds from 1 to num_tasks.
num_classes: int = 100

The number of classes in CIFAR100 dataset.

mean_original: tuple[float] = (0.1307,)

The mean values for normalisation.

std_original: tuple[float] = (0.3081,)

The standard deviation values for normalisation.

def prepare_data(self) -> None:
73    def prepare_data(self) -> None:
74        """Download the original MNIST dataset if haven't."""
75        # just download
76        CIFAR100(root=self.root, train=True, download=True)
77        CIFAR100(root=self.root, train=False, download=True)
78
79        pylogger.debug(
80            "The original CIFAR100 dataset has been downloaded to %s.", self.root
81        )

Download the original MNIST dataset if haven't.

def train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
 83    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 84        """Get the training and validation dataset of task `self.task_id`.
 85
 86        **Returns:**
 87        - The train and validation dataset of task `self.task_id`."""
 88        dataset_train_and_val = self.get_class_subset(
 89            CIFAR100(
 90                root=self.root,
 91                train=True,
 92                transform=self.train_and_val_transforms(to_tensor=True),
 93                target_transform=self.target_transforms(),
 94                download=False,
 95            ),
 96            classes=self.class_split[self.task_id - 1],
 97        )
 98        return random_split(
 99            dataset_train_and_val,
100            lengths=[1 - self.validation_percentage, self.validation_percentage],
101        )

Get the training and validation dataset of task self.task_id.

Returns:

  • The train and validation dataset of task self.task_id.
def test_dataset(self) -> torch.utils.data.dataset.Dataset:
103    def test_dataset(self) -> Dataset:
104        """Get the test dataset of task `self.task_id`.
105
106        **Returns:**
107        - The test dataset of task `self.task_id`.
108        """
109        return self.get_class_subset(
110            CIFAR100(
111                root=self.root,
112                train=False,
113                transform=self.test_transforms(to_tensor=True),
114                target_transform=self.target_transforms(),
115                download=False,
116            ),
117            classes=self.class_split[self.task_id - 1],
118        )

Get the test dataset of task self.task_id.

Returns:

  • The test dataset of task self.task_id.