clarena.mtl_datasets
Multi-Task Learning Datasets
This submodule provides the multi-task learning datasets that can be used in CLArena.
Here are the base classes for multi-task learning datasets, which inherit from Lightning LightningDataModule:
MTLDataset: The base class for all multi-task learning datasets.MTLCombinedDataset: The base class for combined multi-task learning datasets. A child class ofMTLDataset.MTLDatasetFromCL: The base class for constructing multi-task learning datasets from continual learning datasets. A child class ofMTLDataset.
Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement MTL datasets:
1r""" 2 3# Multi-Task Learning Datasets 4 5This submodule provides the **multi-task learning datasets** that can be used in CLArena. 6 7Here are the base classes for multi-task learning datasets, which inherit from Lightning `LightningDataModule`: 8 9- `MTLDataset`: The base class for all multi-task learning datasets. 10 - `MTLCombinedDataset`: The base class for combined multi-task learning datasets. A child class of `MTLDataset`. 11 - `MTLDatasetFromCL`: The base class for constructing multi-task learning datasets from continual learning datasets. A child class of `MTLDataset`. 12 13Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement MTL datasets: 14 15- [**Configure MTL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/mtl-dataset) 16- [**Implement Custom MTL Dataset**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/mtl_dataset) 17 18 19 20""" 21 22from .base import MTLDataset, MTLCombinedDataset, MTLDatasetFromCL 23 24from .combined import Combined 25 26 27__all__ = ["MTLDataset", "MTLCombinedDataset", "MTLDatasetFromCL", "combined"]
32class MTLDataset(LightningDataModule): 33 r"""The base class of multi-task learning datasets.""" 34 35 def __init__( 36 self, 37 root: str | dict[int, str], 38 num_tasks: int, 39 sampling_strategy: str = "mixed", 40 batch_size: int = 1, 41 num_workers: int = 0, 42 custom_transforms: ( 43 Callable 44 | transforms.Compose 45 | None 46 | dict[int, Callable | transforms.Compose | None] 47 ) = None, 48 repeat_channels: int | None | dict[int, int | None] = None, 49 to_tensor: bool | dict[int, bool] = True, 50 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 51 ) -> None: 52 r""" 53 **Args:** 54 - **root** (`str` | `list[str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `list[str]`, it should be a list of strings, each string is the root directory for each task. 55 - **num_tasks** (`int`): the maximum number of tasks supported by the MTL dataset. 56 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 57 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 58 - **batch_size** (`int`): The batch size in train, val, test dataloader. 59 - **num_workers** (`int`): the number of workers for dataloaders. 60 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included. 61 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 62 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 63 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 64 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 65 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 66 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 67 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 68 """ 69 super().__init__() 70 71 self.root: dict[int, str] = ( 72 OmegaConf.to_container(root) 73 if isinstance(root, DictConfig) 74 else {t: root for t in range(1, num_tasks + 1)} 75 ) 76 r"""The dict of root directories of the original data files for each task.""" 77 self.num_tasks: int = num_tasks 78 r"""The maximum number of tasks supported by the dataset.""" 79 self.sampling_strategy: str = sampling_strategy 80 r"""The sampling strategy for constructing training batch from each task's dataset.""" 81 self.batch_size: int = batch_size 82 r"""The batch size for dataloaders.""" 83 self.num_workers: int = num_workers 84 r"""The number of workers for dataloaders.""" 85 86 self.custom_transforms: dict[int, Callable | transforms.Compose | None] = ( 87 OmegaConf.to_container(custom_transforms) 88 if isinstance(custom_transforms, dict) 89 else {t: custom_transforms for t in range(1, num_tasks + 1)} 90 ) 91 r"""The dict of custom transforms for each task.""" 92 self.repeat_channels: dict[int, int | None] = ( 93 OmegaConf.to_container(repeat_channels) 94 if isinstance(repeat_channels, dict) 95 else {t: repeat_channels for t in range(1, num_tasks + 1)} 96 ) 97 r"""The dict of number of channels to repeat for each task.""" 98 self.to_tensor: dict[int, bool] = ( 99 OmegaConf.to_container(to_tensor) 100 if isinstance(to_tensor, dict) 101 else {t: to_tensor for t in range(1, num_tasks + 1)} 102 ) 103 r"""The dict of to_tensor flag for each task. """ 104 self.resize: dict[int, tuple[int, int] | None] = ( 105 {t: tuple(rs) if rs else None for t, rs in resize.items()} 106 if isinstance(resize, DictConfig) 107 else { 108 t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1) 109 } 110 ) 111 r"""The dict of sizes to resize to for each task.""" 112 113 # dataset containers 114 self.dataset_train: dict[int, Any] = {} 115 r"""The dictionary to store training dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects. 116 117 Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary.""" 118 self.dataset_val: dict[int, Any] = {} 119 r"""The dictionary to store validation dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects. 120 121 Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary.""" 122 self.dataset_test: dict[int, Any] = {} 123 r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects. 124 125 Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary.""" 126 127 self.mean: dict[int, float] = {} 128 r"""Tthe list of mean values for normalization for all tasks. Used when constructing the transforms.""" 129 self.std: dict[int, float] = {} 130 r"""The list of standard deviation values for normalization for all tasks. Used when constructing the transforms.""" 131 132 # task ID controls 133 self.train_tasks: list[int] 134 r""""The list of task IDs to be trained. It should be a list of integers, each integer is the task ID.""" 135 self.eval_tasks: list[int] 136 r"""The list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID.""" 137 138 MTLDataset.sanity_check(self) 139 140 def sanity_check(self) -> None: 141 r"""Sanity check.""" 142 for attr in [ 143 "root", 144 "custom_transforms", 145 "repeat_channels", 146 "to_tensor", 147 "resize", 148 ]: 149 value = getattr(self, attr) 150 expected_keys = set(range(1, self.num_tasks + 1)) 151 if set(value.keys()) != expected_keys: 152 raise ValueError( 153 f"{attr} dict keys must be consecutive integers from 1 to num_tasks." 154 ) 155 156 @abstractmethod 157 def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]: 158 r"""Get the mapping of classes of task `task_id` to fit multi-task learning. It must be implemented by subclasses. 159 160 **Args:** 161 - **task_id** (`int`): The task ID to query class map. 162 163 **Returns:** 164 - **class_map**(`dict[str | int, int]`): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes. 165 """ 166 167 @abstractmethod 168 def prepare_data(self) -> None: 169 r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDatamodule`.""" 170 171 def setup(self, stage: str) -> None: 172 r"""Set up the dataset for different stages. 173 174 **Args:** 175 - **stage** (`str`): the stage of the experiment; one of: 176 - 'fit': training and validation dataset should be assigned to `self.dataset_train` and `self.dataset_val`. 177 - 'test': test dataset should be assigned to `self.dataset_test`. 178 """ 179 if stage == "fit": 180 # these two stages must be done together because a sanity check for validation is conducted before training 181 pylogger.debug("Construct train and validation dataset ...") 182 183 for task_id in self.train_tasks: 184 185 self.dataset_train[task_id], self.dataset_val[task_id] = ( 186 self.train_and_val_dataset(task_id) 187 ) 188 189 pylogger.info( 190 "Train and validation dataset for task %d are ready.", task_id 191 ) 192 pylogger.info( 193 "Train dataset for task %d size: %d", 194 task_id, 195 len(self.dataset_train[task_id]), 196 ) 197 pylogger.info( 198 "Validation dataset for task %d size: %d", 199 task_id, 200 len(self.dataset_val[task_id]), 201 ) 202 203 elif stage == "test": 204 205 pylogger.debug("Construct test dataset ...") 206 207 for task_id in self.eval_tasks: 208 209 self.dataset_test[task_id] = self.test_dataset(task_id) 210 211 pylogger.info("Test dataset for task %d are ready.", task_id) 212 pylogger.info( 213 "Test dataset for task %d size: %d", 214 task_id, 215 len(self.dataset_test[task_id]), 216 ) 217 218 def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None: 219 r"""Set up tasks for the multi-task learning experiment. 220 221 **Args:** 222 - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the train/val dataloader. 223 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the test dataloader. 224 """ 225 self.train_tasks = train_tasks 226 self.eval_tasks = eval_tasks 227 228 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 229 r"""Set up evaluation tasks for the multi-task learning evaluation. 230 231 **Args:** 232 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated.""" 233 self.eval_tasks = eval_tasks 234 235 def train_and_val_transforms(self, task_id: int) -> transforms.Compose: 236 r"""Transforms for train and validation datasets of task `task_id`, incorporating the custom transforms with basic transforms like `normalization` and `ToTensor()`. It can be used in subclasses when constructing the dataset. 237 238 **Args:** 239 - **task_id** (`int`): the task ID of training and validation dataset to get the transforms for. 240 241 **Returns:** 242 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 243 """ 244 repeat_channels_transform = ( 245 transforms.Grayscale(num_output_channels=self.repeat_channels[task_id]) 246 if self.repeat_channels[task_id] is not None 247 else None 248 ) 249 to_tensor_transform = transforms.ToTensor() if self.to_tensor[task_id] else None 250 resize_transform = ( 251 transforms.Resize(self.resize[task_id]) 252 if self.resize[task_id] is not None 253 else None 254 ) 255 normalization_transform = transforms.Normalize( 256 self.mean[task_id], self.std[task_id] 257 ) 258 259 return transforms.Compose( 260 list( 261 filter( 262 None, 263 [ 264 repeat_channels_transform, 265 to_tensor_transform, 266 resize_transform, 267 self.custom_transforms[task_id], 268 normalization_transform, 269 ], 270 ) 271 ) 272 ) # the order of transforms matters 273 274 def test_transforms(self, task_id: int) -> transforms.Compose: 275 r"""Transforms for test dataset of task `task_id`. Only basic transforms like `normalization` and `ToTensor()` are included. It can be used in subclasses when constructing the dataset. 276 277 **Args:** 278 - **task_id** (`int`): the task ID of test dataset to get the transforms for. 279 280 **Returns:** 281 - **test_transforms** (`transforms.Compose`): the composed test transforms. 282 """ 283 284 repeat_channels_transform = ( 285 transforms.Grayscale(num_output_channels=self.repeat_channels[task_id]) 286 if self.repeat_channels[task_id] is not None 287 else None 288 ) 289 to_tensor_transform = transforms.ToTensor() if self.to_tensor[task_id] else None 290 resize_transform = ( 291 transforms.Resize(self.resize[task_id]) 292 if self.resize[task_id] is not None 293 else None 294 ) 295 normalization_transform = transforms.Normalize( 296 self.mean[task_id], self.std[task_id] 297 ) 298 299 return transforms.Compose( 300 list( 301 filter( 302 None, 303 [ 304 repeat_channels_transform, 305 to_tensor_transform, 306 resize_transform, 307 normalization_transform, 308 ], 309 ) 310 ) 311 ) # the order of transforms matters. No custom transforms for test 312 313 def target_transform(self, task_id: int) -> Callable: 314 r"""Target transform for task `task_id`, which maps the original class labels to the integer class labels for multi-task learning. It can be used in subclasses when constructing the dataset. 315 316 **Args:** 317 - **task_id** (`int`): the task ID of dataset to get the target transform for. 318 319 **Returns:** 320 - **target_transform** (`Callable`): the target transform function. 321 """ 322 323 class_map = self.get_mtl_class_map(task_id) 324 325 target_transform = ClassMapping(class_map=class_map) 326 return target_transform 327 328 @abstractmethod 329 def train_and_val_dataset(self, task_id: int) -> tuple[Any, Any]: 330 r"""Get the training and validation dataset of task `task_id`. It must be implemented by subclasses. 331 332 **Args:** 333 - **task_id** (`int`): the task ID to get the training and validation dataset for. 334 335 **Returns:** 336 - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation dataset of task `task_id`. 337 """ 338 339 @abstractmethod 340 def test_dataset(self, task_id: int) -> Any: 341 """Get the test dataset of task `task_id`. It must be implemented by subclasses. 342 343 **Args:** 344 - **task_id** (`int`): the task ID to get the test dataset for. 345 346 **Returns:** 347 - **test_dataset** (`Any`): the test dataset of task `task_id`. 348 """ 349 350 def train_dataloader(self) -> DataLoader: 351 r"""DataLoader generator for stage train. It is automatically called before training. 352 353 **Returns:** 354 - **train_dataloader** (`DataLoader`): the train DataLoader of task. 355 """ 356 357 pylogger.debug( 358 "Construct train dataloader ... sampling_strategy method: %s", 359 self.sampling_strategy, 360 ) 361 362 if self.sampling_strategy == "mixed": 363 # mixed sampling strategy, which samples from all tasks' datasets 364 365 concatenated_dataset = ConcatDataset( 366 [self.dataset_train[task_id] for task_id in self.train_tasks] 367 ) 368 369 return DataLoader( 370 dataset=concatenated_dataset, 371 batch_size=self.batch_size, 372 shuffle=True, # shuffle train batch to prevent overfitting 373 num_workers=self.num_workers, 374 drop_last=True, # to avoid batchnorm error (when batch_size is 1) 375 ) 376 377 def val_dataloader(self) -> DataLoader: 378 r"""DataLoader generator for the validation stage. It is automatically called before validation. 379 380 **Returns:** 381 - **val_dataloader** (`dict[int, DataLoader]`): the validation DataLoader. 382 """ 383 384 pylogger.debug("Construct validation dataloader...") 385 386 return { 387 task_id: DataLoader( 388 dataset=dataset_val_t, 389 batch_size=self.batch_size, 390 shuffle=False, # don't have to shuffle val or test batch 391 num_workers=self.num_workers, 392 ) 393 for task_id, dataset_val_t in self.dataset_val.items() 394 } 395 396 def test_dataloader(self) -> dict[int, DataLoader]: 397 r"""DataLoader generator for stage test. It is automatically called before testing. 398 399 **Returns:** 400 - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader. 401 """ 402 403 pylogger.debug("Construct test dataloader...") 404 405 return { 406 task_id: DataLoader( 407 dataset=dataset_test_t, 408 batch_size=self.batch_size, 409 shuffle=False, # don't have to shuffle val or test batch 410 num_workers=self.num_workers, 411 ) 412 for task_id, dataset_test_t in self.dataset_test.items() 413 }
The base class of multi-task learning datasets.
35 def __init__( 36 self, 37 root: str | dict[int, str], 38 num_tasks: int, 39 sampling_strategy: str = "mixed", 40 batch_size: int = 1, 41 num_workers: int = 0, 42 custom_transforms: ( 43 Callable 44 | transforms.Compose 45 | None 46 | dict[int, Callable | transforms.Compose | None] 47 ) = None, 48 repeat_channels: int | None | dict[int, int | None] = None, 49 to_tensor: bool | dict[int, bool] = True, 50 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 51 ) -> None: 52 r""" 53 **Args:** 54 - **root** (`str` | `list[str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `list[str]`, it should be a list of strings, each string is the root directory for each task. 55 - **num_tasks** (`int`): the maximum number of tasks supported by the MTL dataset. 56 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 57 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 58 - **batch_size** (`int`): The batch size in train, val, test dataloader. 59 - **num_workers** (`int`): the number of workers for dataloaders. 60 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization and so on are not included. 61 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 62 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 63 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 64 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 65 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 66 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. 67 If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 68 """ 69 super().__init__() 70 71 self.root: dict[int, str] = ( 72 OmegaConf.to_container(root) 73 if isinstance(root, DictConfig) 74 else {t: root for t in range(1, num_tasks + 1)} 75 ) 76 r"""The dict of root directories of the original data files for each task.""" 77 self.num_tasks: int = num_tasks 78 r"""The maximum number of tasks supported by the dataset.""" 79 self.sampling_strategy: str = sampling_strategy 80 r"""The sampling strategy for constructing training batch from each task's dataset.""" 81 self.batch_size: int = batch_size 82 r"""The batch size for dataloaders.""" 83 self.num_workers: int = num_workers 84 r"""The number of workers for dataloaders.""" 85 86 self.custom_transforms: dict[int, Callable | transforms.Compose | None] = ( 87 OmegaConf.to_container(custom_transforms) 88 if isinstance(custom_transforms, dict) 89 else {t: custom_transforms for t in range(1, num_tasks + 1)} 90 ) 91 r"""The dict of custom transforms for each task.""" 92 self.repeat_channels: dict[int, int | None] = ( 93 OmegaConf.to_container(repeat_channels) 94 if isinstance(repeat_channels, dict) 95 else {t: repeat_channels for t in range(1, num_tasks + 1)} 96 ) 97 r"""The dict of number of channels to repeat for each task.""" 98 self.to_tensor: dict[int, bool] = ( 99 OmegaConf.to_container(to_tensor) 100 if isinstance(to_tensor, dict) 101 else {t: to_tensor for t in range(1, num_tasks + 1)} 102 ) 103 r"""The dict of to_tensor flag for each task. """ 104 self.resize: dict[int, tuple[int, int] | None] = ( 105 {t: tuple(rs) if rs else None for t, rs in resize.items()} 106 if isinstance(resize, DictConfig) 107 else { 108 t: (tuple(resize) if resize else None) for t in range(1, num_tasks + 1) 109 } 110 ) 111 r"""The dict of sizes to resize to for each task.""" 112 113 # dataset containers 114 self.dataset_train: dict[int, Any] = {} 115 r"""The dictionary to store training dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects. 116 117 Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary.""" 118 self.dataset_val: dict[int, Any] = {} 119 r"""The dictionary to store validation dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects. 120 121 Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary.""" 122 self.dataset_test: dict[int, Any] = {} 123 r"""The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects. 124 125 Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use `TaskLabelledDataset` wrapper if necessary.""" 126 127 self.mean: dict[int, float] = {} 128 r"""Tthe list of mean values for normalization for all tasks. Used when constructing the transforms.""" 129 self.std: dict[int, float] = {} 130 r"""The list of standard deviation values for normalization for all tasks. Used when constructing the transforms.""" 131 132 # task ID controls 133 self.train_tasks: list[int] 134 r""""The list of task IDs to be trained. It should be a list of integers, each integer is the task ID.""" 135 self.eval_tasks: list[int] 136 r"""The list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID.""" 137 138 MTLDataset.sanity_check(self)
Args:
- root (
str|list[str]): the root directory where the original data files for constructing the MTL dataset physically live. Iflist[str], it should be a list of strings, each string is the root directory for each task. - num_tasks (
int): the maximum number of tasks supported by the MTL dataset. - sampling_strategy (
str): the sampling strategy that construct training batch from each task's dataset; one of:- 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
- batch_size (
int): The batch size in train, val, test dataloader. - num_workers (
int): the number of workers for dataloaders. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
The sampling strategy for constructing training batch from each task's dataset.
The dict of custom transforms for each task.
The dictionary to store training dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use TaskLabelledDataset wrapper if necessary.
The dictionary to store validation dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use TaskLabelledDataset wrapper if necessary.
The dictionary to store test dataset object of each task. Keys are task IDs and values are the dataset objects. Can be PyTorch Dataset objects or any other dataset objects.
Note that they must be task labelled, i.e., the elements in the dataset objects must be tuples of (input, target, task_id). Use TaskLabelledDataset wrapper if necessary.
Tthe list of mean values for normalization for all tasks. Used when constructing the transforms.
The list of standard deviation values for normalization for all tasks. Used when constructing the transforms.
"The list of task IDs to be trained. It should be a list of integers, each integer is the task ID.
The list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID.
140 def sanity_check(self) -> None: 141 r"""Sanity check.""" 142 for attr in [ 143 "root", 144 "custom_transforms", 145 "repeat_channels", 146 "to_tensor", 147 "resize", 148 ]: 149 value = getattr(self, attr) 150 expected_keys = set(range(1, self.num_tasks + 1)) 151 if set(value.keys()) != expected_keys: 152 raise ValueError( 153 f"{attr} dict keys must be consecutive integers from 1 to num_tasks." 154 )
Sanity check.
156 @abstractmethod 157 def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]: 158 r"""Get the mapping of classes of task `task_id` to fit multi-task learning. It must be implemented by subclasses. 159 160 **Args:** 161 - **task_id** (`int`): The task ID to query class map. 162 163 **Returns:** 164 - **class_map**(`dict[str | int, int]`): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes. 165 """
Get the mapping of classes of task task_id to fit multi-task learning. It must be implemented by subclasses.
Args:
- task_id (
int): The task ID to query class map.
Returns:
- class_map(
dict[str | int, int]): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes.
167 @abstractmethod 168 def prepare_data(self) -> None: 169 r"""Use this to download and prepare data. It must be implemented by subclasses, as required by `LightningDatamodule`."""
Use this to download and prepare data. It must be implemented by subclasses, as required by LightningDatamodule.
171 def setup(self, stage: str) -> None: 172 r"""Set up the dataset for different stages. 173 174 **Args:** 175 - **stage** (`str`): the stage of the experiment; one of: 176 - 'fit': training and validation dataset should be assigned to `self.dataset_train` and `self.dataset_val`. 177 - 'test': test dataset should be assigned to `self.dataset_test`. 178 """ 179 if stage == "fit": 180 # these two stages must be done together because a sanity check for validation is conducted before training 181 pylogger.debug("Construct train and validation dataset ...") 182 183 for task_id in self.train_tasks: 184 185 self.dataset_train[task_id], self.dataset_val[task_id] = ( 186 self.train_and_val_dataset(task_id) 187 ) 188 189 pylogger.info( 190 "Train and validation dataset for task %d are ready.", task_id 191 ) 192 pylogger.info( 193 "Train dataset for task %d size: %d", 194 task_id, 195 len(self.dataset_train[task_id]), 196 ) 197 pylogger.info( 198 "Validation dataset for task %d size: %d", 199 task_id, 200 len(self.dataset_val[task_id]), 201 ) 202 203 elif stage == "test": 204 205 pylogger.debug("Construct test dataset ...") 206 207 for task_id in self.eval_tasks: 208 209 self.dataset_test[task_id] = self.test_dataset(task_id) 210 211 pylogger.info("Test dataset for task %d are ready.", task_id) 212 pylogger.info( 213 "Test dataset for task %d size: %d", 214 task_id, 215 len(self.dataset_test[task_id]), 216 )
Set up the dataset for different stages.
Args:
- stage (
str): the stage of the experiment; one of:- 'fit': training and validation dataset should be assigned to
self.dataset_trainandself.dataset_val. - 'test': test dataset should be assigned to
self.dataset_test.
- 'fit': training and validation dataset should be assigned to
218 def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None: 219 r"""Set up tasks for the multi-task learning experiment. 220 221 **Args:** 222 - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the train/val dataloader. 223 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the test dataloader. 224 """ 225 self.train_tasks = train_tasks 226 self.eval_tasks = eval_tasks
Set up tasks for the multi-task learning experiment.
Args:
- train_tasks (
list[int]): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the train/val dataloader. - eval_tasks (
list[int]): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the test dataloader.
228 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 229 r"""Set up evaluation tasks for the multi-task learning evaluation. 230 231 **Args:** 232 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated.""" 233 self.eval_tasks = eval_tasks
Set up evaluation tasks for the multi-task learning evaluation.
Args:
- eval_tasks (
list[int]): the list of task IDs to be evaluated.
235 def train_and_val_transforms(self, task_id: int) -> transforms.Compose: 236 r"""Transforms for train and validation datasets of task `task_id`, incorporating the custom transforms with basic transforms like `normalization` and `ToTensor()`. It can be used in subclasses when constructing the dataset. 237 238 **Args:** 239 - **task_id** (`int`): the task ID of training and validation dataset to get the transforms for. 240 241 **Returns:** 242 - **train_and_val_transforms** (`transforms.Compose`): the composed train/val transforms. 243 """ 244 repeat_channels_transform = ( 245 transforms.Grayscale(num_output_channels=self.repeat_channels[task_id]) 246 if self.repeat_channels[task_id] is not None 247 else None 248 ) 249 to_tensor_transform = transforms.ToTensor() if self.to_tensor[task_id] else None 250 resize_transform = ( 251 transforms.Resize(self.resize[task_id]) 252 if self.resize[task_id] is not None 253 else None 254 ) 255 normalization_transform = transforms.Normalize( 256 self.mean[task_id], self.std[task_id] 257 ) 258 259 return transforms.Compose( 260 list( 261 filter( 262 None, 263 [ 264 repeat_channels_transform, 265 to_tensor_transform, 266 resize_transform, 267 self.custom_transforms[task_id], 268 normalization_transform, 269 ], 270 ) 271 ) 272 ) # the order of transforms matters
Transforms for train and validation datasets of task task_id, incorporating the custom transforms with basic transforms like normalization and ToTensor(). It can be used in subclasses when constructing the dataset.
Args:
- task_id (
int): the task ID of training and validation dataset to get the transforms for.
Returns:
- train_and_val_transforms (
transforms.Compose): the composed train/val transforms.
274 def test_transforms(self, task_id: int) -> transforms.Compose: 275 r"""Transforms for test dataset of task `task_id`. Only basic transforms like `normalization` and `ToTensor()` are included. It can be used in subclasses when constructing the dataset. 276 277 **Args:** 278 - **task_id** (`int`): the task ID of test dataset to get the transforms for. 279 280 **Returns:** 281 - **test_transforms** (`transforms.Compose`): the composed test transforms. 282 """ 283 284 repeat_channels_transform = ( 285 transforms.Grayscale(num_output_channels=self.repeat_channels[task_id]) 286 if self.repeat_channels[task_id] is not None 287 else None 288 ) 289 to_tensor_transform = transforms.ToTensor() if self.to_tensor[task_id] else None 290 resize_transform = ( 291 transforms.Resize(self.resize[task_id]) 292 if self.resize[task_id] is not None 293 else None 294 ) 295 normalization_transform = transforms.Normalize( 296 self.mean[task_id], self.std[task_id] 297 ) 298 299 return transforms.Compose( 300 list( 301 filter( 302 None, 303 [ 304 repeat_channels_transform, 305 to_tensor_transform, 306 resize_transform, 307 normalization_transform, 308 ], 309 ) 310 ) 311 ) # the order of transforms matters. No custom transforms for test
Transforms for test dataset of task task_id. Only basic transforms like normalization and ToTensor() are included. It can be used in subclasses when constructing the dataset.
Args:
- task_id (
int): the task ID of test dataset to get the transforms for.
Returns:
- test_transforms (
transforms.Compose): the composed test transforms.
313 def target_transform(self, task_id: int) -> Callable: 314 r"""Target transform for task `task_id`, which maps the original class labels to the integer class labels for multi-task learning. It can be used in subclasses when constructing the dataset. 315 316 **Args:** 317 - **task_id** (`int`): the task ID of dataset to get the target transform for. 318 319 **Returns:** 320 - **target_transform** (`Callable`): the target transform function. 321 """ 322 323 class_map = self.get_mtl_class_map(task_id) 324 325 target_transform = ClassMapping(class_map=class_map) 326 return target_transform
Target transform for task task_id, which maps the original class labels to the integer class labels for multi-task learning. It can be used in subclasses when constructing the dataset.
Args:
- task_id (
int): the task ID of dataset to get the target transform for.
Returns:
- target_transform (
Callable): the target transform function.
328 @abstractmethod 329 def train_and_val_dataset(self, task_id: int) -> tuple[Any, Any]: 330 r"""Get the training and validation dataset of task `task_id`. It must be implemented by subclasses. 331 332 **Args:** 333 - **task_id** (`int`): the task ID to get the training and validation dataset for. 334 335 **Returns:** 336 - **train_and_val_dataset** (`tuple[Any, Any]`): the train and validation dataset of task `task_id`. 337 """
Get the training and validation dataset of task task_id. It must be implemented by subclasses.
Args:
- task_id (
int): the task ID to get the training and validation dataset for.
Returns:
- train_and_val_dataset (
tuple[Any, Any]): the train and validation dataset of tasktask_id.
339 @abstractmethod 340 def test_dataset(self, task_id: int) -> Any: 341 """Get the test dataset of task `task_id`. It must be implemented by subclasses. 342 343 **Args:** 344 - **task_id** (`int`): the task ID to get the test dataset for. 345 346 **Returns:** 347 - **test_dataset** (`Any`): the test dataset of task `task_id`. 348 """
Get the test dataset of task task_id. It must be implemented by subclasses.
Args:
- task_id (
int): the task ID to get the test dataset for.
Returns:
- test_dataset (
Any): the test dataset of tasktask_id.
350 def train_dataloader(self) -> DataLoader: 351 r"""DataLoader generator for stage train. It is automatically called before training. 352 353 **Returns:** 354 - **train_dataloader** (`DataLoader`): the train DataLoader of task. 355 """ 356 357 pylogger.debug( 358 "Construct train dataloader ... sampling_strategy method: %s", 359 self.sampling_strategy, 360 ) 361 362 if self.sampling_strategy == "mixed": 363 # mixed sampling strategy, which samples from all tasks' datasets 364 365 concatenated_dataset = ConcatDataset( 366 [self.dataset_train[task_id] for task_id in self.train_tasks] 367 ) 368 369 return DataLoader( 370 dataset=concatenated_dataset, 371 batch_size=self.batch_size, 372 shuffle=True, # shuffle train batch to prevent overfitting 373 num_workers=self.num_workers, 374 drop_last=True, # to avoid batchnorm error (when batch_size is 1) 375 )
DataLoader generator for stage train. It is automatically called before training.
Returns:
- train_dataloader (
DataLoader): the train DataLoader of task.
377 def val_dataloader(self) -> DataLoader: 378 r"""DataLoader generator for the validation stage. It is automatically called before validation. 379 380 **Returns:** 381 - **val_dataloader** (`dict[int, DataLoader]`): the validation DataLoader. 382 """ 383 384 pylogger.debug("Construct validation dataloader...") 385 386 return { 387 task_id: DataLoader( 388 dataset=dataset_val_t, 389 batch_size=self.batch_size, 390 shuffle=False, # don't have to shuffle val or test batch 391 num_workers=self.num_workers, 392 ) 393 for task_id, dataset_val_t in self.dataset_val.items() 394 }
DataLoader generator for the validation stage. It is automatically called before validation.
Returns:
- val_dataloader (
dict[int, DataLoader]): the validation DataLoader.
396 def test_dataloader(self) -> dict[int, DataLoader]: 397 r"""DataLoader generator for stage test. It is automatically called before testing. 398 399 **Returns:** 400 - **test_dataloader** (`dict[int, DataLoader]`): the test DataLoader. 401 """ 402 403 pylogger.debug("Construct test dataloader...") 404 405 return { 406 task_id: DataLoader( 407 dataset=dataset_test_t, 408 batch_size=self.batch_size, 409 shuffle=False, # don't have to shuffle val or test batch 410 num_workers=self.num_workers, 411 ) 412 for task_id, dataset_test_t in self.dataset_test.items() 413 }
DataLoader generator for stage test. It is automatically called before testing.
Returns:
- test_dataloader (
dict[int, DataLoader]): the test DataLoader.
416class MTLCombinedDataset(MTLDataset): 417 r"""The base class of multi-task learning datasets constructed as combinations of several single-task datasets (one dataset per task).""" 418 419 def __init__( 420 self, 421 datasets: dict[int, str], 422 root: str | dict[int, str], 423 sampling_strategy: str = "mixed", 424 batch_size: int = 1, 425 num_workers: int = 0, 426 custom_transforms: ( 427 Callable 428 | transforms.Compose 429 | None 430 | dict[int, Callable | transforms.Compose | None] 431 ) = None, 432 repeat_channels: int | None | dict[int, int | None] = None, 433 to_tensor: bool | dict[int, bool] = True, 434 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 435 ) -> None: 436 r""" 437 **Args:** 438 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 439 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories. 440 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 441 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 442 - **batch_size** (`int`): The batch size in train, val, test dataloader. 443 - **num_workers** (`int`): the number of workers for dataloaders. 444 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, and so on are not included. 445 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 446 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 447 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 448 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 449 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 450 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 451 """ 452 super().__init__( 453 root=root, 454 num_tasks=len( 455 datasets 456 ), # num_tasks is not explicitly provided, but derived from the datasets length 457 sampling_strategy=sampling_strategy, 458 batch_size=batch_size, 459 num_workers=num_workers, 460 custom_transforms=custom_transforms, 461 repeat_channels=repeat_channels, 462 to_tensor=to_tensor, 463 resize=resize, 464 ) 465 466 self.original_dataset_python_classes: dict[int, Dataset] = { 467 t: str_to_class(dataset_class_path) 468 for t, dataset_class_path in datasets.items() 469 } 470 r"""The dict of dataset classes for each task.""" 471 472 def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]: 473 r"""Get the mapping of classes of task `task_id` to fit multi-task learning. 474 475 **Args:** 476 - **task_id** (`int`): the task ID to query the class map. 477 478 **Returns:** 479 - **class_map** (`dict[str | int, int]`): the class map of the task. Keys are the original class label and values are the integer class labels for multi-task learning. For multi-task learning, the mapped class labels of a task should be continuous integers from 0 to the number of classes. 480 """ 481 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 482 original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[ 483 original_dataset_python_class_t 484 ] 485 num_classes_t = original_dataset_constants_t.NUM_CLASSES 486 class_map_t = original_dataset_constants_t.CLASS_MAP 487 488 return {class_map_t[i]: i for i in range(num_classes_t)} 489 490 def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None: 491 r"""Set up tasks for the multi-task learning experiment. 492 493 **Args:** 494 - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. 495 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. 496 """ 497 super().setup_tasks_expr(train_tasks=train_tasks, eval_tasks=eval_tasks) 498 499 for task_id in train_tasks + eval_tasks: 500 original_dataset_python_class_t = self.original_dataset_python_classes[ 501 task_id 502 ] 503 original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[ 504 original_dataset_python_class_t 505 ] 506 self.mean[task_id] = original_dataset_constants_t.MEAN 507 self.std[task_id] = original_dataset_constants_t.STD 508 509 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 510 r"""Set up evaluation tasks for the multi-task learning evaluation. 511 512 **Args:** 513 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. 514 """ 515 super().setup_tasks_eval(eval_tasks=eval_tasks)
The base class of multi-task learning datasets constructed as combinations of several single-task datasets (one dataset per task).
419 def __init__( 420 self, 421 datasets: dict[int, str], 422 root: str | dict[int, str], 423 sampling_strategy: str = "mixed", 424 batch_size: int = 1, 425 num_workers: int = 0, 426 custom_transforms: ( 427 Callable 428 | transforms.Compose 429 | None 430 | dict[int, Callable | transforms.Compose | None] 431 ) = None, 432 repeat_channels: int | None | dict[int, int | None] = None, 433 to_tensor: bool | dict[int, bool] = True, 434 resize: tuple[int, int] | None | dict[int, tuple[int, int] | None] = None, 435 ) -> None: 436 r""" 437 **Args:** 438 - **datasets** (`dict[int, str]`): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. 439 - **root** (`str` | `dict[int, str]`): the root directory where the original data files for constructing the MTL dataset physically live. If `dict[int, str]`, it should be a dict of task IDs and their corresponding root directories. 440 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 441 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 442 - **batch_size** (`int`): The batch size in train, val, test dataloader. 443 - **num_workers** (`int`): the number of workers for dataloaders. 444 - **custom_transforms** (`transform` or `transforms.Compose` or `None` or dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform. `ToTensor()`, normalization, and so on are not included. 445 If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it is `None`, no custom transforms are applied. 446 - **repeat_channels** (`int` | `None` | dict of them): the number of channels to repeat for each task. Default is `None`, which means no repeat. 447 If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is an `int`, it is the same number of channels to repeat for all tasks. If it is `None`, no repeat is applied. 448 - **to_tensor** (`bool` | `dict[int, bool]`): whether to include the `ToTensor()` transform. Default is `True`. 449 If it is a dict, the keys are task IDs and the values are whether to include the `ToTensor()` transform for each task. If it is a single boolean value, it is applied to all tasks. 450 - **resize** (`tuple[int, int]` | `None` or dict of them): the size to resize the images to. Default is `None`, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it is `None`, no resize is applied. 451 """ 452 super().__init__( 453 root=root, 454 num_tasks=len( 455 datasets 456 ), # num_tasks is not explicitly provided, but derived from the datasets length 457 sampling_strategy=sampling_strategy, 458 batch_size=batch_size, 459 num_workers=num_workers, 460 custom_transforms=custom_transforms, 461 repeat_channels=repeat_channels, 462 to_tensor=to_tensor, 463 resize=resize, 464 ) 465 466 self.original_dataset_python_classes: dict[int, Dataset] = { 467 t: str_to_class(dataset_class_path) 468 for t, dataset_class_path in datasets.items() 469 } 470 r"""The dict of dataset classes for each task."""
Args:
- datasets (
dict[int, str]): the dict of dataset class paths for each task. The keys are task IDs and the values are the dataset class paths (as strings) to use for each task. - root (
str|dict[int, str]): the root directory where the original data files for constructing the MTL dataset physically live. Ifdict[int, str], it should be a dict of task IDs and their corresponding root directories. - sampling_strategy (
str): the sampling strategy that construct training batch from each task's dataset; one of:- 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
- batch_size (
int): The batch size in train, val, test dataloader. - num_workers (
int): the number of workers for dataloaders. - custom_transforms (
transformortransforms.ComposeorNoneor dict of them): the custom transforms to apply ONLY to the TRAIN dataset. Can be a single transform, composed transforms, or no transform.ToTensor(), normalization, and so on are not included. If it is a dict, the keys are task IDs and the values are the custom transforms for each task. If it is a single transform or composed transforms, it is applied to all tasks. If it isNone, no custom transforms are applied. - repeat_channels (
int|None| dict of them): the number of channels to repeat for each task. Default isNone, which means no repeat. If it is a dict, the keys are task IDs and the values are the number of channels to repeat for each task. If it is anint, it is the same number of channels to repeat for all tasks. If it isNone, no repeat is applied. - to_tensor (
bool|dict[int, bool]): whether to include theToTensor()transform. Default isTrue. If it is a dict, the keys are task IDs and the values are whether to include theToTensor()transform for each task. If it is a single boolean value, it is applied to all tasks. - resize (
tuple[int, int]|Noneor dict of them): the size to resize the images to. Default isNone, which means no resize. If it is a dict, the keys are task IDs and the values are the sizes to resize for each task. If it is a single tuple of two integers, it is applied to all tasks. If it isNone, no resize is applied.
The dict of dataset classes for each task.
472 def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]: 473 r"""Get the mapping of classes of task `task_id` to fit multi-task learning. 474 475 **Args:** 476 - **task_id** (`int`): the task ID to query the class map. 477 478 **Returns:** 479 - **class_map** (`dict[str | int, int]`): the class map of the task. Keys are the original class label and values are the integer class labels for multi-task learning. For multi-task learning, the mapped class labels of a task should be continuous integers from 0 to the number of classes. 480 """ 481 original_dataset_python_class_t = self.original_dataset_python_classes[task_id] 482 original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[ 483 original_dataset_python_class_t 484 ] 485 num_classes_t = original_dataset_constants_t.NUM_CLASSES 486 class_map_t = original_dataset_constants_t.CLASS_MAP 487 488 return {class_map_t[i]: i for i in range(num_classes_t)}
Get the mapping of classes of task task_id to fit multi-task learning.
Args:
- task_id (
int): the task ID to query the class map.
Returns:
- class_map (
dict[str | int, int]): the class map of the task. Keys are the original class label and values are the integer class labels for multi-task learning. For multi-task learning, the mapped class labels of a task should be continuous integers from 0 to the number of classes.
490 def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None: 491 r"""Set up tasks for the multi-task learning experiment. 492 493 **Args:** 494 - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. 495 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. 496 """ 497 super().setup_tasks_expr(train_tasks=train_tasks, eval_tasks=eval_tasks) 498 499 for task_id in train_tasks + eval_tasks: 500 original_dataset_python_class_t = self.original_dataset_python_classes[ 501 task_id 502 ] 503 original_dataset_constants_t = DATASET_CONSTANTS_MAPPING[ 504 original_dataset_python_class_t 505 ] 506 self.mean[task_id] = original_dataset_constants_t.MEAN 507 self.std[task_id] = original_dataset_constants_t.STD
Set up tasks for the multi-task learning experiment.
Args:
- train_tasks (
list[int]): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. - eval_tasks (
list[int]): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
509 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 510 r"""Set up evaluation tasks for the multi-task learning evaluation. 511 512 **Args:** 513 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. 514 """ 515 super().setup_tasks_eval(eval_tasks=eval_tasks)
Set up evaluation tasks for the multi-task learning evaluation.
Args:
- eval_tasks (
list[int]): the list of task IDs to be evaluated.
518class MTLDatasetFromCL(MTLDataset): 519 r"""Multi-task learning datasets constructed from the CL datasets. 520 521 This is usually for constructing the reference joint learning experiment for continual learning. 522 """ 523 524 def __init__( 525 self, 526 cl_dataset: CLDataset, 527 sampling_strategy: str = "mixed", 528 batch_size: int = 1, 529 num_workers: int = 0, 530 ) -> None: 531 r"""Initialize the `MTLDatasetFromCL` object. 532 533 **Args:** 534 - **cl_dataset** (`CLDataset`): the CL dataset object to be used for constructing the MTL dataset. 535 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 536 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 537 - **batch_size** (`int`): The batch size in train, val, test dataloader. 538 - **num_workers** (`int`): the number of workers for dataloaders. 539 """ 540 541 self.cl_dataset: CLDataset = cl_dataset 542 r"""The CL dataset for constructing the MTL dataset.""" 543 544 super().__init__( 545 root=None, 546 num_tasks=cl_dataset.num_tasks, 547 sampling_strategy=sampling_strategy, 548 batch_size=batch_size, 549 num_workers=num_workers, 550 custom_transforms=None, # already handled in the CL dataset 551 repeat_channels=None, 552 to_tensor=None, 553 resize=None, 554 ) 555 556 def prepare_data(self) -> None: 557 r"""Download and prepare data.""" 558 self.cl_dataset.prepare_data() # prepare the CL dataset 559 560 def setup(self, stage: str) -> None: 561 r"""Set up the dataset for different stages. 562 563 **Args:** 564 - **stage** (`str`): the stage of the experiment; one of: 565 - 'fit': training and validation dataset should be assigned to `self.dataset_train` and `self.dataset_val`. 566 - 'test': test dataset should be assigned to `self.dataset_test`. 567 """ 568 if stage == "fit": 569 pylogger.debug("Construct train and validation dataset ...") 570 571 # go through each task of continual learning to get the training dataset of each task 572 for task_id in range(1, self.num_tasks + 1): 573 self.cl_dataset.setup_task_id(task_id) 574 self.cl_dataset.setup(stage) 575 576 # label the training dataset with the task ID 577 task_labelled_dataset_train_t = TaskLabelledDataset( 578 self.cl_dataset.dataset_train_t, task_id 579 ) 580 self.dataset_train[task_id] = task_labelled_dataset_train_t 581 582 # label the validation dataset with the task ID 583 task_labelled_dataset_val_t = TaskLabelledDataset( 584 self.cl_dataset.dataset_val_t, task_id 585 ) 586 self.dataset_val[task_id] = task_labelled_dataset_val_t 587 588 pylogger.debug( 589 "Train and validation dataset for task %d are ready.", task_id 590 ) 591 pylogger.info( 592 "Train dataset for task %d size: %d", 593 task_id, 594 len(self.dataset_train[task_id]), 595 ) 596 pylogger.info( 597 "Validation dataset for task %d size: %d", 598 task_id, 599 len(self.dataset_val[task_id]), 600 ) 601 602 elif stage == "test": 603 604 pylogger.debug("Construct test dataset ...") 605 606 for task_id in self.eval_tasks: 607 608 self.cl_dataset.setup_task_id(task_id) 609 self.cl_dataset.setup(stage) 610 611 task_labelled_dataset_test_t = TaskLabelledDataset( 612 self.cl_dataset.dataset_test[task_id], task_id 613 ) 614 615 self.dataset_test[task_id] = task_labelled_dataset_test_t 616 617 pylogger.debug("Test dataset for task %d are ready.", task_id) 618 pylogger.info( 619 "Test dataset for task %d size: %d", 620 task_id, 621 len(self.dataset_test[task_id]), 622 ) 623 624 def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]: 625 r"""Get the mapping of classes of task `task_id` to fit multi-task learning. 626 627 **Args:** 628 - **task_id** (`int`): The task ID to query class map. 629 630 **Returns:** 631 - **class_map**(`dict[str | int, int]`): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes. 632 """ 633 return self.cl_dataset.get_cl_class_map( 634 task_id 635 ) # directly use the CL dataset's class map (from TIL setting) 636 637 def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None: 638 r"""Set up tasks for the multi-task learning experiment. 639 640 **Args:** 641 - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. 642 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. 643 """ 644 super().setup_tasks_expr(train_tasks=train_tasks, eval_tasks=eval_tasks) 645 646 # MTL requires independent heads 647 self.cl_dataset.set_cl_paradigm(cl_paradigm="TIL") 648 for task_id in train_tasks + eval_tasks: 649 self.cl_dataset.setup_task_id(task_id) 650 651 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 652 r"""Set up evaluation tasks for the multi-task learning evaluation. 653 654 **Args:** 655 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated.""" 656 super().setup_tasks_eval(eval_tasks=eval_tasks) 657 658 # MTL requires independent heads 659 self.cl_dataset.set_cl_paradigm(cl_paradigm="TIL") 660 for task_id in eval_tasks: 661 self.cl_dataset.setup_task_id(task_id)
Multi-task learning datasets constructed from the CL datasets.
This is usually for constructing the reference joint learning experiment for continual learning.
524 def __init__( 525 self, 526 cl_dataset: CLDataset, 527 sampling_strategy: str = "mixed", 528 batch_size: int = 1, 529 num_workers: int = 0, 530 ) -> None: 531 r"""Initialize the `MTLDatasetFromCL` object. 532 533 **Args:** 534 - **cl_dataset** (`CLDataset`): the CL dataset object to be used for constructing the MTL dataset. 535 - **sampling_strategy** (`str`): the sampling strategy that construct training batch from each task's dataset; one of: 536 - 'mixed': mixed sampling strategy, which samples from all tasks' datasets. 537 - **batch_size** (`int`): The batch size in train, val, test dataloader. 538 - **num_workers** (`int`): the number of workers for dataloaders. 539 """ 540 541 self.cl_dataset: CLDataset = cl_dataset 542 r"""The CL dataset for constructing the MTL dataset.""" 543 544 super().__init__( 545 root=None, 546 num_tasks=cl_dataset.num_tasks, 547 sampling_strategy=sampling_strategy, 548 batch_size=batch_size, 549 num_workers=num_workers, 550 custom_transforms=None, # already handled in the CL dataset 551 repeat_channels=None, 552 to_tensor=None, 553 resize=None, 554 )
Initialize the MTLDatasetFromCL object.
Args:
- cl_dataset (
CLDataset): the CL dataset object to be used for constructing the MTL dataset. - sampling_strategy (
str): the sampling strategy that construct training batch from each task's dataset; one of:- 'mixed': mixed sampling strategy, which samples from all tasks' datasets.
- batch_size (
int): The batch size in train, val, test dataloader. - num_workers (
int): the number of workers for dataloaders.
556 def prepare_data(self) -> None: 557 r"""Download and prepare data.""" 558 self.cl_dataset.prepare_data() # prepare the CL dataset
Download and prepare data.
560 def setup(self, stage: str) -> None: 561 r"""Set up the dataset for different stages. 562 563 **Args:** 564 - **stage** (`str`): the stage of the experiment; one of: 565 - 'fit': training and validation dataset should be assigned to `self.dataset_train` and `self.dataset_val`. 566 - 'test': test dataset should be assigned to `self.dataset_test`. 567 """ 568 if stage == "fit": 569 pylogger.debug("Construct train and validation dataset ...") 570 571 # go through each task of continual learning to get the training dataset of each task 572 for task_id in range(1, self.num_tasks + 1): 573 self.cl_dataset.setup_task_id(task_id) 574 self.cl_dataset.setup(stage) 575 576 # label the training dataset with the task ID 577 task_labelled_dataset_train_t = TaskLabelledDataset( 578 self.cl_dataset.dataset_train_t, task_id 579 ) 580 self.dataset_train[task_id] = task_labelled_dataset_train_t 581 582 # label the validation dataset with the task ID 583 task_labelled_dataset_val_t = TaskLabelledDataset( 584 self.cl_dataset.dataset_val_t, task_id 585 ) 586 self.dataset_val[task_id] = task_labelled_dataset_val_t 587 588 pylogger.debug( 589 "Train and validation dataset for task %d are ready.", task_id 590 ) 591 pylogger.info( 592 "Train dataset for task %d size: %d", 593 task_id, 594 len(self.dataset_train[task_id]), 595 ) 596 pylogger.info( 597 "Validation dataset for task %d size: %d", 598 task_id, 599 len(self.dataset_val[task_id]), 600 ) 601 602 elif stage == "test": 603 604 pylogger.debug("Construct test dataset ...") 605 606 for task_id in self.eval_tasks: 607 608 self.cl_dataset.setup_task_id(task_id) 609 self.cl_dataset.setup(stage) 610 611 task_labelled_dataset_test_t = TaskLabelledDataset( 612 self.cl_dataset.dataset_test[task_id], task_id 613 ) 614 615 self.dataset_test[task_id] = task_labelled_dataset_test_t 616 617 pylogger.debug("Test dataset for task %d are ready.", task_id) 618 pylogger.info( 619 "Test dataset for task %d size: %d", 620 task_id, 621 len(self.dataset_test[task_id]), 622 )
Set up the dataset for different stages.
Args:
- stage (
str): the stage of the experiment; one of:- 'fit': training and validation dataset should be assigned to
self.dataset_trainandself.dataset_val. - 'test': test dataset should be assigned to
self.dataset_test.
- 'fit': training and validation dataset should be assigned to
624 def get_mtl_class_map(self, task_id: int) -> dict[str | int, int]: 625 r"""Get the mapping of classes of task `task_id` to fit multi-task learning. 626 627 **Args:** 628 - **task_id** (`int`): The task ID to query class map. 629 630 **Returns:** 631 - **class_map**(`dict[str | int, int]`): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes. 632 """ 633 return self.cl_dataset.get_cl_class_map( 634 task_id 635 ) # directly use the CL dataset's class map (from TIL setting)
Get the mapping of classes of task task_id to fit multi-task learning.
Args:
- task_id (
int): The task ID to query class map.
Returns:
- class_map(
dict[str | int, int]): the class map of the task. Keys are original class labels and values are integer class labels for multi-task learning. The mapped class labels of each task should be continuous integers from 0 to the number of classes.
637 def setup_tasks_expr(self, train_tasks: list[int], eval_tasks: list[int]) -> None: 638 r"""Set up tasks for the multi-task learning experiment. 639 640 **Args:** 641 - **train_tasks** (`list[int]`): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. 642 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. 643 """ 644 super().setup_tasks_expr(train_tasks=train_tasks, eval_tasks=eval_tasks) 645 646 # MTL requires independent heads 647 self.cl_dataset.set_cl_paradigm(cl_paradigm="TIL") 648 for task_id in train_tasks + eval_tasks: 649 self.cl_dataset.setup_task_id(task_id)
Set up tasks for the multi-task learning experiment.
Args:
- train_tasks (
list[int]): the list of task IDs to be trained. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader. - eval_tasks (
list[int]): the list of task IDs to be evaluated. It should be a list of integers, each integer is the task ID. This is used when constructing the dataloader.
651 def setup_tasks_eval(self, eval_tasks: list[int]) -> None: 652 r"""Set up evaluation tasks for the multi-task learning evaluation. 653 654 **Args:** 655 - **eval_tasks** (`list[int]`): the list of task IDs to be evaluated.""" 656 super().setup_tasks_eval(eval_tasks=eval_tasks) 657 658 # MTL requires independent heads 659 self.cl_dataset.set_cl_paradigm(cl_paradigm="TIL") 660 for task_id in eval_tasks: 661 self.cl_dataset.setup_task_id(task_id)
Set up evaluation tasks for the multi-task learning evaluation.
Args:
- eval_tasks (
list[int]): the list of task IDs to be evaluated.