clarena.stl_datasets.celeba
The submodule in stl_datasets for CelebA dataset.
1r""" 2The submodule in `stl_datasets` for CelebA dataset. 3""" 4 5__all__ = ["CelebA"] 6 7import logging 8from typing import Callable 9 10from torch.utils.data import Dataset 11from torchvision.datasets import CelebA as CelebARaw 12from torchvision.transforms import transforms 13 14from clarena.stl_datasets.base import STLDatasetFromRaw 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class CelebA(STLDatasetFromRaw): 21 r"""CelebA dataset. The [CelebFaces Attributes Dataset (CelebA)](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) is a large-scale celebrity faces dataset. It consists of 202,599 face images of 10,177 celebrity identities (classes), each 178x218 color image. 22 23 Note that the original CelebA dataset is not a classification dataset but an attributes dataset. We only use the identity of each face as the class label for classification. 24 """ 25 26 original_dataset_python_class: type[Dataset] = CelebARaw 27 r"""The original dataset class.""" 28 29 def __init__( 30 self, 31 root: str, 32 batch_size: int = 1, 33 num_workers: int = 0, 34 custom_transforms: Callable | transforms.Compose | None = None, 35 repeat_channels: int | None = None, 36 to_tensor: bool = True, 37 resize: tuple[int, int] | None = None, 38 ) -> None: 39 r""" 40 **Args:** 41 - **root** (`str`): the root directory where the original CelebA data 'CelebA/' live. 42 - **batch_size** (`int`): The batch size in train, val, test dataloader. 43 - **num_workers** (`int`): the number of workers for dataloaders. 44 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 45 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 46 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 47 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 48 """ 49 super().__init__( 50 root=root, 51 batch_size=batch_size, 52 num_workers=num_workers, 53 custom_transforms=custom_transforms, 54 repeat_channels=repeat_channels, 55 to_tensor=to_tensor, 56 resize=resize, 57 ) 58 59 def prepare_data(self) -> None: 60 r"""Download the original CelebA dataset if haven't.""" 61 62 CelebARaw(root=self.root, split="train", target_type="identity", download=True) 63 CelebARaw(root=self.root, split="valid", target_type="identity", download=True) 64 CelebARaw(root=self.root, split="test", target_type="identity", download=True) 65 66 pylogger.debug( 67 "The original CelebA dataset has been downloaded to %s.", self.root 68 ) 69 70 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 71 """Get the training and validation dataset. 72 73 **Returns:** 74 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 75 """ 76 dataset_train = CelebARaw( 77 root=self.root, 78 split="train", 79 target_type="identity", 80 transform=self.train_and_val_transforms(), 81 target_transform=self.target_transform(), 82 download=False, 83 ) 84 85 dataset_val = CelebARaw( 86 root=self.root, 87 split="valid", 88 target_type="identity", 89 transform=self.train_and_val_transforms(), 90 target_transform=self.target_transform(), 91 download=False, 92 ) 93 94 return dataset_train, dataset_val 95 96 def test_dataset(self) -> Dataset: 97 r"""Get the test dataset. 98 99 **Returns:** 100 - **test_dataset** (`Dataset`): the test dataset. 101 """ 102 dataset_test = CelebARaw( 103 root=self.root, 104 split="test", 105 target_type="identity", 106 transform=self.test_transforms(), 107 target_transform=self.target_transform(), 108 download=False, 109 ) 110 111 return dataset_test
class
CelebA(clarena.stl_datasets.base.STLDatasetFromRaw):
21class CelebA(STLDatasetFromRaw): 22 r"""CelebA dataset. The [CelebFaces Attributes Dataset (CelebA)](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) is a large-scale celebrity faces dataset. It consists of 202,599 face images of 10,177 celebrity identities (classes), each 178x218 color image. 23 24 Note that the original CelebA dataset is not a classification dataset but an attributes dataset. We only use the identity of each face as the class label for classification. 25 """ 26 27 original_dataset_python_class: type[Dataset] = CelebARaw 28 r"""The original dataset class.""" 29 30 def __init__( 31 self, 32 root: str, 33 batch_size: int = 1, 34 num_workers: int = 0, 35 custom_transforms: Callable | transforms.Compose | None = None, 36 repeat_channels: int | None = None, 37 to_tensor: bool = True, 38 resize: tuple[int, int] | None = None, 39 ) -> None: 40 r""" 41 **Args:** 42 - **root** (`str`): the root directory where the original CelebA data 'CelebA/' live. 43 - **batch_size** (`int`): The batch size in train, val, test dataloader. 44 - **num_workers** (`int`): the number of workers for dataloaders. 45 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 46 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 47 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 48 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 49 """ 50 super().__init__( 51 root=root, 52 batch_size=batch_size, 53 num_workers=num_workers, 54 custom_transforms=custom_transforms, 55 repeat_channels=repeat_channels, 56 to_tensor=to_tensor, 57 resize=resize, 58 ) 59 60 def prepare_data(self) -> None: 61 r"""Download the original CelebA dataset if haven't.""" 62 63 CelebARaw(root=self.root, split="train", target_type="identity", download=True) 64 CelebARaw(root=self.root, split="valid", target_type="identity", download=True) 65 CelebARaw(root=self.root, split="test", target_type="identity", download=True) 66 67 pylogger.debug( 68 "The original CelebA dataset has been downloaded to %s.", self.root 69 ) 70 71 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 72 """Get the training and validation dataset. 73 74 **Returns:** 75 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 76 """ 77 dataset_train = CelebARaw( 78 root=self.root, 79 split="train", 80 target_type="identity", 81 transform=self.train_and_val_transforms(), 82 target_transform=self.target_transform(), 83 download=False, 84 ) 85 86 dataset_val = CelebARaw( 87 root=self.root, 88 split="valid", 89 target_type="identity", 90 transform=self.train_and_val_transforms(), 91 target_transform=self.target_transform(), 92 download=False, 93 ) 94 95 return dataset_train, dataset_val 96 97 def test_dataset(self) -> Dataset: 98 r"""Get the test dataset. 99 100 **Returns:** 101 - **test_dataset** (`Dataset`): the test dataset. 102 """ 103 dataset_test = CelebARaw( 104 root=self.root, 105 split="test", 106 target_type="identity", 107 transform=self.test_transforms(), 108 target_transform=self.target_transform(), 109 download=False, 110 ) 111 112 return dataset_test
CelebA dataset. The CelebFaces Attributes Dataset (CelebA) is a large-scale celebrity faces dataset. It consists of 202,599 face images of 10,177 celebrity identities (classes), each 178x218 color image.
Note that the original CelebA dataset is not a classification dataset but an attributes dataset. We only use the identity of each face as the class label for classification.
CelebA( root: str, batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, repeat_channels: int | None = None, to_tensor: bool = True, resize: tuple[int, int] | None = None)
30 def __init__( 31 self, 32 root: str, 33 batch_size: int = 1, 34 num_workers: int = 0, 35 custom_transforms: Callable | transforms.Compose | None = None, 36 repeat_channels: int | None = None, 37 to_tensor: bool = True, 38 resize: tuple[int, int] | None = None, 39 ) -> None: 40 r""" 41 **Args:** 42 - **root** (`str`): the root directory where the original CelebA data 'CelebA/' live. 43 - **batch_size** (`int`): The batch size in train, val, test dataloader. 44 - **num_workers** (`int`): the number of workers for dataloaders. 45 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 46 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 47 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 48 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 49 """ 50 super().__init__( 51 root=root, 52 batch_size=batch_size, 53 num_workers=num_workers, 54 custom_transforms=custom_transforms, 55 repeat_channels=repeat_channels, 56 to_tensor=to_tensor, 57 resize=resize, 58 )
Args:
- root (
str): the root directory where the original CelebA data 'CelebA/' live. - batch_size (
int): The batch size in train, val, test dataloader. - num_workers (
int): the number of workers for dataloaders. - custom_transforms (
transformortransforms.ComposeorNone): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor(), normalize and so on are not included. - repeat_channels (
int|None): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. - to_tensor (
bool): whether to includeToTensor()transform. Default is True. - resize (
tuple[int, int]|Noneor list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset] =
<class 'torchvision.datasets.celeba.CelebA'>
The original dataset class.
def
prepare_data(self) -> None:
60 def prepare_data(self) -> None: 61 r"""Download the original CelebA dataset if haven't.""" 62 63 CelebARaw(root=self.root, split="train", target_type="identity", download=True) 64 CelebARaw(root=self.root, split="valid", target_type="identity", download=True) 65 CelebARaw(root=self.root, split="test", target_type="identity", download=True) 66 67 pylogger.debug( 68 "The original CelebA dataset has been downloaded to %s.", self.root 69 )
Download the original CelebA dataset if haven't.
def
train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
71 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 72 """Get the training and validation dataset. 73 74 **Returns:** 75 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 76 """ 77 dataset_train = CelebARaw( 78 root=self.root, 79 split="train", 80 target_type="identity", 81 transform=self.train_and_val_transforms(), 82 target_transform=self.target_transform(), 83 download=False, 84 ) 85 86 dataset_val = CelebARaw( 87 root=self.root, 88 split="valid", 89 target_type="identity", 90 transform=self.train_and_val_transforms(), 91 target_transform=self.target_transform(), 92 download=False, 93 ) 94 95 return dataset_train, dataset_val
Get the training and validation dataset.
Returns:
- train_and_val_dataset (
tuple[Dataset, Dataset]): the train and validation dataset.
def
test_dataset(self) -> torch.utils.data.dataset.Dataset:
97 def test_dataset(self) -> Dataset: 98 r"""Get the test dataset. 99 100 **Returns:** 101 - **test_dataset** (`Dataset`): the test dataset. 102 """ 103 dataset_test = CelebARaw( 104 root=self.root, 105 split="test", 106 target_type="identity", 107 transform=self.test_transforms(), 108 target_transform=self.target_transform(), 109 download=False, 110 ) 111 112 return dataset_test
Get the test dataset.
Returns:
- test_dataset (
Dataset): the test dataset.