clarena.stl_datasets.renderedsst2
The submodule in stl_datasets for Rendered SST2 dataset.
1r""" 2The submodule in `stl_datasets` for Rendered SST2 dataset. 3""" 4 5__all__ = ["RenderedSST2"] 6 7import logging 8from typing import Callable 9 10import torch 11from torch.utils.data import Dataset, random_split 12from torchvision.datasets import RenderedSST2 as RenderedSST2Raw 13from torchvision.transforms import transforms 14 15from clarena.stl_datasets.base import STLDatasetFromRaw 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class RenderedSST2(STLDatasetFromRaw): 22 r"""Rendered SST2 dataset. The [Rendered SST2 dataset](https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md) is a collection of optical character recognition images. It consists of 9,613 images in 2 classes (positive and negative sentiment), each 448x448 color image.""" 23 24 original_dataset_python_class: type[Dataset] = RenderedSST2Raw 25 r"""The original dataset class.""" 26 27 def __init__( 28 self, 29 root: str, 30 batch_size: int = 1, 31 num_workers: int = 0, 32 custom_transforms: Callable | transforms.Compose | None = None, 33 repeat_channels: int | None = None, 34 to_tensor: bool = True, 35 resize: tuple[int, int] | None = None, 36 ) -> None: 37 r""" 38 **Args:** 39 - **root** (`str`): the root directory where the original Rendered SST2 data 'RenderedSST2/' live. 40 - **batch_size** (`int`): The batch size in train, val, test dataloader. 41 - **num_workers** (`int`): the number of workers for dataloaders. 42 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 43 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 44 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 45 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 46 """ 47 super().__init__( 48 root=root, 49 batch_size=batch_size, 50 num_workers=num_workers, 51 custom_transforms=custom_transforms, 52 repeat_channels=repeat_channels, 53 to_tensor=to_tensor, 54 resize=resize, 55 ) 56 57 def prepare_data(self) -> None: 58 r"""Download the original Rendered SST2 dataset if haven't.""" 59 60 RenderedSST2Raw(root=self.root, split="train", download=True) 61 RenderedSST2Raw(root=self.root, split="val", download=True) 62 RenderedSST2Raw(root=self.root, split="test", download=True) 63 64 pylogger.debug( 65 "The original Rendered SST2 dataset has been downloaded to %s.", 66 self.root, 67 ) 68 69 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 70 """Get the training and validation dataset. 71 72 **Returns:** 73 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 74 """ 75 dataset_train = RenderedSST2Raw( 76 root=self.root, 77 split="train", 78 transform=self.train_and_val_transforms(), 79 target_transform=self.target_transform(), 80 download=False, 81 ) 82 83 dataset_val = RenderedSST2Raw( 84 root=self.root, 85 split="val", 86 transform=self.train_and_val_transforms(), 87 target_transform=self.target_transform(), 88 download=False, 89 ) 90 91 return dataset_train, dataset_val 92 93 def test_dataset(self) -> Dataset: 94 r"""Get the test dataset. 95 96 **Returns:** 97 - **test_dataset** (`Dataset`): the test dataset. 98 """ 99 dataset_test = RenderedSST2Raw( 100 root=self.root, 101 split="test", 102 transform=self.test_transforms(), 103 target_transform=self.target_transform(), 104 download=False, 105 ) 106 107 return dataset_test
class
RenderedSST2(clarena.stl_datasets.base.STLDatasetFromRaw):
22class RenderedSST2(STLDatasetFromRaw): 23 r"""Rendered SST2 dataset. The [Rendered SST2 dataset](https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md) is a collection of optical character recognition images. It consists of 9,613 images in 2 classes (positive and negative sentiment), each 448x448 color image.""" 24 25 original_dataset_python_class: type[Dataset] = RenderedSST2Raw 26 r"""The original dataset class.""" 27 28 def __init__( 29 self, 30 root: str, 31 batch_size: int = 1, 32 num_workers: int = 0, 33 custom_transforms: Callable | transforms.Compose | None = None, 34 repeat_channels: int | None = None, 35 to_tensor: bool = True, 36 resize: tuple[int, int] | None = None, 37 ) -> None: 38 r""" 39 **Args:** 40 - **root** (`str`): the root directory where the original Rendered SST2 data 'RenderedSST2/' live. 41 - **batch_size** (`int`): The batch size in train, val, test dataloader. 42 - **num_workers** (`int`): the number of workers for dataloaders. 43 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 44 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 45 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 46 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 47 """ 48 super().__init__( 49 root=root, 50 batch_size=batch_size, 51 num_workers=num_workers, 52 custom_transforms=custom_transforms, 53 repeat_channels=repeat_channels, 54 to_tensor=to_tensor, 55 resize=resize, 56 ) 57 58 def prepare_data(self) -> None: 59 r"""Download the original Rendered SST2 dataset if haven't.""" 60 61 RenderedSST2Raw(root=self.root, split="train", download=True) 62 RenderedSST2Raw(root=self.root, split="val", download=True) 63 RenderedSST2Raw(root=self.root, split="test", download=True) 64 65 pylogger.debug( 66 "The original Rendered SST2 dataset has been downloaded to %s.", 67 self.root, 68 ) 69 70 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 71 """Get the training and validation dataset. 72 73 **Returns:** 74 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 75 """ 76 dataset_train = RenderedSST2Raw( 77 root=self.root, 78 split="train", 79 transform=self.train_and_val_transforms(), 80 target_transform=self.target_transform(), 81 download=False, 82 ) 83 84 dataset_val = RenderedSST2Raw( 85 root=self.root, 86 split="val", 87 transform=self.train_and_val_transforms(), 88 target_transform=self.target_transform(), 89 download=False, 90 ) 91 92 return dataset_train, dataset_val 93 94 def test_dataset(self) -> Dataset: 95 r"""Get the test dataset. 96 97 **Returns:** 98 - **test_dataset** (`Dataset`): the test dataset. 99 """ 100 dataset_test = RenderedSST2Raw( 101 root=self.root, 102 split="test", 103 transform=self.test_transforms(), 104 target_transform=self.target_transform(), 105 download=False, 106 ) 107 108 return dataset_test
Rendered SST2 dataset. The Rendered SST2 dataset is a collection of optical character recognition images. It consists of 9,613 images in 2 classes (positive and negative sentiment), each 448x448 color image.
RenderedSST2( root: str, batch_size: int = 1, num_workers: int = 0, custom_transforms: Union[Callable, torchvision.transforms.transforms.Compose, NoneType] = None, repeat_channels: int | None = None, to_tensor: bool = True, resize: tuple[int, int] | None = None)
28 def __init__( 29 self, 30 root: str, 31 batch_size: int = 1, 32 num_workers: int = 0, 33 custom_transforms: Callable | transforms.Compose | None = None, 34 repeat_channels: int | None = None, 35 to_tensor: bool = True, 36 resize: tuple[int, int] | None = None, 37 ) -> None: 38 r""" 39 **Args:** 40 - **root** (`str`): the root directory where the original Rendered SST2 data 'RenderedSST2/' live. 41 - **batch_size** (`int`): The batch size in train, val, test dataloader. 42 - **num_workers** (`int`): the number of workers for dataloaders. 43 - **custom_transforms** (`transform` or `transforms.Compose` or `None`): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform. `ToTensor()`, normalize and so on are not included. 44 - **repeat_channels** (`int` | `None`): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. 45 - **to_tensor** (`bool`): whether to include `ToTensor()` transform. Default is True. 46 - **resize** (`tuple[int, int]` | `None` or list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers. 47 """ 48 super().__init__( 49 root=root, 50 batch_size=batch_size, 51 num_workers=num_workers, 52 custom_transforms=custom_transforms, 53 repeat_channels=repeat_channels, 54 to_tensor=to_tensor, 55 resize=resize, 56 )
Args:
- root (
str): the root directory where the original Rendered SST2 data 'RenderedSST2/' live. - batch_size (
int): The batch size in train, val, test dataloader. - num_workers (
int): the number of workers for dataloaders. - custom_transforms (
transformortransforms.ComposeorNone): the custom transforms to apply to ONLY TRAIN dataset. Can be a single transform, composed transforms or no transform.ToTensor(), normalize and so on are not included. - repeat_channels (
int|None): the number of channels to repeat. Default is None, which means no repeat. If not None, it should be an integer. - to_tensor (
bool): whether to includeToTensor()transform. Default is True. - resize (
tuple[int, int]|Noneor list of them): the size to resize the images to. Default is None, which means no resize. If not None, it should be a tuple of two integers.
original_dataset_python_class: type[torch.utils.data.dataset.Dataset] =
<class 'torchvision.datasets.rendered_sst2.RenderedSST2'>
The original dataset class.
def
prepare_data(self) -> None:
58 def prepare_data(self) -> None: 59 r"""Download the original Rendered SST2 dataset if haven't.""" 60 61 RenderedSST2Raw(root=self.root, split="train", download=True) 62 RenderedSST2Raw(root=self.root, split="val", download=True) 63 RenderedSST2Raw(root=self.root, split="test", download=True) 64 65 pylogger.debug( 66 "The original Rendered SST2 dataset has been downloaded to %s.", 67 self.root, 68 )
Download the original Rendered SST2 dataset if haven't.
def
train_and_val_dataset( self) -> tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]:
70 def train_and_val_dataset(self) -> tuple[Dataset, Dataset]: 71 """Get the training and validation dataset. 72 73 **Returns:** 74 - **train_and_val_dataset** (`tuple[Dataset, Dataset]`): the train and validation dataset. 75 """ 76 dataset_train = RenderedSST2Raw( 77 root=self.root, 78 split="train", 79 transform=self.train_and_val_transforms(), 80 target_transform=self.target_transform(), 81 download=False, 82 ) 83 84 dataset_val = RenderedSST2Raw( 85 root=self.root, 86 split="val", 87 transform=self.train_and_val_transforms(), 88 target_transform=self.target_transform(), 89 download=False, 90 ) 91 92 return dataset_train, dataset_val
Get the training and validation dataset.
Returns:
- train_and_val_dataset (
tuple[Dataset, Dataset]): the train and validation dataset.
def
test_dataset(self) -> torch.utils.data.dataset.Dataset:
94 def test_dataset(self) -> Dataset: 95 r"""Get the test dataset. 96 97 **Returns:** 98 - **test_dataset** (`Dataset`): the test dataset. 99 """ 100 dataset_test = RenderedSST2Raw( 101 root=self.root, 102 split="test", 103 transform=self.test_transforms(), 104 target_transform=self.target_transform(), 105 download=False, 106 ) 107 108 return dataset_test
Get the test dataset.
Returns:
- test_dataset (
Dataset): the test dataset.