clarena.stl_datasets.ahdd

The submodule in stl_datasets for Arabic Handwritten Digits dataset.

  1r"""
  2The submodule in `stl_datasets` for Arabic Handwritten Digits dataset.
  3"""
  4
  5__all__ = ["ArabicHandwrittenDigits"]
  6
  7import logging
  8from typing import Callable
  9
 10import torch
 11from torch.utils.data import Dataset, random_split
 12from torchvision.transforms import transforms
 13
 14from clarena.stl_datasets.base import STLDatasetFromRaw
 15from clarena.stl_datasets.raw import (
 16    ArabicHandwrittenDigits as ArabicHandwrittenDigitsRaw,
 17)
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23class ArabicHandwrittenDigits(STLDatasetFromRaw):
 24    r"""Arabic Handwritten Digits dataset. The [Arabic Handwritten Digits dataset](https://www.kaggle.com/datasets/mloey1/ahdd1) is a collection of handwritten Arabic digits (0-9). It consists of 60,000 training and 10,000 test images of handwritten Arabic digits (10 classes), each 28x28 grayscale image (similar to MNIST)."""
 25
 26    original_dataset_python_class: type[Dataset] = ArabicHandwrittenDigitsRaw
 27
 28    def __init__(
 29        self,
 30        root: str,
 31        validation_percentage: float,
 32        batch_size: int = 1,
 33        num_workers: int = 0,
 34        custom_transforms: Callable | transforms.Compose | None = None,
 35        repeat_channels: int | None = None,
 36        to_tensor: bool = True,
 37        resize: tuple[int, int] | None = None,
 38    ) -> None:
 39        r"""
 40        **Args:**
 41        - **root** (`str`): the root directory where the original Arabic Handwritten Digits data 'ArabicHandwrittenDigits/' live.
 42        - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data.
 43        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 44        - **num_workers** (`int`): the number of workers for dataloaders.
 45        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
 46        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
 47        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
 48        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
 49        """
 50        super().__init__(
 51            root=root,
 52            batch_size=batch_size,
 53            num_workers=num_workers,
 54            custom_transforms=custom_transforms,
 55            repeat_channels=repeat_channels,
 56            to_tensor=to_tensor,
 57            resize=resize,
 58        )
 59
 60        self.validation_percentage: float = validation_percentage
 61        r"""The percentage to randomly split some training data into validation data."""
 62
 63    def prepare_data(self) -> None:
 64        r"""Download the original Arabic Handwritten Digits dataset if haven't. Because the original dataset is published on Kaggle, we need to download it manually. This function will not download the original dataset automatically."""
 65        pass
 66
 67    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 68        """Get the training and validation dataset.
 69
 70        **Returns:**
 71        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
 72        """
 73        dataset_train_and_val = ArabicHandwrittenDigitsRaw(
 74            root=self.root,
 75            train=True,
 76            transform=self.train_and_val_transforms(),
 77            target_transform=self.target_transform(),
 78            download=False,
 79        )
 80
 81        return random_split(
 82            dataset_train_and_val,
 83            lengths=[1 - self.validation_percentage, self.validation_percentage],
 84            generator=torch.Generator().manual_seed(
 85                42
 86            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
 87        )
 88
 89    def test_dataset(self) -> Dataset:
 90        r"""Get the test dataset.
 91
 92        **Returns:**
 93        - **test_dataset** (`Dataset`): the test dataset.
 94        """
 95        dataset_test = ArabicHandwrittenDigitsRaw(
 96            root=self.root,
 97            train=False,
 98            transform=self.test_transforms(),
 99            target_transform=self.target_transform(),
100            download=False,
101        )
102
103        return dataset_test
class ArabicHandwrittenDigits(clarena.stl_datasets.base.STLDatasetFromRaw):
 24class ArabicHandwrittenDigits(STLDatasetFromRaw):
 25    r"""Arabic Handwritten Digits dataset. The [Arabic Handwritten Digits dataset](https://www.kaggle.com/datasets/mloey1/ahdd1) is a collection of handwritten Arabic digits (0-9). It consists of 60,000 training and 10,000 test images of handwritten Arabic digits (10 classes), each 28x28 grayscale image (similar to MNIST)."""
 26
 27    original_dataset_python_class: type[Dataset] = ArabicHandwrittenDigitsRaw
 28
 29    def __init__(
 30        self,
 31        root: str,
 32        validation_percentage: float,
 33        batch_size: int = 1,
 34        num_workers: int = 0,
 35        custom_transforms: Callable | transforms.Compose | None = None,
 36        repeat_channels: int | None = None,
 37        to_tensor: bool = True,
 38        resize: tuple[int, int] | None = None,
 39    ) -> None:
 40        r"""
 41        **Args:**
 42        - **root** (`str`): the root directory where the original Arabic Handwritten Digits data 'ArabicHandwrittenDigits/' live.
 43        - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data.
 44        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 45        - **num_workers** (`int`): the number of workers for dataloaders.
 46        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
 47        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
 48        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
 49        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
 50        """
 51        super().__init__(
 52            root=root,
 53            batch_size=batch_size,
 54            num_workers=num_workers,
 55            custom_transforms=custom_transforms,
 56            repeat_channels=repeat_channels,
 57            to_tensor=to_tensor,
 58            resize=resize,
 59        )
 60
 61        self.validation_percentage: float = validation_percentage
 62        r"""The percentage to randomly split some training data into validation data."""
 63
 64    def prepare_data(self) -> None:
 65        r"""Download the original Arabic Handwritten Digits dataset if haven't. Because the original dataset is published on Kaggle, we need to download it manually. This function will not download the original dataset automatically."""
 66        pass
 67
 68    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 69        """Get the training and validation dataset.
 70
 71        **Returns:**
 72        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
 73        """
 74        dataset_train_and_val = ArabicHandwrittenDigitsRaw(
 75            root=self.root,
 76            train=True,
 77            transform=self.train_and_val_transforms(),
 78            target_transform=self.target_transform(),
 79            download=False,
 80        )
 81
 82        return random_split(
 83            dataset_train_and_val,
 84            lengths=[1 - self.validation_percentage, self.validation_percentage],
 85            generator=torch.Generator().manual_seed(
 86                42
 87            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
 88        )
 89
 90    def test_dataset(self) -> Dataset:
 91        r"""Get the test dataset.
 92
 93        **Returns:**
 94        - **test_dataset** (`Dataset`): the test dataset.
 95        """
 96        dataset_test = ArabicHandwrittenDigitsRaw(
 97            root=self.root,
 98            train=False,
 99            transform=self.test_transforms(),
100            target_transform=self.target_transform(),
101            download=False,
102        )
103
104        return dataset_test

Arabic Handwritten Digits dataset. The Arabic Handwritten Digits dataset is a collection of handwritten Arabic digits (0-9). It consists of 60,000 training and 10,000 test images of handwritten Arabic digits (10 classes), each 28x28 grayscale image (similar to MNIST).

ArabicHandwrittenDigits( root: str, validation_percentage: float, batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, repeat_channels: int | None = None, to_tensor: bool = True, resize: tuple[int, int] | None = None)
29    def __init__(
30        self,
31        root: str,
32        validation_percentage: float,
33        batch_size: int = 1,
34        num_workers: int = 0,
35        custom_transforms: Callable | transforms.Compose | None = None,
36        repeat_channels: int | None = None,
37        to_tensor: bool = True,
38        resize: tuple[int, int] | None = None,
39    ) -> None:
40        r"""
41        **Args:**
42        - **root** (`str`): the root directory where the original Arabic Handwritten Digits data 'ArabicHandwrittenDigits/' live.
43        - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data.
44        - **batch_size** (`int`): The batch size in train, val, test dataloader.
45        - **num_workers** (`int`): the number of workers for dataloaders.
46        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
47        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
48        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
49        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
50        """
51        super().__init__(
52            root=root,
53            batch_size=batch_size,
54            num_workers=num_workers,
55            custom_transforms=custom_transforms,
56            repeat_channels=repeat_channels,
57            to_tensor=to_tensor,
58            resize=resize,
59        )
60
61        self.validation_percentage: float = validation_percentage
62        r"""The percentage to randomly split some training data into validation data."""

Args:

  • root (str): the root directory where the original Arabic Handwritten Digits data 'ArabicHandwrittenDigits/' live.
  • validation_percentage (float): the percentage to randomly split some training data into validation data.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalize and so on are not included.
  • repeat_channels (int | None): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
  • to_tensor (bool): whether to include ToTensor() transform. Default is True.
  • resize (tuple[int, int] | None or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset] = <class 'ArabicHandwrittenDigits'>

The original dataset class. It must be provided in subclasses.

validation_percentage: float

The percentage to randomly split some training data into validation data.

def prepare_data(self) -> None:
64    def prepare_data(self) -> None:
65        r"""Download the original Arabic Handwritten Digits dataset if haven't. Because the original dataset is published on Kaggle, we need to download it manually. This function will not download the original dataset automatically."""
66        pass

Download the original Arabic Handwritten Digits dataset if haven't. Because the original dataset is published on Kaggle, we need to download it manually. This function will not download the original dataset automatically.

def train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
68    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
69        """Get the training and validation dataset.
70
71        **Returns:**
72        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
73        """
74        dataset_train_and_val = ArabicHandwrittenDigitsRaw(
75            root=self.root,
76            train=True,
77            transform=self.train_and_val_transforms(),
78            target_transform=self.target_transform(),
79            download=False,
80        )
81
82        return random_split(
83            dataset_train_and_val,
84            lengths=[1 - self.validation_percentage, self.validation_percentage],
85            generator=torch.Generator().manual_seed(
86                42
87            ),  # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments
88        )

Get the training and validation dataset.

Returns:

  • train_and_val_dataset (tuple[Dataset, Dataset]): the train and validation dataset.
def test_dataset(self) -> torch.utils.data.dataset.Dataset:
 90    def test_dataset(self) -> Dataset:
 91        r"""Get the test dataset.
 92
 93        **Returns:**
 94        - **test_dataset** (`Dataset`): the test dataset.
 95        """
 96        dataset_test = ArabicHandwrittenDigitsRaw(
 97            root=self.root,
 98            train=False,
 99            transform=self.test_transforms(),
100            target_transform=self.target_transform(),
101            download=False,
102        )
103
104        return dataset_test

Get the test dataset.

Returns:

  • test_dataset (Dataset): the test dataset.