clarena.cl_datasets.split_cifar10

The submodule in cl_datasets for Split CIFAR-10 dataset.

  1r"""
  2The submodule in `cl_datasets` for Split CIFAR-10 dataset.
  3"""
  4
  5__all__ = ["SplitCIFAR10"]
  6
  7import logging
  8from typing import Callable
  9
 10import torch
 11from torch.utils.data import Dataset, random_split
 12from torchvision.datasets import CIFAR10
 13from torchvision.transforms import transforms
 14
 15from clarena.cl_datasets import CLSplitDataset
 16
 17# always get logger for built-in logging in each module
 18pylogger = logging.getLogger(__name__)
 19
 20
 21class SplitCIFAR10(CLSplitDataset):
 22    r"""Split CIFAR-10 dataset. The [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) is a subset of the [80 million tiny images dataset](https://people.csail.mit.edu/torralba/tinyimages/). It consists of 50,000 training and 10,000 test images of 10 classes, each 32x32 color image."""
 23
 24    original_dataset_python_class: type[Dataset] = CIFAR10
 25    r"""The original dataset class."""
 26
 27    def __init__(
 28        self,
 29        root: str,
 30        class_split: dict[int, list[int]],
 31        validation_percentage: float,
 32        batch_size: int | dict[int, int] = 1,
 33        num_workers: int | dict[int, int] = 0,
 34        custom_transforms: (
 35            Callable
 36            | transforms.Compose
 37            | None
 38            | dict[int, Callable | transforms.Compose | None]
 39        ) = None,
 40        repeat_channels: int | None | dict[int, int | None] = None,
 41        to_tensor: bool | dict[int, bool] = True,
 42        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 43    ) -> None:
 44        r"""
 45        **Args:**
 46        - **root** (`str`): the root directory where the original CIFAR-10 data 'cifar-10-python/' live.
 47        - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
 48        - **validation_percentage** (`float`): The percentage to randomly split some training data into validation data.
 49        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
 50        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
 51        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
 52        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
 53        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
 54        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 55        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 56        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 57        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 58        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 59        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 60        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 61        """
 62        super().__init__(
 63            root=root,
 64            class_split=class_split,
 65            batch_size=batch_size,
 66            num_workers=num_workers,
 67            custom_transforms=custom_transforms,
 68            repeat_channels=repeat_channels,
 69            to_tensor=to_tensor,
 70            resize=resize,
 71        )
 72
 73        self.validation_percentage: float = validation_percentage
 74        r"""The percentage to randomly split some training data into validation data."""
 75
 76    def prepare_data(self) -> None:
 77        r"""Download the original CIFAR-10 dataset if haven't."""
 78
 79        if self.task_id != 1:
 80            return  # download all original datasets only at the beginning of first task
 81
 82        CIFAR10(root=self.root_t, train=True, download=True)
 83        CIFAR10(root=self.root_t, train=False, download=True)
 84
 85        pylogger.debug(
 86            "The original CIFAR-10 dataset has been downloaded to %s.", self.root
 87        )
 88
 89    def get_subset_of_classes(self, dataset: Dataset) -> Dataset:
 90        r"""Get a subset of classes from the dataset of current classes of `self.task_id`. It is used when constructing the split. It must be implemented by subclasses.
 91
 92        **Args:**
 93        - **dataset** (`Dataset`): the dataset to retrieve subset from.
 94
 95        **Returns:**
 96        - **subset** (`Dataset`): the subset of classes from the dataset.
 97        """
 98        classes = self.class_split[self.task_id]
 99
