clarena.stl_datasets.celeba

The submodule in stl_datasets for CelebA dataset.

  1r"""
  2The submodule in `stl_datasets` for CelebA dataset.
  3"""
  4
  5__all__ = ["CelebA"]
  6
  7import logging
  8from typing import Callable
  9
 10from torch.utils.data import Dataset
 11from torchvision.datasets import CelebA as CelebARaw
 12from torchvision.transforms import transforms
 13
 14from clarena.stl_datasets.base import STLDatasetFromRaw
 15
 16# always get logger for built-in logging in each module
 17pylogger = logging.getLogger(__name__)
 18
 19
 20class CelebA(STLDatasetFromRaw):
 21    r"""CelebA dataset. The [CelebFaces Attributes Dataset (CelebA)](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) is a large-scale celebrity faces dataset. It consists of 202,599 face images of 10,177 celebrity identities (classes), each 178x218 color image.
 22
 23    Note that the original CelebA dataset is not a classification dataset but an attributes dataset. We only use the identity of each face as the class label for classification.
 24    """
 25
 26    original_dataset_python_class: type[Dataset] = CelebARaw
 27    r"""The original dataset class."""
 28
 29    def __init__(
 30        self,
 31        root: str,
 32        batch_size: int = 1,
 33        num_workers: int = 0,
 34        custom_transforms: Callable | transforms.Compose | None = None,
 35        repeat_channels: int | None = None,
 36        to_tensor: bool = True,
 37        resize: tuple[int, int] | None = None,
 38    ) -> None:
 39        r"""
 40        **Args:**
 41        - **root** (`str`): the root directory where the original CelebA data 'CelebA/' live.
 42        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 43        - **num_workers** (`int`): the number of workers for dataloaders.
 44        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
 45        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
 46        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
 47        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
 48        """
 49        super().__init__(
 50            root=root,
 51            batch_size=batch_size,
 52            num_workers=num_workers,
 53            custom_transforms=custom_transforms,
 54            repeat_channels=repeat_channels,
 55            to_tensor=to_tensor,
 56            resize=resize,
 57        )
 58
 59    def prepare_data(self) -> None:
 60        r"""Download the original CelebA dataset if haven't."""
 61
 62        CelebARaw(root=self.root, split="train", target_type="identity", download=True)
 63        CelebARaw(root=self.root, split="valid", target_type="identity", download=True)
 64        CelebARaw(root=self.root, split="test", target_type="identity", download=True)
 65
 66        pylogger.debug(
 67            "The original CelebA dataset has been downloaded to %s.", self.root
 68        )
 69
 70    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 71        """Get the training and validation dataset.
 72
 73        **Returns:**
 74        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
 75        """
 76        dataset_train = CelebARaw(
 77            root=self.root,
 78            split="train",
 79            target_type="identity",
 80            transform=self.train_and_val_transforms(),
 81            target_transform=self.target_transform(),
 82            download=False,
 83        )
 84
 85        dataset_val = CelebARaw(
 86            root=self.root,
 87            split="valid",
 88            target_type="identity",
 89            transform=self.train_and_val_transforms(),
 90            target_transform=self.target_transform(),
 91            download=False,
 92        )
 93
 94        return dataset_train, dataset_val
 95
 96    def test_dataset(self) -> Dataset:
 97        r"""Get the test dataset.
 98
 99        **Returns:**
100        - **test_dataset** (`Dataset`): the test dataset.
101        """
102        dataset_test = CelebARaw(
103            root=self.root,
104            split="test",
105            target_type="identity",
106            transform=self.test_transforms(),
107            target_transform=self.target_transform(),
108            download=False,
109        )
110
111        return dataset_test
class CelebA(clarena.stl_datasets.base.STLDatasetFromRaw):
 21class CelebA(STLDatasetFromRaw):
 22    r"""CelebA dataset. The [CelebFaces Attributes Dataset (CelebA)](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) is a large-scale celebrity faces dataset. It consists of 202,599 face images of 10,177 celebrity identities (classes), each 178x218 color image.
 23
 24    Note that the original CelebA dataset is not a classification dataset but an attributes dataset. We only use the identity of each face as the class label for classification.
 25    """
 26
 27    original_dataset_python_class: type[Dataset] = CelebARaw
 28    r"""The original dataset class."""
 29
 30    def __init__(
 31        self,
 32        root: str,
 33        batch_size: int = 1,
 34        num_workers: int = 0,
 35        custom_transforms: Callable | transforms.Compose | None = None,
 36        repeat_channels: int | None = None,
 37        to_tensor: bool = True,
 38        resize: tuple[int, int] | None = None,
 39    ) -> None:
 40        r"""
 41        **Args:**
 42        - **root** (`str`): the root directory where the original CelebA data 'CelebA/' live.
 43        - **batch_size** (`int`): The batch size in train, val, test dataloader.
 44        - **num_workers** (`int`): the number of workers for dataloaders.
 45        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
 46        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
 47        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
 48        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
 49        """
 50        super().__init__(
 51            root=root,
 52            batch_size=batch_size,
 53            num_workers=num_workers,
 54            custom_transforms=custom_transforms,
 55            repeat_channels=repeat_channels,
 56            to_tensor=to_tensor,
 57            resize=resize,
 58        )
 59
 60    def prepare_data(self) -> None:
 61        r"""Download the original CelebA dataset if haven't."""
 62
 63        CelebARaw(root=self.root, split="train", target_type="identity", download=True)
 64        CelebARaw(root=self.root, split="valid", target_type="identity", download=True)
 65        CelebARaw(root=self.root, split="test", target_type="identity", download=True)
 66
 67        pylogger.debug(
 68            "The original CelebA dataset has been downloaded to %s.", self.root
 69        )
 70
 71    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
 72        """Get the training and validation dataset.
 73
 74        **Returns:**
 75        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
 76        """
 77        dataset_train = CelebARaw(
 78            root=self.root,
 79            split="train",
 80            target_type="identity",
 81            transform=self.train_and_val_transforms(),
 82            target_transform=self.target_transform(),
 83            download=False,
 84        )
 85
 86        dataset_val = CelebARaw(
 87            root=self.root,
 88            split="valid",
 89            target_type="identity",
 90            transform=self.train_and_val_transforms(),
 91            target_transform=self.target_transform(),
 92            download=False,
 93        )
 94
 95        return dataset_train, dataset_val
 96
 97    def test_dataset(self) -> Dataset:
 98        r"""Get the test dataset.
 99
100        **Returns:**
101        - **test_dataset** (`Dataset`): the test dataset.
102        """
103        dataset_test = CelebARaw(
104            root=self.root,
105            split="test",
106            target_type="identity",
107            transform=self.test_transforms(),
108            target_transform=self.target_transform(),
109            download=False,
110        )
111
112        return dataset_test

