clarena.stl_datasets.emnist
The submodule in stl_datasets for EMNIST dataset.
1r""" 2The submodule in `stl_datasets` for EMNIST dataset. 3""" 4 5__all__ = ["EMNIST"] 6 7import logging 8from typing import Callable 9 10import torch 11from torch.utils.data import Dataset, random_split 12from torchvision.datasets import EMNIST as EMNISTRaw 13from torchvision.transforms import transforms 14 15from clarena.stl_datasets.base import STLDatasetFromRaw 16from clarena.stl_datasets.raw import ( 17 EMNISTBalanced, 18 EMNISTByClass, 19 EMNISTByMerge, 20 EMNISTDigits, 21 EMNISTLetters, 22) 23 24# always get logger for built-in logging in each module 25pylogger = logging.getLogger(__name__) 26 27 28class EMNIST(STLDatasetFromRaw): 29 r"""EMNIST dataset. The [EMNIST dataset](https://www.nist.gov/itl/products-and-services/emnist-dataset/) is a collection of handwritten letters and digits (including A-Z, a-z, 0-9). It consists of 814,255 images in 62 classes, each 28x28 grayscale image. 30 31 EMNIST has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`, each containing a different subset of the original collection. We support all of them in Permuted EMNIST. 32 """ 33 34 def __init__( 35 self, 36 root: str, 37 split: str, 38 validation_percentage: float, 39 batch_size: int = 1, 40 num_workers: int = 0, 41 custom_transforms: Callable | transforms.Compose | None = None, 42 repeat_channels: int | None = None, 43 to_tensor: bool = True, 44 resize: tuple[int, int] | None = None, 45 ) -> None: 46 r""" 47 **Args:** 48 - **root** (`str`): the root directory where the original EMNIST data 'EMNIST/' live. 49 - **split** (`str`): the original EMNIST dataset has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`. This argument specifies which one to use. 50 - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data. 51 - **batch_size** (`int`): The batch size in train, val, test dataloader. 52 - **num_workers** (`int`): the number of workers for dataloaders. 53 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 54 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 55 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 56 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 57 """ 58 if split == "byclass": 59 self.original_dataset_python_class: type[Dataset] = EMNISTByClass 60 elif split == "bymerge": 61 self.original_dataset_python_class: type[Dataset] = EMNISTByMerge 62 elif split == "balanced": 63 self.original_dataset_python_class: type[Dataset] = EMNISTBalanced 64 elif split == "letters": 65 self.original_dataset_python_class: type[Dataset] = EMNISTLetters 66 elif split == "digits": 67 self.original_dataset_python_class: type[Dataset] = EMNISTDigits 68 r"""The original dataset class.""" 69 70 super().__init__( 71 root=root, 72 batch_size=batch_size, 73 num_workers=num_workers, 74 custom_transforms=custom_transforms, 75 repeat_channels=repeat_channels, 76 to_tensor=to_tensor, 77 resize=resize, 78 ) 79 80 self.split: str = split 81 r"""The split of the original EMNIST dataset. It can be `byclass`, `bymerge`, `balanced`, `letters`, `digits` or `mnist`.""" 82 83 self.validation_percentage: float = validation_percentage 84 r"""The percentage to randomly split some training data into validation data.""" 85 86 def prepare_data(self) -> None: 87 r"""Download the original EMNIST dataset if haven't.""" 88 89 EMNISTRaw(root=self.root, split=self.split, train=True, download=True) 90 EMNISTRaw(root=self.root, split=self.split, train=False, download=True) 91 92 pylogger.debug( 93 "The original EMNIST dataset has been downloaded to %s.", self.root 94 ) 95 96 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 97 """Get the training and validation dataset. 98 99 **Returns:** 100 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 101 """ 102 dataset_train_and_val = EMNISTRaw( 103 root=self.root, 104 split=self.split, 105 train=True, 106 transform=self.train_and_val_transforms(), 107 target_transform=self.target_transform(), 108 download=False, 109 ) 110 111 return random_split( 112 dataset_train_and_val, 113 lengths=[1 - self.validation_percentage, self.validation_percentage], 114 generator=torch.Generator().manual_seed( 115 42 116 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 117 ) 118 119 def test_dataset(self) -> Dataset: 120 r"""Get the test dataset. 121 122 **Returns:** 123 - **test_dataset** (`Dataset`): the test dataset. 124 """ 125 dataset_test = EMNISTRaw( 126 root=self.root, 127 split=self.split, 128 train=False, 129 transform=self.test_transforms(), 130 target_transform=self.target_transform(), 131 download=False, 132 ) 133 134 return dataset_test
class
EMNIST(clarena.stl_datasets.base.STLDatasetFromRaw):
29class EMNIST(STLDatasetFromRaw): 30 r"""EMNIST dataset. The [EMNIST dataset](https://www.nist.gov/itl/products-and-services/emnist-dataset/) is a collection of handwritten letters and digits (including A-Z, a-z, 0-9). It consists of 814,255 images in 62 classes, each 28x28 grayscale image. 31 32 EMNIST has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`, each containing a different subset of the original collection. We support all of them in Permuted EMNIST. 33 """ 34 35 def __init__( 36 self, 37 root: str, 38 split: str, 39 validation_percentage: float, 40 batch_size: int = 1, 41 num_workers: int = 0, 42 custom_transforms: Callable | transforms.Compose | None = None, 43 repeat_channels: int | None = None, 44 to_tensor: bool = True, 45 resize: tuple[int, int] | None = None, 46 ) -> None: 47 r""" 48 **Args:** 49 - **root** (`str`): the root directory where the original EMNIST data 'EMNIST/' live. 50 - **split** (`str`): the original EMNIST dataset has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`. This argument specifies which one to use. 51 - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data. 52 - **batch_size** (`int`): The batch size in train, val, test dataloader. 53 - **num_workers** (`int`): the number of workers for dataloaders. 54 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 55 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 56 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 57 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 58 """ 59 if split == "byclass": 60 self.original_dataset_python_class: type[Dataset] = EMNISTByClass 61 elif split == "bymerge": 62 self.original_dataset_python_class: type[Dataset] = EMNISTByMerge 63 elif split == "balanced": 64 self.original_dataset_python_class: type[Dataset] = EMNISTBalanced 65 elif split == "letters": 66 self.original_dataset_python_class: type[Dataset] = EMNISTLetters 67 elif split == "digits": 68 self.original_dataset_python_class: type[Dataset] = EMNISTDigits 69 r"""The original dataset class.""" 70 71 super().__init__( 72 root=root, 73 batch_size=batch_size, 74 num_workers=num_workers, 75 custom_transforms=custom_transforms, 76 repeat_channels=repeat_channels, 77 to_tensor=to_tensor, 78 resize=resize, 79 ) 80 81 self.split: str = split 82 r"""The split of the original EMNIST dataset. It can be `byclass`, `bymerge`, `balanced`, `letters`, `digits` or `mnist`.""" 83 84 self.validation_percentage: float = validation_percentage 85 r"""The percentage to randomly split some training data into validation data.""" 86 87 def prepare_data(self) -> None: 88 r"""Download the original EMNIST dataset if haven't.""" 89 90 EMNISTRaw(root=self.root, split=self.split, train=True, download=True) 91 EMNISTRaw(root=self.root, split=self.split, train=False, download=True) 92 93 pylogger.debug( 94 "The original EMNIST dataset has been downloaded to %s.", self.root 95 ) 96 97 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 98 """Get the training and validation dataset. 99 100 **Returns:** 101 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 102 """ 103 dataset_train_and_val = EMNISTRaw( 104 root=self.root, 105 split=self.split, 106 train=True, 107 transform=self.train_and_val_transforms(), 108 target_transform=self.target_transform(), 109 download=False, 110 ) 111 112 return random_split( 113 dataset_train_and_val, 114 lengths=[1 - self.validation_percentage, self.validation_percentage], 115 generator=torch.Generator().manual_seed( 116 42 117 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 118 ) 119 120 def test_dataset(self) -> Dataset: 121 r"""Get the test dataset. 122 123 **Returns:** 124 - **test_dataset** (`Dataset`): the test dataset. 125 """ 126 dataset_test = EMNISTRaw( 127 root=self.root, 128 split=self.split, 129 train=False, 130 transform=self.test_transforms(), 131 target_transform=self.target_transform(), 132 download=False, 133 ) 134 135 return dataset_test
EMNIST dataset. The EMNIST dataset is a collection of handwritten letters and digits (including A-Z, a-z, 0-9). It consists of 814,255 images in 62 classes, each 28x28 grayscale image.
EMNIST has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist, each containing a different subset of the original collection. We support all of them in Permuted EMNIST.
EMNIST( root: str, split: str, validation_percentage: float, batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, repeat_channels: int | None = None, to_tensor: bool = True, resize: tuple[int, int] | None = None)
35 def __init__( 36 self, 37 root: str, 38 split: str, 39 validation_percentage: float, 40 batch_size: int = 1, 41 num_workers: int = 0, 42 custom_transforms: Callable | transforms.Compose | None = None, 43 repeat_channels: int | None = None, 44 to_tensor: bool = True, 45 resize: tuple[int, int] | None = None, 46 ) -> None: 47 r""" 48 **Args:** 49 - **root** (`str`): the root directory where the original EMNIST data 'EMNIST/' live. 50 - **split** (`str`): the original EMNIST dataset has 6 different splits: `byclass`, `bymerge`, `balanced`, `letters`, `digits` and `mnist`. This argument specifies which one to use. 51 - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data. 52 - **batch_size** (`int`): The batch size in train, val, test dataloader. 53 - **num_workers** (`int`): the number of workers for dataloaders. 54 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 55 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 56 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 57 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 58 """ 59 if split == "byclass": 60 self.original_dataset_python_class: type[Dataset] = EMNISTByClass 61 elif split == "bymerge": 62 self.original_dataset_python_class: type[Dataset] = EMNISTByMerge 63 elif split == "balanced": 64 self.original_dataset_python_class: type[Dataset] = EMNISTBalanced 65 elif split == "letters": 66 self.original_dataset_python_class: type[Dataset] = EMNISTLetters 67 elif split == "digits": 68 self.original_dataset_python_class: type[Dataset] = EMNISTDigits 69 r"""The original dataset class.""" 70 71 super().__init__( 72 root=root, 73 batch_size=batch_size, 74 num_workers=num_workers, 75 custom_transforms=custom_transforms, 76 repeat_channels=repeat_channels, 77 to_tensor=to_tensor, 78 resize=resize, 79 ) 80 81 self.split: str = split 82 r"""The split of the original EMNIST dataset. It can be `byclass`, `bymerge`, `balanced`, `letters`, `digits` or `mnist`.""" 83 84 self.validation_percentage: float = validation_percentage 85 r"""The percentage to randomly split some training data into validation data."""
Args:
- root (
str): the root directory where the original EMNIST data 'EMNIST/' live. - split (
str): the original EMNIST dataset has 6 different splits:byclass,bymerge,balanced,letters,digitsandmnist. This argument specifies which one to use. - validation_percentage (
float): the percentage to randomly split some training data into validation data. - batch_size (
int): The batch size in train, val, test dataloader. - num_workers (
int): the number of workers for dataloaders. - custom_transforms (
transformortransforms.ComposeorNone): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor(), normalize and so on are not included. - repeat_channels (
int|None): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. - to_tensor (
bool): whether to includeToTensor()transform. Default is True. - resize (
tuple[int, int]|Noneor list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
split: str
The split of the original EMNIST dataset. It can be byclass, bymerge, balanced, letters, digits or mnist.
validation_percentage: float
The percentage to randomly split some training data into validation data.
def
prepare_data(self) -> None:
87 def prepare_data(self) -> None: 88 r"""Download the original EMNIST dataset if haven't.""" 89 90 EMNISTRaw(root=self.root, split=self.split, train=True, download=True) 91 EMNISTRaw(root=self.root, split=self.split, train=False, download=True) 92 93 pylogger.debug( 94 "The original EMNIST dataset has been downloaded to %s.", self.root 95 )
Download the original EMNIST dataset if haven't.
def
train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
97 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 98 """Get the training and validation dataset. 99 100 **Returns:** 101 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 102 """ 103 dataset_train_and_val = EMNISTRaw( 104 root=self.root, 105 split=self.split, 106 train=True, 107 transform=self.train_and_val_transforms(), 108 target_transform=self.target_transform(), 109 download=False, 110 ) 111 112 return random_split( 113 dataset_train_and_val, 114 lengths=[1 - self.validation_percentage, self.validation_percentage], 115 generator=torch.Generator().manual_seed( 116 42 117 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 118 )
Get the training and validation dataset.
Returns:
- train_and_val_dataset (
tuple[Dataset, Dataset]): the train and validation dataset.
def
test_dataset(self) -> torch.utils.data.dataset.Dataset:
120 def test_dataset(self) -> Dataset: 121 r"""Get the test dataset. 122 123 **Returns:** 124 - **test_dataset** (`Dataset`): the test dataset. 125 """ 126 dataset_test = EMNISTRaw( 127 root=self.root, 128 split=self.split, 129 train=False, 130 transform=self.test_transforms(), 131 target_transform=self.target_transform(), 132 download=False, 133 ) 134 135 return dataset_test
Get the test dataset.
Returns:
- test_dataset (
Dataset): the test dataset.