100        # get the indices of the dataset that belong to the classes
101        idx = [i for i, (_, target) in enumerate(dataset) if target in classes]
102
103        # subset the dataset by the indices, in-place operation
104        dataset.data = dataset.data[idx]  # data is a Numpy ndarray
105        dataset.targets = [dataset.targets[i] for i in idx]  # targets is a list
106
107        dataset.target_transform = self.target_transform()  # cl class mapping should be applied after the split
108
109        return dataset
110
111    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
112        r"""Get the training and validation dataset of task `self.task_id`.
113
114        **Returns:**
115        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`.
116        """
117        dataset_train_and_val = self.get_subset_of_classes(
118            CIFAR10(
119                root=self.root_t,
120                train=True,
121                transform=self.train_and_val_transforms(),
122                # cl class mapping should be applied after the split
123                download=False,
124            )
125        )
126
127        return random_split(
128            dataset_train_and_val,
129            lengths=[1 - self.validation_percentage, self.validation_percentage],
130            generator=torch.Generator().manual_seed(
131                42
132            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
133        )
134
135    def test_dataset(self) -> Dataset:
136        r"""Get the test dataset of task `self.task_id`.
137
138        **Returns:**
139        - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`.
140        """
141        dataset_test = self.get_subset_of_classes(
142            CIFAR10(
143                root=self.root_t,
144                train=False,
145                transform=self.test_transforms(),
146                # cl class mapping should be applied after the split
147                download=False,
148            )
149        )
150
151        return dataset_test
class SplitCIFAR10(clarena.cl_datasets.base.CLSplitDataset):
 22class SplitCIFAR10(CLSplitDataset):
 23    r"""Split CIFAR-10 dataset. The [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) is a subset of the [80 million tiny images dataset](https://people.csail.mit.edu/torralba/tinyimages/). It consists of 50,000 training and 10,000 test images of 10 classes, each 32x32 color image."""
 24
 25    original_dataset_python_class: type[Dataset] = CIFAR10
 26    r"""The original dataset class."""
 27
 28    def __init__(
 29        self,
 30        root: str,
 31        class_split: dict[int, list[int]],
 32        validation_percentage: float,
 33        batch_size: int | dict[int, int] = 1,
 34        num_workers: int | dict[int, int] = 0,
 35        custom_transforms: (
 36            Callable
 37            | transforms.Compose
 38            | None
 39            | dict[int, Callable | transforms.Compose | None]
 40        ) = None,
 41        repeat_channels: int | None | dict[int, int | None] = None,
 42        to_tensor: bool | dict[int, bool] = True,
 43        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
 44    ) -> None:
 45        r"""
 46        **Args:**
 47        - **root** (`str`): the root directory where the original CIFAR-10 data 'cifar-10-python/' live.
 48        - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
 49        - **validation_percentage** (`float`): The percentage to randomly split some training data into validation data.
 50        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
 51        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
 52        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
 53        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
 54        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
 55        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
 56        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
 57        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
 58        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
 59        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
 60        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
 61        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
 62        """
 63        super().__init__(
 64            root=root,
 65            class_split=class_split,
 66            batch_size=batch_size,
 67            num_workers=num_workers,
 68            custom_transforms=custom_transforms,
 69            repeat_channels=repeat_channels,
 70            to_tensor=to_tensor,
 71            resize=resize,
 72        )
 73
 74        self.validation_percentage: float = validation_percentage
 75        r"""The percentage to randomly split some training data into validation data."""
 76
 77    def prepare_data(self) -> None:
 78        r"""Download the original CIFAR-10 dataset if haven't."""
 79
 80        if self.task_id != 1:
 81            return  # download all original datasets only at the beginning of first task
 82
 83        CIFAR10(root=self.root_t, train=True, download=True)
 84        CIFAR10(root=self.root_t, train=False, download=True)
 85
 86        pylogger.debug(
 87            "The original CIFAR-10 dataset has been downloaded to %s.", self.root
 88        )
 89
 90    def get_subset_of_classes(self, dataset: Dataset) -> Dataset:
 91        r"""Get a subset of classes from the dataset of current classes of `self.task_id`. It is used when constructing the split. It must be implemented by subclasses.
 92
 93        **Args:**
 94        - **dataset** (`Dataset`): the dataset to retrieve subset from.
 95
 96        **Returns:**
 97        - **subset** (`Dataset`): the subset of classes from the dataset.
 98        """
 99        classes = self.class_split[self.task_id]
100
101        # get the indices of the dataset that belong to the classes
102        idx = [i for i, (_, target) in enumerate(dataset) if target in classes]
103
104        # subset the dataset by the indices, in-place operation
105        dataset.data = dataset.data[idx]  # data is a Numpy ndarray
106        dataset.targets = [dataset.targets[i] for i in idx]  # targets is a list
107
108        dataset.target_transform = self.target_transform()  # cl class mapping should be applied after the split
109
110        return dataset
111
112    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
113        r"""Get the training and validation dataset of task `self.task_id`.
114
115        **Returns:**
116        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`.
117        """
118        dataset_train_and_val = self.get_subset_of_classes(
119            CIFAR10(
120                root=self.root_t,
121                train=True,
122                transform=self.train_and_val_transforms(),
123                # cl class mapping should be applied after the split
124                download=False,
125            )
126        )
127
128        return random_split(
129            dataset_train_and_val,
130            lengths=[1 - self.validation_percentage, self.validation_percentage],
131            generator=torch.Generator().manual_seed(
132                42
133            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
134        )
135
136    def test_dataset(self) -> Dataset:
137        r"""Get the test dataset of task `self.task_id`.
138
139        **Returns:**
140        - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`.
141        """
142        dataset_test = self.get_subset_of_classes(
143            CIFAR10(
144                root=self.root_t,
145                train=False,
146                transform=self.test_transforms(),
147                # cl class mapping should be applied after the split
148                download=False,
149            )
150        )
151
152        return dataset_test

Split CIFAR-10 dataset. The CIFAR-10 dataset is a subset of the 80 million tiny images dataset. It consists of 50,000 training and 10,000 test images of 10 classes, each 32x32 color image.

SplitCIFAR10( root: str, class_split: dict[int, list[int]], validation_percentage: float, batch_size: int | dict[int, int] = 1, num_workers: int | dict[int, int] = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType, dict[int, Union[Callable, torchvision.transforms.transforms.Compose, NoneType]]] = None, repeat_channels: int | None | dict[int, int | None] = None, to_tensor: bool | dict[int, bool] = True, resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None)
28    def __init__(
29        self,
30        root: str,
31        class_split: dict[int, list[int]],
32        validation_percentage: float,
33        batch_size: int | dict[int, int] = 1,
34        num_workers: int | dict[int, int] = 0,
35        custom_transforms: (
36            Callable
37            | transforms.Compose
38            | None
39            | dict[int, Callable | transforms.Compose | None]
40        ) = None,
41        repeat_channels: int | None | dict[int, int | None] = None,
42        to_tensor: bool | dict[int, bool] = True,
43        resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None,
44    ) -> None:
45        r"""
46        **Args:**
47        - **root** (`str`): the root directory where the original CIFAR-10 data 'cifar-10-python/' live.
48        - **class_split** (`dict[int, list[int]]`): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
49        - **validation_percentage** (`float`): The percentage to randomly split some training data into validation data.
50        - **batch_size** (`int` | `dict[int, int]`): the batch size for train, val, and test dataloaders.
51        If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an `int`, it is the same batch size for all tasks.
52        - **num_workers** (`int` | `dict[int, int]`): the number of workers for dataloaders.
53        If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an `int`, it is the same number of workers for all tasks.
54        - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, permute, and so on are not included.
55        If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied.
56        - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat.
57        If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied.
58        - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`.
59        If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks.
60        - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize.
61        If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied.
62        """
63        super().__init__(
64            root=root,
65            class_split=class_split,
66            batch_size=batch_size,
67            num_workers=num_workers,
68            custom_transforms=custom_transforms,
69            repeat_channels=repeat_channels,
70            to_tensor=to_tensor,
71            resize=resize,
72        )
73
74        self.validation_percentage: float = validation_percentage
75        r"""The percentage to randomly split some training data into validation data."""

Args:

  • root (str): the root directory where the original CIFAR-10 data 'cifar-10-python/' live.
  • class_split (dict[int, list[int]]): the dict of classes for each task. The keys are task IDs ane the values are lists of class labels (integers starting from 0) to split for each task.
  • validation_percentage (float): The percentage to randomly split some training data into validation data.
  • batch_size (int | dict[int, int]): the batch size for train, val, and test dataloaders. If it is a dict, the keys are task IDs and the values are the batch sizes for each task. If it is an int, it is the same batch size for all tasks.
  • num_workers (int | dict[int, int]): the number of workers for dataloaders. If it is a dict, the keys are task IDs and the values are the number of workers for each task. If it is an int, it is the same number of workers for all tasks.
  • custom_transforms (transform or transforms.Compose or None or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. ToTensor(), normalization, permute, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is None, no custom transforms are applied.
  • repeat_channels (int | None | dict of them): the number of channels to repeat for each task. Default is None, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an int, it is the same number of channels to repeat for all tasks. If it is None, no repeat is applied.
  • to_tensor (bool | dict[int, bool]): whether to include the ToTensor() transform. Default is True. If it is a dict, the keys are task IDs and the values are whether to include the ToTensor() transform for each task. If it is a single boolean value, it is applied to all tasks.
  • resize (tuple[int, int] | None or dict of them): the size to resize the images to. Default is None, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is None, no resize is applied.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset] = <class 'torchvision.datasets.cifar.CIFAR10'>

The original dataset class.

validation_percentage: float

The percentage to randomly split some training data into validation data.

def prepare_data(self) -> None:
77    def prepare_data(self) -> None:
78        r"""Download the original CIFAR-10 dataset if haven't."""
79
80        if self.task_id != 1:
81            return  # download all original datasets only at the beginning of first task
82
83        CIFAR10(root=self.root_t, train=True, download=True)
84        CIFAR10(root=self.root_t, train=False, download=True)
85
86        pylogger.debug(
87            "The original CIFAR-10 dataset has been downloaded to %s.", self.root
88        )

Download the original CIFAR-10 dataset if haven't.

def get_subset_of_classes( self, dataset: torch.utils.data.dataset.Dataset) -> torch.utils.data.dataset.Dataset:
 90    def get_subset_of_classes(self, dataset: Dataset) -> Dataset:
 91        r"""Get a subset of classes from the dataset of current classes of `self.task_id`. It is used when constructing the split. It must be implemented by subclasses.
 92
 93        **Args:**
 94        - **dataset** (`Dataset`): the dataset to retrieve subset from.
 95
 96        **Returns:**
 97        - **subset** (`Dataset`): the subset of classes from the dataset.
 98        """
 99        classes = self.class_split[self.task_id]
100
101        # get the indices of the dataset that belong to the classes
102        idx = [i for i, (_, target) in enumerate(dataset) if target in classes]
103
104        # subset the dataset by the indices, in-place operation
105        dataset.data = dataset.data[idx]  # data is a Numpy ndarray
106        dataset.targets = [dataset.targets[i] for i in idx]  # targets is a list
107
108        dataset.target_transform = self.target_transform()  # cl class mapping should be applied after the split
109
110        return dataset

Get a subset of classes from the dataset of current classes of self.task_id. It is used when constructing the split. It must be implemented by subclasses.

Args:

  • dataset (Dataset): the dataset to retrieve subset from.

Returns:

  • subset (Dataset): the subset of classes from the dataset.
def train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
112    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
113        r"""Get the training and validation dataset of task `self.task_id`.
114
115        **Returns:**
116        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset of task `self.task_id`.
117        """
118        dataset_train_and_val = self.get_subset_of_classes(
119            CIFAR10(
120                root=self.root_t,
121                train=True,
122                transform=self.train_and_val_transforms(),
123                # cl class mapping should be applied after the split
124                download=False,
125            )
126        )
127
128        return random_split(
129            dataset_train_and_val,
130            lengths=[1 - self.validation_percentage, self.validation_percentage],
131            generator=torch.Generator().manual_seed(
132                42
133            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
134        )

Get the training and validation dataset of task self.task_id.

Returns:

  • train_and_val_dataset (tuple[Dataset, Dataset]): the train and validation dataset of task self.task_id.
def test_dataset(self) -> torch.utils.data.dataset.Dataset:
136    def test_dataset(self) -> Dataset:
137        r"""Get the test dataset of task `self.task_id`.
138
139        **Returns:**
140        - **test_dataset** (`Dataset`): the test dataset of task `self.task_id`.
141        """
142        dataset_test = self.get_subset_of_classes(
143            CIFAR10(
144                root=self.root_t,
145                train=False,
146                transform=self.test_transforms(),
147                # cl class mapping should be applied after the split
148                download=False,
149            )
150        )
151
152        return dataset_test

Get the test dataset of task self.task_id.

Returns:

  • test_dataset (Dataset): the test dataset of task self.task_id.