CelebA dataset. The CelebFaces Attributes Dataset (CelebA) is a large-scale celebrity faces dataset. It consists of 202,599 face images of 10,177 celebrity identities (classes), each 178x218 color image.

Note that the original CelebA dataset is not a classification dataset but an attributes dataset. We only use the identity of each face as the class label for classification.

CelebA( root: str, batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, repeat_channels: int | None = None, to_tensor: bool = True, resize: tuple[int, int] | None = None)
30    def __init__(
31        self,
32        root: str,
33        batch_size: int = 1,
34        num_workers: int = 0,
35        custom_transforms: Callable | transforms.Compose | None = None,
36        repeat_channels: int | None = None,
37        to_tensor: bool = True,
38        resize: tuple[int, int] | None = None,
39    ) -> None:
40        r"""
41        **Args:**
42        - **root** (`str`): the root directory where the original CelebA data 'CelebA/' live.
43        - **batch_size** (`int`): The batch size in train, val, test dataloader.
44        - **num_workers** (`int`): the number of workers for dataloaders.
45        - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included.
46        - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
47        - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True.
48        - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
49        """
50        super().__init__(
51            root=root,
52            batch_size=batch_size,
53            num_workers=num_workers,
54            custom_transforms=custom_transforms,
55            repeat_channels=repeat_channels,
56            to_tensor=to_tensor,
57            resize=resize,
58        )

Args:

  • root (str): the root directory where the original CelebA data 'CelebA/' live.
  • batch_size (int): The batch size in train, val, test dataloader.
  • num_workers (int): the number of workers for dataloaders.
  • custom_transforms (transform or transforms.Compose or None): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. ToTensor(), normalize and so on are not included.
  • repeat_channels (int | None): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer.
  • to_tensor (bool): whether to include ToTensor() transform. Default is True.
  • resize (tuple[int, int] | None or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset] = <class 'torchvision.datasets.celeba.CelebA'>

The original dataset class.

def prepare_data(self) -> None:
60    def prepare_data(self) -> None:
61        r"""Download the original CelebA dataset if haven't."""
62
63        CelebARaw(root=self.root, split="train", target_type="identity", download=True)
64        CelebARaw(root=self.root, split="valid", target_type="identity", download=True)
65        CelebARaw(root=self.root, split="test", target_type="identity", download=True)
66
67        pylogger.debug(
68            "The original CelebA dataset has been downloaded to %s.", self.root
69        )

Download the original CelebA dataset if haven't.

def train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
71    def train_and_val_dataset(self) -> tuple[Dataset, Dataset]:
72        """Get the training and validation dataset.
73
74        **Returns:**
75        - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset.
76        """
77        dataset_train = CelebARaw(
78            root=self.root,
79            split="train",
80            target_type="identity",
81            transform=self.train_and_val_transforms(),
82            target_transform=self.target_transform(),
83            download=False,
84        )
85
86        dataset_val = CelebARaw(
87            root=self.root,
88            split="valid",
89            target_type="identity",
90            transform=self.train_and_val_transforms(),
91            target_transform=self.target_transform(),
92            download=False,
93        )
94
95        return dataset_train, dataset_val

Get the training and validation dataset.

Returns:

  • train_and_val_dataset (tuple[Dataset, Dataset]): the train and validation dataset.
def test_dataset(self) -> torch.utils.data.dataset.Dataset:
 97    def test_dataset(self) -> Dataset:
 98        r"""Get the test dataset.
 99
100        **Returns:**
101        - **test_dataset** (`Dataset`): the test dataset.
102        """
103        dataset_test = CelebARaw(
104            root=self.root,
105            split="test",
106            target_type="identity",
107            transform=self.test_transforms(),
108            target_transform=self.target_transform(),
109            download=False,
110        )
111
112        return dataset_test

Get the test dataset.

Returns:

  • test_dataset (Dataset): the test dataset.