clarena.stl_datasets.emnist

The submodule in stl_datasets for EMNIST dataset.

  1r"""
  2The submodule in `stl_datasets` for EMNIST dataset.
  3"""
  4
  5__all__ = ["EMNIST"]
  6
  7import logging
  8from typing import Callable
  9
 10import torch
 11from torch.utils.data import Dataset, random_split
 12from torchvision.datasets import EMNIST as EMNISTRaw
 13from torchvision.transforms import transforms
 14
 15from clarena.stl_datasets.base import STLDatasetFromRaw
 16from clarena.stl_datasets.raw import (
 17    EMNISTBalanced,
 18    EMNISTByClass,
 19    EMNISTByMerge,
 20    EMNISTDigits,
 21    EMNISTLetters,
 22)
 23
 24# always get logger for built-in logging in each module
 25pylogger = logging.getLogger(__name__)
 26
 27
 28class EMNIST(STLDatasetFromRaw):
 29    r"""EMNIST dataset. The [EMNIST dataset](https://www.nist.gov/itl/products-and-services/emnist-dataset/) is a collection of handwritten letters and digits (including A-Z, a-z, 0-9). It consists of 814,255 images in 62 classes, each 28x28 grayscale image.
 30
 31    EMNIST has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`, each containing a different subset of the original collection. We support all of them in Permuted EMNIST.
 32    """
 33
 34    def __init__(
 35        self,
 36        root: str,
 37        split: str,
 38        validation_percentage: float,
 39        batch_size: int = 1,
 40        num_workers: int = 0,
 41        custom_transforms: Callable | transforms.Compose | None = None,
 42        repeat_channels: int | None = None,
 43        to_tensor: bool = True,
 44        resize: tuple[int, int] | None = None,
 45    ) -> None:
 46        r"""
 47        **Args:**
 48        - **root** (`str`): the root directory where the original EMNIST data 'EMNIST/' live.
 49        - **split** (`str`): the original EMNIST dataset has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`. This argument specifies which one to use.
 50        - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data.
 51        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 52        - **num_workers** (`int`): the number of workers for dataloaders.
 53        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
 54        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
 55        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
 56        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
 57        """
 58        if split == "byclass":
 59            self.original_dataset_python_class: type[Dataset] = EMNISTByClass
 60        elif split == "bymerge":
 61            self.original_dataset_python_class: type[Dataset] = EMNISTByMerge
 62        elif split == "balanced":
 63            self.original_dataset_python_class: type[Dataset] = EMNISTBalanced
 64        elif split == "letters":
 65            self.original_dataset_python_class: type[Dataset] = EMNISTLetters
 66        elif split == "digits":
 67            self.original_dataset_python_class: type[Dataset] = EMNISTDigits
 68            r"""The original dataset class."""
 69
 70        super().__init__(
 71            root=root,
 72            batch_size=batch_size,
 73            num_workers=num_workers,
 74            custom_transforms=custom_transforms,
 75            repeat_channels=repeat_channels,
 76            to_tensor=to_tensor,
 77            resize=resize,
 78        )
 79
 80        self.split: str = split
 81        r"""The split of the original EMNIST dataset. It can be `byclass`, `bymerge`, `balanced`, `letters`, `digits` or `mnist`."""
 82
 83        self.validation_percentage: float = validation_percentage
 84        r"""The percentage to randomly split some training data into validation data."""
 85
 86    def prepare_data(self) -> None:
 87        r"""Download the original EMNIST dataset if haven't."""
 88
 89        EMNISTRaw(root=self.root, split=self.split, train=True, download=True)
 90        EMNISTRaw(root=self.root, split=self.split, train=False, download=True)
 91
 92        pylogger.debug(
 93            "The original EMNIST dataset has been downloaded to %s.", self.root
 94        )
 95
 96    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 97        """Get the training and validation dataset.
 98
 99        **Returns:**
100        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
101        """
102        dataset_train_and_val = EMNISTRaw(
103            root=self.root,
104            split=self.split,
105            train=True,
106            transform=self.train_and_val_transforms(),
107            target_transform=self.target_transform(),
108            download=False,
109        )
110
111        return random_split(
112            dataset_train_and_val,
113            lengths=[1 - self.validation_percentage, self.validation_percentage],
114            generator=torch.Generator().manual_seed(
115                42
116            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
117        )
118
119    def test_dataset(self) -> Dataset:
120        r"""Get the test dataset.
121
122        **Returns:**
123        - **test_dataset** (`Dataset`): the test dataset.
124        """
125        dataset_test = EMNISTRaw(
126            root=self.root,
127            split=self.split,
128            train=False,
129            transform=self.test_transforms(),
130            target_transform=self.target_transform(),
131            download=False,
132        )
133
134        return dataset_test
class EMNIST(clarena.stl_datasets.base.STLDatasetFromRaw):
 29class EMNIST(STLDatasetFromRaw):
 30    r"""EMNIST dataset. The [EMNIST dataset](https://www.nist.gov/itl/products-and-services/emnist-dataset/) is a collection of handwritten letters and digits (including A-Z, a-z, 0-9). It consists of 814,255 images in 62 classes, each 28x28 grayscale image.
 31
 32    EMNIST has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`, each containing a different subset of the original collection. We support all of them in Permuted EMNIST.
 33    """
 34
 35    def __init__(
 36        self,
 37        root: str,
 38        split: str,
 39        validation_percentage: float,
 40        batch_size: int = 1,
 41        num_workers: int = 0,
 42        custom_transforms: Callable | transforms.Compose | None = None,
 43        repeat_channels: int | None = None,
 44        to_tensor: bool = True,
 45        resize: tuple[int, int] | None = None,
 46    ) -> None:
 47        r"""
 48        **Args:**
 49        - **root** (`str`): the root directory where the original EMNIST data 'EMNIST/' live.
 50        - **split** (`str`): the original EMNIST dataset has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`. This argument specifies which one to use.
 51        - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data.
 52        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 53        - **num_workers** (`int`): the number of workers for dataloaders.
 54        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
 55        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
 56        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
 57        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
 58        """
 59        if split == "byclass":
 60            self.original_dataset_python_class: type[Dataset] = EMNISTByClass
 61        elif split == "bymerge":
 62            self.original_dataset_python_class: type[Dataset] = EMNISTByMerge
 63        elif split == "balanced":
 64            self.original_dataset_python_class: type[Dataset] = EMNISTBalanced
 65        elif split == "letters":
 66            self.original_dataset_python_class: type[Dataset] = EMNISTLetters
 67        elif split == "digits":
 68            self.original_dataset_python_class: type[Dataset] = EMNISTDigits
 69            r"""The original dataset class."""
 70
 71        super().__init__(
 72            root=root,
 73            batch_size=batch_size,
 74            num_workers=num_workers,
 75            custom_transforms=custom_transforms,
 76            repeat_channels=repeat_channels,
 77            to_tensor=to_tensor,
 78            resize=resize,
 79        )
 80
 81        self.split: str = split
 82        r"""The split of the original EMNIST dataset. It can be `byclass`, `bymerge`, `balanced`, `letters`, `digits` or `mnist`."""
 83
 84        self.validation_percentage: float = validation_percentage
 85        r"""The percentage to randomly split some training data into validation data."""
 86
 87    def prepare_data(self) -> None:
 88        r"""Download the original EMNIST dataset if haven't."""
 89
 90        EMNISTRaw(root=self.root, split=self.split, train=True, download=True)
 91        EMNISTRaw(root=self.root, split=self.split, train=False, download=True)
 92
 93        pylogger.debug(
 94            "The original EMNIST dataset has been downloaded to %s.", self.root
 95        )
 96
 97    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 98        """Get the training and validation dataset.
 99
100        **Returns:**
101        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
102        """
103        dataset_train_and_val = EMNISTRaw(
104            root=self.root,
105            split=self.split,
106            train=True,
107            transform=self.train_and_val_transforms(),
108            target_transform=self.target_transform(),
109            download=False,
110        )
111
112        return random_split(
113            dataset_train_and_val,
114            lengths=[1 - self.validation_percentage, self.validation_percentage],
115            generator=torch.Generator().manual_seed(
116                42
117            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
118        )
119
120    def test_dataset(self) -> Dataset:
121        r"""Get the test dataset.
122
123        **Returns:**
124        - **test_dataset** (`Dataset`): the test dataset.
125        """
126        dataset_test = EMNISTRaw(
127            root=self.root,
128            split=self.split,
129            train=False,
130            transform=self.test_transforms(),
131            target_transform=self.target_transform(),
132            download=False,
133        )
134
135        return dataset_test

EMNIST dataset. The EMNIST dataset is a collection of handwritten letters and digits (including A-Z, a-z, 0-9). It consists of 814,255 images in 62 classes, each 28x28 grayscale image.

EMNIST has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist, each containing a different subset of the original collection. We support all of them in Permuted EMNIST.

EMNIST( root: str, split: str, validation_percentage: float, batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, repeat_channels: int | None = None, to_tensor: bool = True, resize: tuple[int, int] | None = None)
35    def __init__(
36        self,
37        root: str,
38        split: str,
39        validation_percentage: float,
40        batch_size: int = 1,
41        num_workers: int = 0,
42        custom_transforms: Callable | transforms.Compose | None = None,
43        repeat_channels: int | None = None,
44        to_tensor: bool = True,
45        resize: tuple[int, int] | None = None,
46    ) -> None:
47        r"""
48        **Args:**
49        - **root** (`str`): the root directory where the original EMNIST data 'EMNIST/' live.
50        - **split** (`str`): the original EMNIST dataset has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`. This argument specifies which one to use.
51        - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data.
52        - **batch_size** (`int`): The batch size in train, val, test dataloader.
53        - **num_workers** (`int`): the number of workers for dataloaders.
54        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
55        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
56        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
57        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
58        """
59        if split == "byclass":
60            self.original_dataset_python_class: type[Dataset] = EMNISTByClass
61        elif split == "bymerge":
62            self.original_dataset_python_class: type[Dataset] = EMNISTByMerge
63        elif split == "balanced":
64            self.original_dataset_python_class: type[Dataset] = EMNISTBalanced
65        elif split == "letters":
66            self.original_dataset_python_class: type[Dataset] = EMNISTLetters
67        elif split == "digits":
68            self.original_dataset_python_class: type[Dataset] = EMNISTDigits
69            r"""The original dataset class."""
70
71        super().__init__(
72            root=root,
73            batch_size=batch_size,
74            num_workers=num_workers,
75            custom_transforms=custom_transforms,
76            repeat_channels=repeat_channels,
77            to_tensor=to_tensor,
78            resize=resize,
79        )
80
81        self.split: str = split
82        r"""The split of the original EMNIST dataset. It can be `byclass`, `bymerge`, `balanced`, `letters`, `digits` or `mnist`."""
83
84        self.validation_percentage: float = validation_percentage
85        r"""The percentage to randomly split some training data into validation data."""

Args:

  • root (str): the root directory where the original EMNIST data 'EMNIST/' live.
  • split (str): the original EMNIST dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist. This argument specifies which one to use.
  • validation_percentage (float): the percentage to randomly split some training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalize and so on are not included.
  • repeat_channels (int | None): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
  • to_tensor (bool): whether to include ToTensor() transform. Default is True.
  • resize (tuple[int, int] | None or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
split: str

The split of the original EMNIST dataset. It can be byclass, bymerge, balanced, letters, digits or mnist.

validation_percentage: float

The percentage to randomly split some training data into validation data.

def prepare_data(self) -> None:
87    def prepare_data(self) -> None:
88        r"""Download the original EMNIST dataset if haven't."""
89
90        EMNISTRaw(root=self.root, split=self.split, train=True, download=True)
91        EMNISTRaw(root=self.root, split=self.split, train=False, download=True)
92
93        pylogger.debug(
94            "The original EMNIST dataset has been downloaded to %s.", self.root
95        )

Download the original EMNIST dataset if haven't.

def train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
 97    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 98        """Get the training and validation dataset.
 99
100        **Returns:**
101        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
102        """
103        dataset_train_and_val = EMNISTRaw(
104            root=self.root,
105            split=self.split,
106            train=True,
107            transform=self.train_and_val_transforms(),
108            target_transform=self.target_transform(),
109            download=False,
110        )
111
112        return random_split(
113            dataset_train_and_val,
114            lengths=[1 - self.validation_percentage, self.validation_percentage],
115            generator=torch.Generator().manual_seed(
116                42
117            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
118        )

Get the training and validation dataset.

Returns:

  • train_and_val_dataset (tuple[Dataset, Dataset]): the train and validation dataset.
def test_dataset(self) -> torch.utils.data.dataset.Dataset:
120    def test_dataset(self) -> Dataset:
121        r"""Get the test dataset.
122
123        **Returns:**
124        - **test_dataset** (`Dataset`): the test dataset.
125        """
126        dataset_test = EMNISTRaw(
127            root=self.root,
128            split=self.split,
129            train=False,
130            transform=self.test_transforms(),
131            target_transform=self.target_transform(),
132            download=False,
133        )
134
135        return dataset_test

Get the test dataset.

Returns:

  • test_dataset (Dataset): the test dataset.