clarena.stl_datasets.gtsrb
The submodule in stl_datasets for GTSRB dataset.
1r""" 2The submodule in `stl_datasets` for GTSRB dataset. 3""" 4 5__all__ = ["GTSRB"] 6 7import logging 8from typing import Callable 9 10import torch 11from torch.utils.data import Dataset, random_split 12from torchvision.datasets import GTSRB as GTSRBRaw 13from torchvision.transforms import transforms 14 15from clarena.stl_datasets.base import STLDatasetFromRaw 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class GTSRB(STLDatasetFromRaw): 22 r"""GTSRB dataset. The [GTSRB dataset](http://yann.lecun.com/exdb/mnist/) is a collection of traffic sign images. It consists of 51,839 images of 43 different traffic signs (classes), each color image.""" 23 24 original_dataset_python_class: type[Dataset] = GTSRBRaw 25 r"""The original dataset class.""" 26 27 def __init__( 28 self, 29 root: str, 30 validation_percentage: float, 31 batch_size: int = 1, 32 num_workers: int = 0, 33 custom_transforms: Callable | transforms.Compose | None = None, 34 repeat_channels: int | None = None, 35 to_tensor: bool = True, 36 resize: tuple[int, int] | None = None, 37 ) -> None: 38 r""" 39 **Args:** 40 - **root** (`str`): the root directory where the original GTSRB data 'GTSRB/' live. 41 - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data. 42 - **batch_size** (`int`): The batch size in train, val, test dataloader. 43 - **num_workers** (`int`): the number of workers for dataloaders. 44 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 45 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 46 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 47 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 48 """ 49 super().__init__( 50 root=root, 51 batch_size=batch_size, 52 num_workers=num_workers, 53 custom_transforms=custom_transforms, 54 repeat_channels=repeat_channels, 55 to_tensor=to_tensor, 56 resize=resize, 57 ) 58 59 self.validation_percentage: float = validation_percentage 60 r"""The percentage to randomly split some training data into validation data.""" 61 62 def prepare_data(self) -> None: 63 r"""Download the original GTSRB dataset if haven't.""" 64 65 GTSRBRaw(root=self.root, split="train", download=True) 66 GTSRBRaw(root=self.root, split="test", download=True) 67 68 pylogger.debug( 69 "The original GTSRB dataset has been downloaded to %s.", self.root 70 ) 71 72 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 73 """Get the training and validation dataset. 74 75 **Returns:** 76 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 77 """ 78 dataset_train_and_val = GTSRBRaw( 79 root=self.root, 80 split="train", 81 transform=self.train_and_val_transforms(), 82 target_transform=self.target_transform(), 83 download=False, 84 ) 85 86 return random_split( 87 dataset_train_and_val, 88 lengths=[1 - self.validation_percentage, self.validation_percentage], 89 generator=torch.Generator().manual_seed( 90 42 91 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 92 ) 93 94 def test_dataset(self) -> Dataset: 95 r"""Get the test dataset. 96 97 **Returns:** 98 - **test_dataset** (`Dataset`): the test dataset. 99 """ 100 dataset_test = GTSRBRaw( 101 root=self.root, 102 split="test", 103 transform=self.test_transforms(), 104 target_transform=self.target_transform(), 105 download=False, 106 ) 107 108 return dataset_test
class
GTSRB(clarena.stl_datasets.base.STLDatasetFromRaw):
22class GTSRB(STLDatasetFromRaw): 23 r"""GTSRB dataset. The [GTSRB dataset](http://yann.lecun.com/exdb/mnist/) is a collection of traffic sign images. It consists of 51,839 images of 43 different traffic signs (classes), each color image.""" 24 25 original_dataset_python_class: type[Dataset] = GTSRBRaw 26 r"""The original dataset class.""" 27 28 def __init__( 29 self, 30 root: str, 31 validation_percentage: float, 32 batch_size: int = 1, 33 num_workers: int = 0, 34 custom_transforms: Callable | transforms.Compose | None = None, 35 repeat_channels: int | None = None, 36 to_tensor: bool = True, 37 resize: tuple[int, int] | None = None, 38 ) -> None: 39 r""" 40 **Args:** 41 - **root** (`str`): the root directory where the original GTSRB data 'GTSRB/' live. 42 - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data. 43 - **batch_size** (`int`): The batch size in train, val, test dataloader. 44 - **num_workers** (`int`): the number of workers for dataloaders. 45 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 46 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 47 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 48 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 49 """ 50 super().__init__( 51 root=root, 52 batch_size=batch_size, 53 num_workers=num_workers, 54 custom_transforms=custom_transforms, 55 repeat_channels=repeat_channels, 56 to_tensor=to_tensor, 57 resize=resize, 58 ) 59 60 self.validation_percentage: float = validation_percentage 61 r"""The percentage to randomly split some training data into validation data.""" 62 63 def prepare_data(self) -> None: 64 r"""Download the original GTSRB dataset if haven't.""" 65 66 GTSRBRaw(root=self.root, split="train", download=True) 67 GTSRBRaw(root=self.root, split="test", download=True) 68 69 pylogger.debug( 70 "The original GTSRB dataset has been downloaded to %s.", self.root 71 ) 72 73 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 74 """Get the training and validation dataset. 75 76 **Returns:** 77 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 78 """ 79 dataset_train_and_val = GTSRBRaw( 80 root=self.root, 81 split="train", 82 transform=self.train_and_val_transforms(), 83 target_transform=self.target_transform(), 84 download=False, 85 ) 86 87 return random_split( 88 dataset_train_and_val, 89 lengths=[1 - self.validation_percentage, self.validation_percentage], 90 generator=torch.Generator().manual_seed( 91 42 92 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 93 ) 94 95 def test_dataset(self) -> Dataset: 96 r"""Get the test dataset. 97 98 **Returns:** 99 - **test_dataset** (`Dataset`): the test dataset. 100 """ 101 dataset_test = GTSRBRaw( 102 root=self.root, 103 split="test", 104 transform=self.test_transforms(), 105 target_transform=self.target_transform(), 106 download=False, 107 ) 108 109 return dataset_test
GTSRB dataset. The GTSRB dataset is a collection of traffic sign images. It consists of 51,839 images of 43 different traffic signs (classes), each color image.
GTSRB( root: str, validation_percentage: float, batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, repeat_channels: int | None = None, to_tensor: bool = True, resize: tuple[int, int] | None = None)
28 def __init__( 29 self, 30 root: str, 31 validation_percentage: float, 32 batch_size: int = 1, 33 num_workers: int = 0, 34 custom_transforms: Callable | transforms.Compose | None = None, 35 repeat_channels: int | None = None, 36 to_tensor: bool = True, 37 resize: tuple[int, int] | None = None, 38 ) -> None: 39 r""" 40 **Args:** 41 - **root** (`str`): the root directory where the original GTSRB data 'GTSRB/' live. 42 - **validation_percentage** (`float`): the percentage to randomly split some training data into validation data. 43 - **batch_size** (`int`): The batch size in train, val, test dataloader. 44 - **num_workers** (`int`): the number of workers for dataloaders. 45 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 46 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 47 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 48 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 49 """ 50 super().__init__( 51 root=root, 52 batch_size=batch_size, 53 num_workers=num_workers, 54 custom_transforms=custom_transforms, 55 repeat_channels=repeat_channels, 56 to_tensor=to_tensor, 57 resize=resize, 58 ) 59 60 self.validation_percentage: float = validation_percentage 61 r"""The percentage to randomly split some training data into validation data."""
Args:
- root (
str): the root directory where the original GTSRB data 'GTSRB/' live. - validation_percentage (
float): the percentage to randomly split some training data into validation data. - batch_size (
int): The batch size in train, val, test dataloader. - num_workers (
int): the number of workers for dataloaders. - custom_transforms (
transformortransforms.ComposeorNone): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor(), normalize and so on are not included. - repeat_channels (
int|None): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. - to_tensor (
bool): whether to includeToTensor()transform. Default is True. - resize (
tuple[int, int]|Noneor list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset] =
<class 'torchvision.datasets.gtsrb.GTSRB'>
The original dataset class.
validation_percentage: float
The percentage to randomly split some training data into validation data.
def
prepare_data(self) -> None:
63 def prepare_data(self) -> None: 64 r"""Download the original GTSRB dataset if haven't.""" 65 66 GTSRBRaw(root=self.root, split="train", download=True) 67 GTSRBRaw(root=self.root, split="test", download=True) 68 69 pylogger.debug( 70 "The original GTSRB dataset has been downloaded to %s.", self.root 71 )
Download the original GTSRB dataset if haven't.
def
train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
73 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 74 """Get the training and validation dataset. 75 76 **Returns:** 77 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 78 """ 79 dataset_train_and_val = GTSRBRaw( 80 root=self.root, 81 split="train", 82 transform=self.train_and_val_transforms(), 83 target_transform=self.target_transform(), 84 download=False, 85 ) 86 87 return random_split( 88 dataset_train_and_val, 89 lengths=[1 - self.validation_percentage, self.validation_percentage], 90 generator=torch.Generator().manual_seed( 91 42 92 ), # this must be set fixed to make sure the datasets across experiments are the same. Don't handle it to global seed as it might vary across experiments 93 )
Get the training and validation dataset.
Returns:
- train_and_val_dataset (
tuple[Dataset, Dataset]): the train and validation dataset.
def
test_dataset(self) -> torch.utils.data.dataset.Dataset:
95 def test_dataset(self) -> Dataset: 96 r"""Get the test dataset. 97 98 **Returns:** 99 - **test_dataset** (`Dataset`): the test dataset. 100 """ 101 dataset_test = GTSRBRaw( 102 root=self.root, 103 split="test", 104 transform=self.test_transforms(), 105 target_transform=self.target_transform(), 106 download=False, 107 ) 108 109 return dataset_test
Get the test dataset.
Returns:
- test_dataset (
Dataset): the test dataset.