clarena.pipelines.mtl_expr
The submodule in pipelines for multi-task learning experiment.
1r""" 2The submodule in `pipelines` for multi-task learning experiment. 3""" 4 5__all__ = ["MTLExperiment"] 6 7import logging 8from typing import Any 9 10import hydra 11import lightning as L 12from lightning import Callback, Trainer 13from lightning.pytorch.loggers import Logger 14from omegaconf import DictConfig 15from torch.optim import Optimizer 16from torch.optim.lr_scheduler import LRScheduler 17 18from clarena.backbones import Backbone, CLBackbone 19from clarena.heads import HeadsMTL 20from clarena.mtl_algorithms import MTLAlgorithm 21from clarena.mtl_datasets import MTLDataset 22from clarena.utils.cfg import select_hyperparameters_from_config 23 24# always get logger for built-in logging in each module 25pylogger = logging.getLogger(__name__) 26 27 28class MTLExperiment: 29 r"""The base class for multi-task learning experiment.""" 30 31 def __init__(self, cfg: DictConfig) -> None: 32 r""" 33 **Args:** 34 - **cfg** (`DictConfig`): the complete config dict for the multi-task learning experiment. 35 """ 36 self.cfg: DictConfig = cfg 37 r"""The complete config dict.""" 38 39 MTLExperiment.sanity_check(self) 40 41 # required config fields 42 self.train_tasks: list[int] = ( 43 cfg.train_tasks 44 if isinstance(cfg.train_tasks, list) 45 else list(range(1, cfg.train_tasks + 1)) 46 ) 47 r"""The list of tasks to train.""" 48 self.eval_tasks: list[int] = ( 49 cfg.eval_tasks 50 if isinstance(cfg.eval_tasks, list) 51 else list(range(1, cfg.eval_tasks + 1)) 52 ) 53 r"""The list of tasks to evaluate.""" 54 self.global_seed: int = cfg.global_seed 55 r"""The global seed for the entire experiment.""" 56 self.output_dir: str = cfg.output_dir 57 r"""The folder for storing the experiment results.""" 58 59 # components 60 self.mtl_dataset: MTLDataset 61 r"""MTL dataset object.""" 62 self.backbone: CLBackbone 63 r"""Backbone network object.""" 64 self.heads: HeadsMTL 65 r"""MTL output heads object.""" 66 self.model: MTLAlgorithm 67 r"""MTL model object.""" 68 self.optimizer: Optimizer 69 r"""Optimizer object.""" 70 self.lr_scheduler: LRScheduler | None 71 r"""Learning rate scheduler object.""" 72 self.lightning_loggers: list[Logger] 73 r"""The list of initialized lightning loggers objects.""" 74 self.callbacks: list[Callback] 75 r"""The list of initialized callbacks objects.""" 76 self.trainer: Trainer 77 r"""Trainer object.""" 78 79 def sanity_check(self) -> None: 80 r"""Sanity check for config.""" 81 82 # check required config fields 83 required_config_fields = [ 84 "pipeline", 85 "expr_name", 86 "train_tasks", 87 "eval_tasks", 88 "global_seed", 89 "mtl_dataset", 90 "mtl_algorithm", 91 "backbone", 92 "optimizer", 93 "lr_scheduler", 94 "trainer", 95 "metrics", 96 "lightning_loggers", 97 "callbacks", 98 "output_dir", 99 # "hydra" is excluded as it doesn't appear 100 "misc", 101 ] 102 for field in required_config_fields: 103 if not self.cfg.get(field): 104 raise KeyError( 105 f"Field `{field}` is required in the experiment index config." 106 ) 107 108 # get dataset number of tasks 109 if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL": 110 cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset") 111 if cl_dataset_cfg.get("num_tasks"): 112 num_tasks = cl_dataset_cfg.get("num_tasks") 113 elif cl_dataset_cfg.get("class_split"): 114 num_tasks = len(cl_dataset_cfg.class_split) 115 elif cl_dataset_cfg.get("datasets"): 116 num_tasks = len(cl_dataset_cfg.datasets) 117 else: 118 raise KeyError( 119 "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 120 ) 121 else: 122 if self.cfg.mtl_dataset.get("num_tasks"): 123 num_tasks = self.cfg.mtl_dataset.num_tasks 124 else: 125 raise KeyError( 126 "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config." 127 ) 128 129 # check train_tasks 130 train_tasks = self.cfg.train_tasks 131 if isinstance(train_tasks, list): 132 if len(train_tasks) < 1: 133 raise ValueError("`train_tasks` must contain at least one task.") 134 if any(t < 1 or t > num_tasks for t in train_tasks): 135 raise ValueError( 136 f"All task IDs in `train_tasks` must be between 1 and {num_tasks}." 137 ) 138 elif isinstance(train_tasks, int): 139 if train_tasks < 0 or train_tasks > num_tasks: 140 raise ValueError( 141 f"`train_tasks` as integer must be between 0 and {num_tasks}." 142 ) 143 else: 144 raise TypeError( 145 "`train_tasks` must be either a list of integers or an integer." 146 ) 147 148 # check eval_tasks 149 eval_tasks = self.cfg.eval_tasks 150 if isinstance(eval_tasks, list): 151 if len(eval_tasks) < 1: 152 raise ValueError("`eval_tasks` must contain at least one task.") 153 if any(t < 1 or t > num_tasks for t in eval_tasks): 154 raise ValueError( 155 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 156 ) 157 elif isinstance(eval_tasks, int): 158 if eval_tasks < 0 or eval_tasks > num_tasks: 159 raise ValueError( 160 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 161 ) 162 else: 163 raise TypeError( 164 "`eval_tasks` must be either a list of integers or an integer." 165 ) 166 167 def instantiate_mtl_dataset( 168 self, 169 mtl_dataset_cfg: DictConfig, 170 ) -> None: 171 r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`.""" 172 pylogger.debug( 173 "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset)...", 174 mtl_dataset_cfg.get("_target_"), 175 ) 176 self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg) 177 pylogger.debug( 178 "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated!", 179 mtl_dataset_cfg.get("_target_"), 180 ) 181 182 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 183 r"""Instantiate the MTL backbone network object from `backbone_cfg`.""" 184 pylogger.debug( 185 "Instantiating backbone network <%s> (clarena.backbones.Backbone)...", 186 backbone_cfg.get("_target_"), 187 ) 188 self.backbone = hydra.utils.instantiate(backbone_cfg) 189 pylogger.debug( 190 "Backbone network <%s> (clarena.backbones.Backbone) instantiated!", 191 backbone_cfg.get("_target_"), 192 ) 193 194 def instantiate_heads( 195 self, 196 input_dim: int, 197 ) -> None: 198 r"""Instantiate the MTL output heads object. 199 200 **Args:** 201 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 202 """ 203 pylogger.debug( 204 "Instantiating MTL heads...", 205 ) 206 self.heads = HeadsMTL(input_dim=input_dim) 207 pylogger.debug("MTL heads instantiated! ") 208 209 def instantiate_mtl_algorithm( 210 self, 211 mtl_algorithm_cfg: DictConfig, 212 backbone: Backbone, 213 heads: HeadsMTL, 214 non_algorithmic_hparams: dict[str, Any], 215 ) -> None: 216 r"""Instantiate the mtl_algorithm object from `mtl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 217 pylogger.debug( 218 "MTL algorithm is set as <%s>. Instantiating <%s> (clarena.mtl_algorithms.MTLAlgorithm)...", 219 mtl_algorithm_cfg.get("_target_"), 220 mtl_algorithm_cfg.get("_target_"), 221 ) 222 self.model = hydra.utils.instantiate( 223 mtl_algorithm_cfg, 224 backbone=backbone, 225 heads=heads, 226 non_algorithmic_hparams=non_algorithmic_hparams, 227 ) 228 pylogger.debug( 229 "<%s> (clarena.mtl_algorithms.MTLAlgorithm) instantiated!", 230 mtl_algorithm_cfg.get("_target_"), 231 ) 232 233 def instantiate_optimizer( 234 self, 235 optimizer_cfg: DictConfig, 236 ) -> None: 237 r"""Instantiate the optimizer object from `optimizer_cfg`.""" 238 239 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 240 pylogger.debug( 241 "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...", 242 optimizer_cfg.get("_target_"), 243 ) 244 self.optimizer = hydra.utils.instantiate(optimizer_cfg) 245 pylogger.debug( 246 "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!", 247 optimizer_cfg.get("_target_"), 248 ) 249 250 def instantiate_lr_scheduler( 251 self, 252 lr_scheduler_cfg: DictConfig, 253 ) -> None: 254 r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`.""" 255 256 # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now. 257 pylogger.debug( 258 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...", 259 lr_scheduler_cfg.get("_target_"), 260 ) 261 self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg) 262 pylogger.debug( 263 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!", 264 lr_scheduler_cfg.get("_target_"), 265 ) 266 267 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 268 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 269 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 270 self.lightning_loggers = [ 271 hydra.utils.instantiate(lightning_logger) 272 for lightning_logger in lightning_loggers_cfg.values() 273 ] 274 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!") 275 276 def instantiate_callbacks( 277 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 278 ) -> None: 279 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 280 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 281 282 # instantiate metric callbacks 283 metric_callbacks = [ 284 hydra.utils.instantiate(callback) for callback in metrics_cfg 285 ] 286 287 # instantiate other callbacks 288 other_callbacks = [ 289 hydra.utils.instantiate(callback) for callback in callbacks_cfg 290 ] 291 292 # add metric callbacks to the list of callbacks 293 self.callbacks = metric_callbacks + other_callbacks 294 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 295 296 def instantiate_trainer( 297 self, 298 trainer_cfg: DictConfig, 299 lightning_loggers: list[Logger], 300 callbacks: list[Callback], 301 ) -> None: 302 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 303 304 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 305 self.trainer = hydra.utils.instantiate( 306 trainer_cfg, logger=lightning_loggers, callbacks=callbacks 307 ) 308 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 309 310 def set_global_seed(self, global_seed: int) -> None: 311 r"""Set the `global_seed` for the entire experiment.""" 312 L.seed_everything(self.global_seed, workers=True) 313 pylogger.debug("Global seed is set as %d.", global_seed) 314 315 def run(self) -> None: 316 r"""The main method to run the multi-task learning experiment.""" 317 self.set_global_seed(self.global_seed) 318 319 self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset) 320 self.instantiate_backbone(backbone_cfg=self.cfg.backbone) 321 self.instantiate_heads(input_dim=self.cfg.backbone.output_dim) 322 self.instantiate_mtl_algorithm( 323 mtl_algorithm_cfg=self.cfg.mtl_algorithm, 324 backbone=self.backbone, 325 heads=self.heads, 326 non_algorithmic_hparams=select_hyperparameters_from_config( 327 cfg=self.cfg, type=self.cfg.pipeline 328 ), 329 ) # mtl_algorithm should be instantiated after backbone and heads 330 self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer) 331 self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler) 332 self.instantiate_lightning_loggers( 333 lightning_loggers_cfg=self.cfg.lightning_loggers 334 ) 335 self.instantiate_callbacks( 336 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 337 ) 338 self.instantiate_trainer( 339 trainer_cfg=self.cfg.trainer, 340 lightning_loggers=self.lightning_loggers, 341 callbacks=self.callbacks, 342 ) # trainer should be instantiated after lightning loggers and callbacks 343 344 # setup tasks for dataset and model 345 self.mtl_dataset.setup_tasks_expr( 346 train_tasks=self.train_tasks, eval_tasks=self.eval_tasks 347 ) 348 self.model.setup_tasks( 349 task_ids=self.train_tasks, 350 num_classes={ 351 task_id: len(self.mtl_dataset.get_mtl_class_map(task_id)) 352 for task_id in self.train_tasks 353 }, 354 optimizer=self.optimizer, 355 lr_scheduler=self.lr_scheduler, 356 ) 357 358 # train and validate the model 359 self.trainer.fit( 360 model=self.model, 361 datamodule=self.mtl_dataset, 362 ) 363 364 # evaluation after training and validation 365 self.trainer.test( 366 model=self.model, 367 datamodule=self.mtl_dataset, 368 )
29class MTLExperiment: 30 r"""The base class for multi-task learning experiment.""" 31 32 def __init__(self, cfg: DictConfig) -> None: 33 r""" 34 **Args:** 35 - **cfg** (`DictConfig`): the complete config dict for the multi-task learning experiment. 36 """ 37 self.cfg: DictConfig = cfg 38 r"""The complete config dict.""" 39 40 MTLExperiment.sanity_check(self) 41 42 # required config fields 43 self.train_tasks: list[int] = ( 44 cfg.train_tasks 45 if isinstance(cfg.train_tasks, list) 46 else list(range(1, cfg.train_tasks + 1)) 47 ) 48 r"""The list of tasks to train.""" 49 self.eval_tasks: list[int] = ( 50 cfg.eval_tasks 51 if isinstance(cfg.eval_tasks, list) 52 else list(range(1, cfg.eval_tasks + 1)) 53 ) 54 r"""The list of tasks to evaluate.""" 55 self.global_seed: int = cfg.global_seed 56 r"""The global seed for the entire experiment.""" 57 self.output_dir: str = cfg.output_dir 58 r"""The folder for storing the experiment results.""" 59 60 # components 61 self.mtl_dataset: MTLDataset 62 r"""MTL dataset object.""" 63 self.backbone: CLBackbone 64 r"""Backbone network object.""" 65 self.heads: HeadsMTL 66 r"""MTL output heads object.""" 67 self.model: MTLAlgorithm 68 r"""MTL model object.""" 69 self.optimizer: Optimizer 70 r"""Optimizer object.""" 71 self.lr_scheduler: LRScheduler | None 72 r"""Learning rate scheduler object.""" 73 self.lightning_loggers: list[Logger] 74 r"""The list of initialized lightning loggers objects.""" 75 self.callbacks: list[Callback] 76 r"""The list of initialized callbacks objects.""" 77 self.trainer: Trainer 78 r"""Trainer object.""" 79 80 def sanity_check(self) -> None: 81 r"""Sanity check for config.""" 82 83 # check required config fields 84 required_config_fields = [ 85 "pipeline", 86 "expr_name", 87 "train_tasks", 88 "eval_tasks", 89 "global_seed", 90 "mtl_dataset", 91 "mtl_algorithm", 92 "backbone", 93 "optimizer", 94 "lr_scheduler", 95 "trainer", 96 "metrics", 97 "lightning_loggers", 98 "callbacks", 99 "output_dir", 100 # "hydra" is excluded as it doesn't appear 101 "misc", 102 ] 103 for field in required_config_fields: 104 if not self.cfg.get(field): 105 raise KeyError( 106 f"Field `{field}` is required in the experiment index config." 107 ) 108 109 # get dataset number of tasks 110 if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL": 111 cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset") 112 if cl_dataset_cfg.get("num_tasks"): 113 num_tasks = cl_dataset_cfg.get("num_tasks") 114 elif cl_dataset_cfg.get("class_split"): 115 num_tasks = len(cl_dataset_cfg.class_split) 116 elif cl_dataset_cfg.get("datasets"): 117 num_tasks = len(cl_dataset_cfg.datasets) 118 else: 119 raise KeyError( 120 "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 121 ) 122 else: 123 if self.cfg.mtl_dataset.get("num_tasks"): 124 num_tasks = self.cfg.mtl_dataset.num_tasks 125 else: 126 raise KeyError( 127 "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config." 128 ) 129 130 # check train_tasks 131 train_tasks = self.cfg.train_tasks 132 if isinstance(train_tasks, list): 133 if len(train_tasks) < 1: 134 raise ValueError("`train_tasks` must contain at least one task.") 135 if any(t < 1 or t > num_tasks for t in train_tasks): 136 raise ValueError( 137 f"All task IDs in `train_tasks` must be between 1 and {num_tasks}." 138 ) 139 elif isinstance(train_tasks, int): 140 if train_tasks < 0 or train_tasks > num_tasks: 141 raise ValueError( 142 f"`train_tasks` as integer must be between 0 and {num_tasks}." 143 ) 144 else: 145 raise TypeError( 146 "`train_tasks` must be either a list of integers or an integer." 147 ) 148 149 # check eval_tasks 150 eval_tasks = self.cfg.eval_tasks 151 if isinstance(eval_tasks, list): 152 if len(eval_tasks) < 1: 153 raise ValueError("`eval_tasks` must contain at least one task.") 154 if any(t < 1 or t > num_tasks for t in eval_tasks): 155 raise ValueError( 156 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 157 ) 158 elif isinstance(eval_tasks, int): 159 if eval_tasks < 0 or eval_tasks > num_tasks: 160 raise ValueError( 161 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 162 ) 163 else: 164 raise TypeError( 165 "`eval_tasks` must be either a list of integers or an integer." 166 ) 167 168 def instantiate_mtl_dataset( 169 self, 170 mtl_dataset_cfg: DictConfig, 171 ) -> None: 172 r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`.""" 173 pylogger.debug( 174 "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset)...", 175 mtl_dataset_cfg.get("_target_"), 176 ) 177 self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg) 178 pylogger.debug( 179 "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated!", 180 mtl_dataset_cfg.get("_target_"), 181 ) 182 183 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 184 r"""Instantiate the MTL backbone network object from `backbone_cfg`.""" 185 pylogger.debug( 186 "Instantiating backbone network <%s> (clarena.backbones.Backbone)...", 187 backbone_cfg.get("_target_"), 188 ) 189 self.backbone = hydra.utils.instantiate(backbone_cfg) 190 pylogger.debug( 191 "Backbone network <%s> (clarena.backbones.Backbone) instantiated!", 192 backbone_cfg.get("_target_"), 193 ) 194 195 def instantiate_heads( 196 self, 197 input_dim: int, 198 ) -> None: 199 r"""Instantiate the MTL output heads object. 200 201 **Args:** 202 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 203 """ 204 pylogger.debug( 205 "Instantiating MTL heads...", 206 ) 207 self.heads = HeadsMTL(input_dim=input_dim) 208 pylogger.debug("MTL heads instantiated! ") 209 210 def instantiate_mtl_algorithm( 211 self, 212 mtl_algorithm_cfg: DictConfig, 213 backbone: Backbone, 214 heads: HeadsMTL, 215 non_algorithmic_hparams: dict[str, Any], 216 ) -> None: 217 r"""Instantiate the mtl_algorithm object from `mtl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 218 pylogger.debug( 219 "MTL algorithm is set as <%s>. Instantiating <%s> (clarena.mtl_algorithms.MTLAlgorithm)...", 220 mtl_algorithm_cfg.get("_target_"), 221 mtl_algorithm_cfg.get("_target_"), 222 ) 223 self.model = hydra.utils.instantiate( 224 mtl_algorithm_cfg, 225 backbone=backbone, 226 heads=heads, 227 non_algorithmic_hparams=non_algorithmic_hparams, 228 ) 229 pylogger.debug( 230 "<%s> (clarena.mtl_algorithms.MTLAlgorithm) instantiated!", 231 mtl_algorithm_cfg.get("_target_"), 232 ) 233 234 def instantiate_optimizer( 235 self, 236 optimizer_cfg: DictConfig, 237 ) -> None: 238 r"""Instantiate the optimizer object from `optimizer_cfg`.""" 239 240 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 241 pylogger.debug( 242 "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...", 243 optimizer_cfg.get("_target_"), 244 ) 245 self.optimizer = hydra.utils.instantiate(optimizer_cfg) 246 pylogger.debug( 247 "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!", 248 optimizer_cfg.get("_target_"), 249 ) 250 251 def instantiate_lr_scheduler( 252 self, 253 lr_scheduler_cfg: DictConfig, 254 ) -> None: 255 r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`.""" 256 257 # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now. 258 pylogger.debug( 259 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...", 260 lr_scheduler_cfg.get("_target_"), 261 ) 262 self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg) 263 pylogger.debug( 264 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!", 265 lr_scheduler_cfg.get("_target_"), 266 ) 267 268 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 269 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 270 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 271 self.lightning_loggers = [ 272 hydra.utils.instantiate(lightning_logger) 273 for lightning_logger in lightning_loggers_cfg.values() 274 ] 275 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!") 276 277 def instantiate_callbacks( 278 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 279 ) -> None: 280 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 281 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 282 283 # instantiate metric callbacks 284 metric_callbacks = [ 285 hydra.utils.instantiate(callback) for callback in metrics_cfg 286 ] 287 288 # instantiate other callbacks 289 other_callbacks = [ 290 hydra.utils.instantiate(callback) for callback in callbacks_cfg 291 ] 292 293 # add metric callbacks to the list of callbacks 294 self.callbacks = metric_callbacks + other_callbacks 295 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 296 297 def instantiate_trainer( 298 self, 299 trainer_cfg: DictConfig, 300 lightning_loggers: list[Logger], 301 callbacks: list[Callback], 302 ) -> None: 303 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 304 305 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 306 self.trainer = hydra.utils.instantiate( 307 trainer_cfg, logger=lightning_loggers, callbacks=callbacks 308 ) 309 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 310 311 def set_global_seed(self, global_seed: int) -> None: 312 r"""Set the `global_seed` for the entire experiment.""" 313 L.seed_everything(self.global_seed, workers=True) 314 pylogger.debug("Global seed is set as %d.", global_seed) 315 316 def run(self) -> None: 317 r"""The main method to run the multi-task learning experiment.""" 318 self.set_global_seed(self.global_seed) 319 320 self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset) 321 self.instantiate_backbone(backbone_cfg=self.cfg.backbone) 322 self.instantiate_heads(input_dim=self.cfg.backbone.output_dim) 323 self.instantiate_mtl_algorithm( 324 mtl_algorithm_cfg=self.cfg.mtl_algorithm, 325 backbone=self.backbone, 326 heads=self.heads, 327 non_algorithmic_hparams=select_hyperparameters_from_config( 328 cfg=self.cfg, type=self.cfg.pipeline 329 ), 330 ) # mtl_algorithm should be instantiated after backbone and heads 331 self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer) 332 self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler) 333 self.instantiate_lightning_loggers( 334 lightning_loggers_cfg=self.cfg.lightning_loggers 335 ) 336 self.instantiate_callbacks( 337 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 338 ) 339 self.instantiate_trainer( 340 trainer_cfg=self.cfg.trainer, 341 lightning_loggers=self.lightning_loggers, 342 callbacks=self.callbacks, 343 ) # trainer should be instantiated after lightning loggers and callbacks 344 345 # setup tasks for dataset and model 346 self.mtl_dataset.setup_tasks_expr( 347 train_tasks=self.train_tasks, eval_tasks=self.eval_tasks 348 ) 349 self.model.setup_tasks( 350 task_ids=self.train_tasks, 351 num_classes={ 352 task_id: len(self.mtl_dataset.get_mtl_class_map(task_id)) 353 for task_id in self.train_tasks 354 }, 355 optimizer=self.optimizer, 356 lr_scheduler=self.lr_scheduler, 357 ) 358 359 # train and validate the model 360 self.trainer.fit( 361 model=self.model, 362 datamodule=self.mtl_dataset, 363 ) 364 365 # evaluation after training and validation 366 self.trainer.test( 367 model=self.model, 368 datamodule=self.mtl_dataset, 369 )
The base class for multi-task learning experiment.
32 def __init__(self, cfg: DictConfig) -> None: 33 r""" 34 **Args:** 35 - **cfg** (`DictConfig`): the complete config dict for the multi-task learning experiment. 36 """ 37 self.cfg: DictConfig = cfg 38 r"""The complete config dict.""" 39 40 MTLExperiment.sanity_check(self) 41 42 # required config fields 43 self.train_tasks: list[int] = ( 44 cfg.train_tasks 45 if isinstance(cfg.train_tasks, list) 46 else list(range(1, cfg.train_tasks + 1)) 47 ) 48 r"""The list of tasks to train.""" 49 self.eval_tasks: list[int] = ( 50 cfg.eval_tasks 51 if isinstance(cfg.eval_tasks, list) 52 else list(range(1, cfg.eval_tasks + 1)) 53 ) 54 r"""The list of tasks to evaluate.""" 55 self.global_seed: int = cfg.global_seed 56 r"""The global seed for the entire experiment.""" 57 self.output_dir: str = cfg.output_dir 58 r"""The folder for storing the experiment results.""" 59 60 # components 61 self.mtl_dataset: MTLDataset 62 r"""MTL dataset object.""" 63 self.backbone: CLBackbone 64 r"""Backbone network object.""" 65 self.heads: HeadsMTL 66 r"""MTL output heads object.""" 67 self.model: MTLAlgorithm 68 r"""MTL model object.""" 69 self.optimizer: Optimizer 70 r"""Optimizer object.""" 71 self.lr_scheduler: LRScheduler | None 72 r"""Learning rate scheduler object.""" 73 self.lightning_loggers: list[Logger] 74 r"""The list of initialized lightning loggers objects.""" 75 self.callbacks: list[Callback] 76 r"""The list of initialized callbacks objects.""" 77 self.trainer: Trainer 78 r"""Trainer object."""
Args:
- cfg (
DictConfig): the complete config dict for the multi-task learning experiment.
The list of initialized lightning loggers objects.
The list of initialized callbacks objects.
80 def sanity_check(self) -> None: 81 r"""Sanity check for config.""" 82 83 # check required config fields 84 required_config_fields = [ 85 "pipeline", 86 "expr_name", 87 "train_tasks", 88 "eval_tasks", 89 "global_seed", 90 "mtl_dataset", 91 "mtl_algorithm", 92 "backbone", 93 "optimizer", 94 "lr_scheduler", 95 "trainer", 96 "metrics", 97 "lightning_loggers", 98 "callbacks", 99 "output_dir", 100 # "hydra" is excluded as it doesn't appear 101 "misc", 102 ] 103 for field in required_config_fields: 104 if not self.cfg.get(field): 105 raise KeyError( 106 f"Field `{field}` is required in the experiment index config." 107 ) 108 109 # get dataset number of tasks 110 if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL": 111 cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset") 112 if cl_dataset_cfg.get("num_tasks"): 113 num_tasks = cl_dataset_cfg.get("num_tasks") 114 elif cl_dataset_cfg.get("class_split"): 115 num_tasks = len(cl_dataset_cfg.class_split) 116 elif cl_dataset_cfg.get("datasets"): 117 num_tasks = len(cl_dataset_cfg.datasets) 118 else: 119 raise KeyError( 120 "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 121 ) 122 else: 123 if self.cfg.mtl_dataset.get("num_tasks"): 124 num_tasks = self.cfg.mtl_dataset.num_tasks 125 else: 126 raise KeyError( 127 "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config." 128 ) 129 130 # check train_tasks 131 train_tasks = self.cfg.train_tasks 132 if isinstance(train_tasks, list): 133 if len(train_tasks) < 1: 134 raise ValueError("`train_tasks` must contain at least one task.") 135 if any(t < 1 or t > num_tasks for t in train_tasks): 136 raise ValueError( 137 f"All task IDs in `train_tasks` must be between 1 and {num_tasks}." 138 ) 139 elif isinstance(train_tasks, int): 140 if train_tasks < 0 or train_tasks > num_tasks: 141 raise ValueError( 142 f"`train_tasks` as integer must be between 0 and {num_tasks}." 143 ) 144 else: 145 raise TypeError( 146 "`train_tasks` must be either a list of integers or an integer." 147 ) 148 149 # check eval_tasks 150 eval_tasks = self.cfg.eval_tasks 151 if isinstance(eval_tasks, list): 152 if len(eval_tasks) < 1: 153 raise ValueError("`eval_tasks` must contain at least one task.") 154 if any(t < 1 or t > num_tasks for t in eval_tasks): 155 raise ValueError( 156 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 157 ) 158 elif isinstance(eval_tasks, int): 159 if eval_tasks < 0 or eval_tasks > num_tasks: 160 raise ValueError( 161 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 162 ) 163 else: 164 raise TypeError( 165 "`eval_tasks` must be either a list of integers or an integer." 166 )
Sanity check for config.
168 def instantiate_mtl_dataset( 169 self, 170 mtl_dataset_cfg: DictConfig, 171 ) -> None: 172 r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`.""" 173 pylogger.debug( 174 "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset)...", 175 mtl_dataset_cfg.get("_target_"), 176 ) 177 self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg) 178 pylogger.debug( 179 "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated!", 180 mtl_dataset_cfg.get("_target_"), 181 )
Instantiate the MTL dataset object from mtl_dataset_cfg.
183 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 184 r"""Instantiate the MTL backbone network object from `backbone_cfg`.""" 185 pylogger.debug( 186 "Instantiating backbone network <%s> (clarena.backbones.Backbone)...", 187 backbone_cfg.get("_target_"), 188 ) 189 self.backbone = hydra.utils.instantiate(backbone_cfg) 190 pylogger.debug( 191 "Backbone network <%s> (clarena.backbones.Backbone) instantiated!", 192 backbone_cfg.get("_target_"), 193 )
Instantiate the MTL backbone network object from backbone_cfg.
195 def instantiate_heads( 196 self, 197 input_dim: int, 198 ) -> None: 199 r"""Instantiate the MTL output heads object. 200 201 **Args:** 202 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 203 """ 204 pylogger.debug( 205 "Instantiating MTL heads...", 206 ) 207 self.heads = HeadsMTL(input_dim=input_dim) 208 pylogger.debug("MTL heads instantiated! ")
Instantiate the MTL output heads object.
Args:
- input_dim (
int): the input dimension of the heads. Must be equal to theoutput_dimof the connected backbone.
210 def instantiate_mtl_algorithm( 211 self, 212 mtl_algorithm_cfg: DictConfig, 213 backbone: Backbone, 214 heads: HeadsMTL, 215 non_algorithmic_hparams: dict[str, Any], 216 ) -> None: 217 r"""Instantiate the mtl_algorithm object from `mtl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 218 pylogger.debug( 219 "MTL algorithm is set as <%s>. Instantiating <%s> (clarena.mtl_algorithms.MTLAlgorithm)...", 220 mtl_algorithm_cfg.get("_target_"), 221 mtl_algorithm_cfg.get("_target_"), 222 ) 223 self.model = hydra.utils.instantiate( 224 mtl_algorithm_cfg, 225 backbone=backbone, 226 heads=heads, 227 non_algorithmic_hparams=non_algorithmic_hparams, 228 ) 229 pylogger.debug( 230 "<%s> (clarena.mtl_algorithms.MTLAlgorithm) instantiated!", 231 mtl_algorithm_cfg.get("_target_"), 232 )
234 def instantiate_optimizer( 235 self, 236 optimizer_cfg: DictConfig, 237 ) -> None: 238 r"""Instantiate the optimizer object from `optimizer_cfg`.""" 239 240 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 241 pylogger.debug( 242 "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...", 243 optimizer_cfg.get("_target_"), 244 ) 245 self.optimizer = hydra.utils.instantiate(optimizer_cfg) 246 pylogger.debug( 247 "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!", 248 optimizer_cfg.get("_target_"), 249 )
Instantiate the optimizer object from optimizer_cfg.
251 def instantiate_lr_scheduler( 252 self, 253 lr_scheduler_cfg: DictConfig, 254 ) -> None: 255 r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`.""" 256 257 # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now. 258 pylogger.debug( 259 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...", 260 lr_scheduler_cfg.get("_target_"), 261 ) 262 self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg) 263 pylogger.debug( 264 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!", 265 lr_scheduler_cfg.get("_target_"), 266 )
Instantiate the learning rate scheduler object from lr_scheduler_cfg.
268 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 269 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 270 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 271 self.lightning_loggers = [ 272 hydra.utils.instantiate(lightning_logger) 273 for lightning_logger in lightning_loggers_cfg.values() 274 ] 275 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
Instantiate the list of lightning loggers objects from lightning_loggers_cfg.
277 def instantiate_callbacks( 278 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 279 ) -> None: 280 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 281 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 282 283 # instantiate metric callbacks 284 metric_callbacks = [ 285 hydra.utils.instantiate(callback) for callback in metrics_cfg 286 ] 287 288 # instantiate other callbacks 289 other_callbacks = [ 290 hydra.utils.instantiate(callback) for callback in callbacks_cfg 291 ] 292 293 # add metric callbacks to the list of callbacks 294 self.callbacks = metric_callbacks + other_callbacks 295 pylogger.debug("Callbacks (lightning.Callback) instantiated!")
Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.
297 def instantiate_trainer( 298 self, 299 trainer_cfg: DictConfig, 300 lightning_loggers: list[Logger], 301 callbacks: list[Callback], 302 ) -> None: 303 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 304 305 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 306 self.trainer = hydra.utils.instantiate( 307 trainer_cfg, logger=lightning_loggers, callbacks=callbacks 308 ) 309 pylogger.debug("Trainer (lightning.Trainer) instantiated!")
Instantiate the trainer object from trainer_cfg, lightning_loggers, and callbacks.
311 def set_global_seed(self, global_seed: int) -> None: 312 r"""Set the `global_seed` for the entire experiment.""" 313 L.seed_everything(self.global_seed, workers=True) 314 pylogger.debug("Global seed is set as %d.", global_seed)
Set the global_seed for the entire experiment.
316 def run(self) -> None: 317 r"""The main method to run the multi-task learning experiment.""" 318 self.set_global_seed(self.global_seed) 319 320 self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset) 321 self.instantiate_backbone(backbone_cfg=self.cfg.backbone) 322 self.instantiate_heads(input_dim=self.cfg.backbone.output_dim) 323 self.instantiate_mtl_algorithm( 324 mtl_algorithm_cfg=self.cfg.mtl_algorithm, 325 backbone=self.backbone, 326 heads=self.heads, 327 non_algorithmic_hparams=select_hyperparameters_from_config( 328 cfg=self.cfg, type=self.cfg.pipeline 329 ), 330 ) # mtl_algorithm should be instantiated after backbone and heads 331 self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer) 332 self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler) 333 self.instantiate_lightning_loggers( 334 lightning_loggers_cfg=self.cfg.lightning_loggers 335 ) 336 self.instantiate_callbacks( 337 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 338 ) 339 self.instantiate_trainer( 340 trainer_cfg=self.cfg.trainer, 341 lightning_loggers=self.lightning_loggers, 342 callbacks=self.callbacks, 343 ) # trainer should be instantiated after lightning loggers and callbacks 344 345 # setup tasks for dataset and model 346 self.mtl_dataset.setup_tasks_expr( 347 train_tasks=self.train_tasks, eval_tasks=self.eval_tasks 348 ) 349 self.model.setup_tasks( 350 task_ids=self.train_tasks, 351 num_classes={ 352 task_id: len(self.mtl_dataset.get_mtl_class_map(task_id)) 353 for task_id in self.train_tasks 354 }, 355 optimizer=self.optimizer, 356 lr_scheduler=self.lr_scheduler, 357 ) 358 359 # train and validate the model 360 self.trainer.fit( 361 model=self.model, 362 datamodule=self.mtl_dataset, 363 ) 364 365 # evaluation after training and validation 366 self.trainer.test( 367 model=self.model, 368 datamodule=self.mtl_dataset, 369 )
The main method to run the multi-task learning experiment.