clarena.pipelines.cl_main_expr
The submodule in pipelines for continual learning main experiment.
1r""" 2The submodule in `pipelines` for continual learning main experiment. 3""" 4 5__all__ = ["CLMainExperiment"] 6 7import logging 8from typing import Any 9 10import hydra 11import lightning as L 12from lightning import Callback, Trainer 13from lightning.pytorch.loggers import Logger 14from omegaconf import DictConfig, ListConfig 15from torch.optim import Optimizer 16from torch.optim.lr_scheduler import LRScheduler 17 18from clarena.backbones import CLBackbone 19from clarena.cl_algorithms import CLAlgorithm 20from clarena.cl_datasets import CLDataset 21from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL 22from clarena.utils.cfg import select_hyperparameters_from_config 23 24# always get logger for built-in logging in each module 25pylogger = logging.getLogger(__name__) 26 27 28class CLMainExperiment: 29 r"""The base class for continual learning main experiment.""" 30 31 def __init__(self, cfg: DictConfig) -> None: 32 r""" 33 **Args:** 34 - **cfg** (`DictConfig`): the complete config dict for the continual learning main experiment. 35 """ 36 self.cfg: DictConfig = cfg 37 r"""The complete config dict.""" 38 39 CLMainExperiment.sanity_check(self) 40 41 # required config fields 42 self.cl_paradigm: str = cfg.cl_paradigm 43 r"""The continual learning paradigm.""" 44 self.train_tasks: list[int] = ( 45 cfg.train_tasks 46 if isinstance(cfg.train_tasks, ListConfig) 47 else list(range(1, cfg.train_tasks + 1)) 48 ) 49 r"""The list of task IDs to train.""" 50 self.eval_after_tasks: list[int] = ( 51 cfg.eval_after_tasks 52 if isinstance(cfg.eval_after_tasks, ListConfig) 53 else list(range(1, cfg.eval_after_tasks + 1)) 54 ) 55 r"""If task ID $t$ is in this list, run the evaluation process for all seen tasks after training task $t$.""" 56 self.global_seed: int = cfg.global_seed 57 r"""The global seed for the entire experiment.""" 58 self.output_dir: str = cfg.output_dir 59 r"""The folder for storing the experiment results.""" 60 61 # components 62 63 # global components 64 self.cl_dataset: CLDataset 65 r"""CL dataset object.""" 66 self.backbone: CLBackbone 67 r"""Backbone network object.""" 68 self.heads: HeadsTIL | HeadsCIL 69 r"""CL output heads object.""" 70 self.model: CLAlgorithm 71 r"""CL model object.""" 72 self.lightning_loggers: list[Logger] 73 r"""Lightning logger objects.""" 74 self.callbacks: list[Callback] 75 r"""Callback objects.""" 76 77 # task-specific components 78 self.optimizer_t: Optimizer 79 r"""Optimizer object for the current task `self.task_id`.""" 80 self.lr_scheduler_t: LRScheduler | None = None 81 r"""Learning rate scheduler object for the current task `self.task_id`.""" 82 self.trainer_t: Trainer 83 r"""Trainer object for the current task `self.task_id`.""" 84 85 # task ID control 86 self.task_id: int 87 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 88 self.processed_task_ids: list[int] = [] 89 r"""Task IDs that have been processed.""" 90 91 def sanity_check(self) -> None: 92 r"""Sanity check for config.""" 93 94 # check required config fields 95 required_config_fields = [ 96 "pipeline", 97 "expr_name", 98 "cl_paradigm", 99 "train_tasks", 100 "eval_after_tasks", 101 "global_seed", 102 "cl_dataset", 103 "cl_algorithm", 104 "backbone", 105 "optimizer", 106 "trainer", 107 "metrics", 108 "lightning_loggers", 109 "callbacks", 110 "output_dir", 111 # "hydra" is excluded as it doesn't appear 112 "misc", 113 ] 114 for field in required_config_fields: 115 if not self.cfg.get(field): 116 raise KeyError( 117 f"Field `{field}` is required in the experiment index config." 118 ) 119 120 # check cl_paradigm 121 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 122 raise ValueError( 123 f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 124 ) 125 126 # get dataset number of tasks 127 if self.cfg.cl_dataset.get("num_tasks"): 128 num_tasks = self.cfg.cl_dataset.get("num_tasks") 129 elif self.cfg.cl_dataset.get("class_split"): 130 num_tasks = len(self.cfg.cl_dataset.class_split) 131 elif self.cfg.cl_dataset.get("datasets"): 132 num_tasks = len(self.cfg.cl_dataset.datasets) 133 else: 134 raise KeyError( 135 "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 136 ) 137 138 # check train_tasks 139 train_tasks = self.cfg.train_tasks 140 if isinstance(train_tasks, ListConfig): 141 if len(train_tasks) < 1: 142 raise ValueError("`train_tasks` config must contain at least one task.") 143 if any(t < 1 or t > num_tasks for t in train_tasks): 144 raise ValueError( 145 f"All task IDs in `train_tasks` config must be between 1 and {num_tasks}." 146 ) 147 elif isinstance(train_tasks, int): 148 if train_tasks < 0 or train_tasks > num_tasks: 149 raise ValueError( 150 f"`train_tasks` config as integer must be between 0 and {num_tasks}." 151 ) 152 else: 153 raise TypeError( 154 "`train_tasks` config must be either a list of integers or an integer." 155 ) 156 157 # check eval_after_tasks 158 eval_after_tasks = self.cfg.eval_after_tasks 159 if isinstance(eval_after_tasks, ListConfig): 160 if len(eval_after_tasks) < 1: 161 raise ValueError( 162 "`eval_after_tasks` config must contain at least one task." 163 ) 164 if any(t < 1 or t > num_tasks for t in eval_after_tasks): 165 raise ValueError( 166 f"All task IDs in `eval_after_tasks` config must be between 1 and {num_tasks}." 167 ) 168 elif isinstance(eval_after_tasks, int): 169 if eval_after_tasks < 0 or eval_after_tasks > num_tasks: 170 raise ValueError( 171 f"`eval_after_tasks` config as integer must be between 0 and {num_tasks}." 172 ) 173 else: 174 raise TypeError( 175 "`eval_after_tasks` config must be either a list of integers or an integer." 176 ) 177 178 # check that eval_after_tasks is a subset of train_tasks 179 if isinstance(train_tasks, list) and isinstance(eval_after_tasks, list): 180 if not set(eval_after_tasks).issubset(set(train_tasks)): 181 raise ValueError( 182 "`eval_after_tasks` config must be a subset of `train_tasks` config." 183 ) 184 185 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 186 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 187 pylogger.debug( 188 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 189 cl_dataset_cfg.get("_target_"), 190 ) 191 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 192 pylogger.debug( 193 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 194 cl_dataset_cfg.get("_target_"), 195 ) 196 197 def instantiate_backbone( 198 self, backbone_cfg: DictConfig, disable_unlearning: bool 199 ) -> None: 200 r"""Instantiate the CL backbone network object from `backbone_cfg`.""" 201 pylogger.debug( 202 "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...", 203 backbone_cfg.get("_target_"), 204 ) 205 self.backbone = hydra.utils.instantiate( 206 backbone_cfg, disable_unlearning=disable_unlearning 207 ) 208 pylogger.debug( 209 "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!", 210 backbone_cfg.get("_target_"), 211 ) 212 213 def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None: 214 r"""Instantiate the CL output heads object. 215 216 **Args:** 217 - **cl_paradigm** (`str`): the CL paradigm, either 'TIL', 'CIL' or 'DIL'. 'TIL' uses `HeadsTIL`, 'CIL' uses `HeadsCIL`, and 'DIL' uses `HeadDIL`. 218 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 219 """ 220 pylogger.debug( 221 "CL paradigm is set as %s. Instantiating %s heads...", 222 cl_paradigm, 223 cl_paradigm, 224 ) 225 if cl_paradigm == "TIL": 226 self.heads = HeadsTIL(input_dim=input_dim) 227 elif cl_paradigm == "CIL": 228 self.heads = HeadsCIL(input_dim=input_dim) 229 elif cl_paradigm == "DIL": 230 self.heads = HeadDIL(input_dim=input_dim) 231 232 pylogger.debug("%s heads instantiated!", cl_paradigm) 233 234 def instantiate_cl_algorithm( 235 self, 236 cl_algorithm_cfg: DictConfig, 237 backbone: CLBackbone, 238 heads: HeadsTIL | HeadsCIL | HeadDIL, 239 non_algorithmic_hparams: dict[str, Any], 240 disable_unlearning: bool, 241 ) -> None: 242 r"""Instantiate the cl_algorithm object from `cl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 243 pylogger.debug( 244 "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...", 245 cl_algorithm_cfg.get("_target_"), 246 cl_algorithm_cfg.get("_target_"), 247 ) 248 self.model = hydra.utils.instantiate( 249 cl_algorithm_cfg, 250 backbone=backbone, 251 heads=heads, 252 non_algorithmic_hparams=non_algorithmic_hparams, 253 disable_unlearning=disable_unlearning, 254 ) 255 pylogger.debug( 256 "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!", 257 cl_algorithm_cfg.get("_target_"), 258 ) 259 260 def instantiate_optimizer( 261 self, 262 optimizer_cfg: DictConfig, 263 task_id: int, 264 ) -> None: 265 r"""Instantiate the optimizer object for task `task_id` from `optimizer_cfg`.""" 266 267 # distinguish whether the optimizer config is uniform or task-specific 268 if not optimizer_cfg.get("_target_"): 269 pylogger.debug("Distinct optimizer config is applied to each task.") 270 optimizer_cfg = optimizer_cfg[task_id] 271 else: 272 pylogger.debug("Uniform optimizer config is applied to all tasks.") 273 274 # partially instantiate optimizer as the 'params' argument from Lightning Modules cannot be passed for now 275 pylogger.debug( 276 "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...", 277 optimizer_cfg.get("_target_"), 278 task_id, 279 ) 280 self.optimizer_t = hydra.utils.instantiate(optimizer_cfg) 281 pylogger.debug( 282 "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!", 283 optimizer_cfg.get("_target_"), 284 task_id, 285 ) 286 287 def instantiate_lr_scheduler( 288 self, 289 lr_scheduler_cfg: DictConfig, 290 task_id: int, 291 ) -> None: 292 r"""Instantiate the learning rate scheduler object for task `task_id` from `lr_scheduler_cfg`.""" 293 294 # distinguish whether the learning rate scheduler config is uniform or task-specific 295 if not lr_scheduler_cfg.get("_target_"): 296 pylogger.debug( 297 "Distinct learning rate scheduler config is applied to each task." 298 ) 299 lr_scheduler_cfg = lr_scheduler_cfg[task_id] 300 else: 301 pylogger.debug( 302 "Uniform learning rate scheduler config is applied to all tasks." 303 ) 304 305 # partially instantiate learning rate scheduler as the 'optimizer' argument from Lightning Modules cannot be passed for now 306 pylogger.debug( 307 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) for task %d...", 308 lr_scheduler_cfg.get("_target_"), 309 task_id, 310 ) 311 self.lr_scheduler_t = hydra.utils.instantiate(lr_scheduler_cfg) 312 pylogger.debug( 313 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially for task %d instantiated!", 314 lr_scheduler_cfg.get("_target_"), 315 task_id, 316 ) 317 318 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 319 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 320 321 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 322 self.lightning_loggers = [ 323 hydra.utils.instantiate(lightning_logger) 324 for lightning_logger in lightning_loggers_cfg.values() 325 ] 326 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!") 327 328 def instantiate_callbacks( 329 self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig 330 ) -> None: 331 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 332 pylogger.debug( 333 "Instantiating callbacks (lightning.Callback)...", 334 ) 335 336 # instantiate metric callbacks 337 metric_callbacks = [ 338 hydra.utils.instantiate(callback) for callback in metrics_cfg 339 ] 340 341 # instantiate other callbacks 342 other_callbacks = [ 343 hydra.utils.instantiate(callback) for callback in callbacks_cfg 344 ] 345 346 # add metric callbacks to the list of callbacks 347 self.callbacks = metric_callbacks + other_callbacks 348 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 349 350 def instantiate_trainer( 351 self, 352 trainer_cfg: DictConfig, 353 lightning_loggers: list[Logger], 354 callbacks: list[Callback], 355 task_id: int, 356 ) -> None: 357 r"""Instantiate the trainer object for task `task_id` from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 358 359 if not trainer_cfg.get("_target_"): 360 pylogger.debug("Distinct trainer config is applied to each task.") 361 trainer_cfg = trainer_cfg[task_id] 362 else: 363 pylogger.debug("Uniform trainer config is applied to all tasks.") 364 365 pylogger.debug( 366 "Instantiating trainer (lightning.Trainer) for task %d...", 367 task_id, 368 ) 369 self.trainer_t = hydra.utils.instantiate( 370 trainer_cfg, 371 logger=lightning_loggers, 372 callbacks=callbacks, 373 ) 374 pylogger.debug( 375 "Trainer (lightning.Trainer) for task %d instantiated!", 376 task_id, 377 ) 378 379 def set_global_seed(self, global_seed: int) -> None: 380 r"""Set the `global_seed` for the entire experiment.""" 381 L.seed_everything(self.global_seed, workers=True) 382 pylogger.debug("Global seed is set as %d.", global_seed) 383 384 def run(self) -> None: 385 r"""The main method to run the continual learning main experiment.""" 386 387 self.set_global_seed(self.global_seed) 388 389 # global components 390 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 391 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 392 self.instantiate_backbone( 393 backbone_cfg=self.cfg.backbone, disable_unlearning=True 394 ) 395 self.instantiate_heads( 396 cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim 397 ) 398 self.instantiate_cl_algorithm( 399 cl_algorithm_cfg=self.cfg.cl_algorithm, 400 backbone=self.backbone, 401 heads=self.heads, 402 non_algorithmic_hparams=select_hyperparameters_from_config( 403 cfg=self.cfg, type=self.cfg.pipeline 404 ), 405 disable_unlearning=True, 406 ) # cl_algorithm should be instantiated after backbone and heads 407 self.instantiate_lightning_loggers( 408 lightning_loggers_cfg=self.cfg.lightning_loggers 409 ) 410 self.instantiate_callbacks( 411 metrics_cfg=self.cfg.metrics, 412 callbacks_cfg=self.cfg.callbacks, 413 ) 414 415 # task loop 416 for task_id in self.train_tasks: 417 418 self.task_id = task_id 419 420 # task-specific components 421 self.instantiate_optimizer( 422 optimizer_cfg=self.cfg.optimizer, 423 task_id=task_id, 424 ) 425 if self.cfg.get("lr_scheduler"): 426 self.instantiate_lr_scheduler( 427 lr_scheduler_cfg=self.cfg.lr_scheduler, 428 task_id=task_id, 429 ) 430 self.instantiate_trainer( 431 trainer_cfg=self.cfg.trainer, 432 lightning_loggers=self.lightning_loggers, 433 callbacks=self.callbacks, 434 task_id=task_id, 435 ) # trainer should be instantiated after lightning loggers and callbacks 436 437 # setup task ID for dataset and model 438 self.cl_dataset.setup_task_id(task_id=task_id) 439 self.model.setup_task_id( 440 task_id=task_id, 441 num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)), 442 optimizer=self.optimizer_t, 443 lr_scheduler=self.lr_scheduler_t, 444 ) 445 446 # train and validate the model 447 self.trainer_t.fit( 448 model=self.model, 449 datamodule=self.cl_dataset, 450 ) 451 452 # evaluation after training and validation 453 if task_id in self.eval_after_tasks: 454 self.trainer_t.test( 455 model=self.model, 456 datamodule=self.cl_dataset, 457 ) 458 459 self.processed_task_ids.append(task_id)
29class CLMainExperiment: 30 r"""The base class for continual learning main experiment.""" 31 32 def __init__(self, cfg: DictConfig) -> None: 33 r""" 34 **Args:** 35 - **cfg** (`DictConfig`): the complete config dict for the continual learning main experiment. 36 """ 37 self.cfg: DictConfig = cfg 38 r"""The complete config dict.""" 39 40 CLMainExperiment.sanity_check(self) 41 42 # required config fields 43 self.cl_paradigm: str = cfg.cl_paradigm 44 r"""The continual learning paradigm.""" 45 self.train_tasks: list[int] = ( 46 cfg.train_tasks 47 if isinstance(cfg.train_tasks, ListConfig) 48 else list(range(1, cfg.train_tasks + 1)) 49 ) 50 r"""The list of task IDs to train.""" 51 self.eval_after_tasks: list[int] = ( 52 cfg.eval_after_tasks 53 if isinstance(cfg.eval_after_tasks, ListConfig) 54 else list(range(1, cfg.eval_after_tasks + 1)) 55 ) 56 r"""If task ID $t$ is in this list, run the evaluation process for all seen tasks after training task $t$.""" 57 self.global_seed: int = cfg.global_seed 58 r"""The global seed for the entire experiment.""" 59 self.output_dir: str = cfg.output_dir 60 r"""The folder for storing the experiment results.""" 61 62 # components 63 64 # global components 65 self.cl_dataset: CLDataset 66 r"""CL dataset object.""" 67 self.backbone: CLBackbone 68 r"""Backbone network object.""" 69 self.heads: HeadsTIL | HeadsCIL 70 r"""CL output heads object.""" 71 self.model: CLAlgorithm 72 r"""CL model object.""" 73 self.lightning_loggers: list[Logger] 74 r"""Lightning logger objects.""" 75 self.callbacks: list[Callback] 76 r"""Callback objects.""" 77 78 # task-specific components 79 self.optimizer_t: Optimizer 80 r"""Optimizer object for the current task `self.task_id`.""" 81 self.lr_scheduler_t: LRScheduler | None = None 82 r"""Learning rate scheduler object for the current task `self.task_id`.""" 83 self.trainer_t: Trainer 84 r"""Trainer object for the current task `self.task_id`.""" 85 86 # task ID control 87 self.task_id: int 88 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 89 self.processed_task_ids: list[int] = [] 90 r"""Task IDs that have been processed.""" 91 92 def sanity_check(self) -> None: 93 r"""Sanity check for config.""" 94 95 # check required config fields 96 required_config_fields = [ 97 "pipeline", 98 "expr_name", 99 "cl_paradigm", 100 "train_tasks", 101 "eval_after_tasks", 102 "global_seed", 103 "cl_dataset", 104 "cl_algorithm", 105 "backbone", 106 "optimizer", 107 "trainer", 108 "metrics", 109 "lightning_loggers", 110 "callbacks", 111 "output_dir", 112 # "hydra" is excluded as it doesn't appear 113 "misc", 114 ] 115 for field in required_config_fields: 116 if not self.cfg.get(field): 117 raise KeyError( 118 f"Field `{field}` is required in the experiment index config." 119 ) 120 121 # check cl_paradigm 122 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 123 raise ValueError( 124 f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 125 ) 126 127 # get dataset number of tasks 128 if self.cfg.cl_dataset.get("num_tasks"): 129 num_tasks = self.cfg.cl_dataset.get("num_tasks") 130 elif self.cfg.cl_dataset.get("class_split"): 131 num_tasks = len(self.cfg.cl_dataset.class_split) 132 elif self.cfg.cl_dataset.get("datasets"): 133 num_tasks = len(self.cfg.cl_dataset.datasets) 134 else: 135 raise KeyError( 136 "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 137 ) 138 139 # check train_tasks 140 train_tasks = self.cfg.train_tasks 141 if isinstance(train_tasks, ListConfig): 142 if len(train_tasks) < 1: 143 raise ValueError("`train_tasks` config must contain at least one task.") 144 if any(t < 1 or t > num_tasks for t in train_tasks): 145 raise ValueError( 146 f"All task IDs in `train_tasks` config must be between 1 and {num_tasks}." 147 ) 148 elif isinstance(train_tasks, int): 149 if train_tasks < 0 or train_tasks > num_tasks: 150 raise ValueError( 151 f"`train_tasks` config as integer must be between 0 and {num_tasks}." 152 ) 153 else: 154 raise TypeError( 155 "`train_tasks` config must be either a list of integers or an integer." 156 ) 157 158 # check eval_after_tasks 159 eval_after_tasks = self.cfg.eval_after_tasks 160 if isinstance(eval_after_tasks, ListConfig): 161 if len(eval_after_tasks) < 1: 162 raise ValueError( 163 "`eval_after_tasks` config must contain at least one task." 164 ) 165 if any(t < 1 or t > num_tasks for t in eval_after_tasks): 166 raise ValueError( 167 f"All task IDs in `eval_after_tasks` config must be between 1 and {num_tasks}." 168 ) 169 elif isinstance(eval_after_tasks, int): 170 if eval_after_tasks < 0 or eval_after_tasks > num_tasks: 171 raise ValueError( 172 f"`eval_after_tasks` config as integer must be between 0 and {num_tasks}." 173 ) 174 else: 175 raise TypeError( 176 "`eval_after_tasks` config must be either a list of integers or an integer." 177 ) 178 179 # check that eval_after_tasks is a subset of train_tasks 180 if isinstance(train_tasks, list) and isinstance(eval_after_tasks, list): 181 if not set(eval_after_tasks).issubset(set(train_tasks)): 182 raise ValueError( 183 "`eval_after_tasks` config must be a subset of `train_tasks` config." 184 ) 185 186 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 187 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 188 pylogger.debug( 189 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 190 cl_dataset_cfg.get("_target_"), 191 ) 192 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 193 pylogger.debug( 194 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 195 cl_dataset_cfg.get("_target_"), 196 ) 197 198 def instantiate_backbone( 199 self, backbone_cfg: DictConfig, disable_unlearning: bool 200 ) -> None: 201 r"""Instantiate the CL backbone network object from `backbone_cfg`.""" 202 pylogger.debug( 203 "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...", 204 backbone_cfg.get("_target_"), 205 ) 206 self.backbone = hydra.utils.instantiate( 207 backbone_cfg, disable_unlearning=disable_unlearning 208 ) 209 pylogger.debug( 210 "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!", 211 backbone_cfg.get("_target_"), 212 ) 213 214 def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None: 215 r"""Instantiate the CL output heads object. 216 217 **Args:** 218 - **cl_paradigm** (`str`): the CL paradigm, either 'TIL', 'CIL' or 'DIL'. 'TIL' uses `HeadsTIL`, 'CIL' uses `HeadsCIL`, and 'DIL' uses `HeadDIL`. 219 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 220 """ 221 pylogger.debug( 222 "CL paradigm is set as %s. Instantiating %s heads...", 223 cl_paradigm, 224 cl_paradigm, 225 ) 226 if cl_paradigm == "TIL": 227 self.heads = HeadsTIL(input_dim=input_dim) 228 elif cl_paradigm == "CIL": 229 self.heads = HeadsCIL(input_dim=input_dim) 230 elif cl_paradigm == "DIL": 231 self.heads = HeadDIL(input_dim=input_dim) 232 233 pylogger.debug("%s heads instantiated!", cl_paradigm) 234 235 def instantiate_cl_algorithm( 236 self, 237 cl_algorithm_cfg: DictConfig, 238 backbone: CLBackbone, 239 heads: HeadsTIL | HeadsCIL | HeadDIL, 240 non_algorithmic_hparams: dict[str, Any], 241 disable_unlearning: bool, 242 ) -> None: 243 r"""Instantiate the cl_algorithm object from `cl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 244 pylogger.debug( 245 "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...", 246 cl_algorithm_cfg.get("_target_"), 247 cl_algorithm_cfg.get("_target_"), 248 ) 249 self.model = hydra.utils.instantiate( 250 cl_algorithm_cfg, 251 backbone=backbone, 252 heads=heads, 253 non_algorithmic_hparams=non_algorithmic_hparams, 254 disable_unlearning=disable_unlearning, 255 ) 256 pylogger.debug( 257 "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!", 258 cl_algorithm_cfg.get("_target_"), 259 ) 260 261 def instantiate_optimizer( 262 self, 263 optimizer_cfg: DictConfig, 264 task_id: int, 265 ) -> None: 266 r"""Instantiate the optimizer object for task `task_id` from `optimizer_cfg`.""" 267 268 # distinguish whether the optimizer config is uniform or task-specific 269 if not optimizer_cfg.get("_target_"): 270 pylogger.debug("Distinct optimizer config is applied to each task.") 271 optimizer_cfg = optimizer_cfg[task_id] 272 else: 273 pylogger.debug("Uniform optimizer config is applied to all tasks.") 274 275 # partially instantiate optimizer as the 'params' argument from Lightning Modules cannot be passed for now 276 pylogger.debug( 277 "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...", 278 optimizer_cfg.get("_target_"), 279 task_id, 280 ) 281 self.optimizer_t = hydra.utils.instantiate(optimizer_cfg) 282 pylogger.debug( 283 "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!", 284 optimizer_cfg.get("_target_"), 285 task_id, 286 ) 287 288 def instantiate_lr_scheduler( 289 self, 290 lr_scheduler_cfg: DictConfig, 291 task_id: int, 292 ) -> None: 293 r"""Instantiate the learning rate scheduler object for task `task_id` from `lr_scheduler_cfg`.""" 294 295 # distinguish whether the learning rate scheduler config is uniform or task-specific 296 if not lr_scheduler_cfg.get("_target_"): 297 pylogger.debug( 298 "Distinct learning rate scheduler config is applied to each task." 299 ) 300 lr_scheduler_cfg = lr_scheduler_cfg[task_id] 301 else: 302 pylogger.debug( 303 "Uniform learning rate scheduler config is applied to all tasks." 304 ) 305 306 # partially instantiate learning rate scheduler as the 'optimizer' argument from Lightning Modules cannot be passed for now 307 pylogger.debug( 308 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) for task %d...", 309 lr_scheduler_cfg.get("_target_"), 310 task_id, 311 ) 312 self.lr_scheduler_t = hydra.utils.instantiate(lr_scheduler_cfg) 313 pylogger.debug( 314 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially for task %d instantiated!", 315 lr_scheduler_cfg.get("_target_"), 316 task_id, 317 ) 318 319 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 320 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 321 322 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 323 self.lightning_loggers = [ 324 hydra.utils.instantiate(lightning_logger) 325 for lightning_logger in lightning_loggers_cfg.values() 326 ] 327 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!") 328 329 def instantiate_callbacks( 330 self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig 331 ) -> None: 332 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 333 pylogger.debug( 334 "Instantiating callbacks (lightning.Callback)...", 335 ) 336 337 # instantiate metric callbacks 338 metric_callbacks = [ 339 hydra.utils.instantiate(callback) for callback in metrics_cfg 340 ] 341 342 # instantiate other callbacks 343 other_callbacks = [ 344 hydra.utils.instantiate(callback) for callback in callbacks_cfg 345 ] 346 347 # add metric callbacks to the list of callbacks 348 self.callbacks = metric_callbacks + other_callbacks 349 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 350 351 def instantiate_trainer( 352 self, 353 trainer_cfg: DictConfig, 354 lightning_loggers: list[Logger], 355 callbacks: list[Callback], 356 task_id: int, 357 ) -> None: 358 r"""Instantiate the trainer object for task `task_id` from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 359 360 if not trainer_cfg.get("_target_"): 361 pylogger.debug("Distinct trainer config is applied to each task.") 362 trainer_cfg = trainer_cfg[task_id] 363 else: 364 pylogger.debug("Uniform trainer config is applied to all tasks.") 365 366 pylogger.debug( 367 "Instantiating trainer (lightning.Trainer) for task %d...", 368 task_id, 369 ) 370 self.trainer_t = hydra.utils.instantiate( 371 trainer_cfg, 372 logger=lightning_loggers, 373 callbacks=callbacks, 374 ) 375 pylogger.debug( 376 "Trainer (lightning.Trainer) for task %d instantiated!", 377 task_id, 378 ) 379 380 def set_global_seed(self, global_seed: int) -> None: 381 r"""Set the `global_seed` for the entire experiment.""" 382 L.seed_everything(self.global_seed, workers=True) 383 pylogger.debug("Global seed is set as %d.", global_seed) 384 385 def run(self) -> None: 386 r"""The main method to run the continual learning main experiment.""" 387 388 self.set_global_seed(self.global_seed) 389 390 # global components 391 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 392 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 393 self.instantiate_backbone( 394 backbone_cfg=self.cfg.backbone, disable_unlearning=True 395 ) 396 self.instantiate_heads( 397 cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim 398 ) 399 self.instantiate_cl_algorithm( 400 cl_algorithm_cfg=self.cfg.cl_algorithm, 401 backbone=self.backbone, 402 heads=self.heads, 403 non_algorithmic_hparams=select_hyperparameters_from_config( 404 cfg=self.cfg, type=self.cfg.pipeline 405 ), 406 disable_unlearning=True, 407 ) # cl_algorithm should be instantiated after backbone and heads 408 self.instantiate_lightning_loggers( 409 lightning_loggers_cfg=self.cfg.lightning_loggers 410 ) 411 self.instantiate_callbacks( 412 metrics_cfg=self.cfg.metrics, 413 callbacks_cfg=self.cfg.callbacks, 414 ) 415 416 # task loop 417 for task_id in self.train_tasks: 418 419 self.task_id = task_id 420 421 # task-specific components 422 self.instantiate_optimizer( 423 optimizer_cfg=self.cfg.optimizer, 424 task_id=task_id, 425 ) 426 if self.cfg.get("lr_scheduler"): 427 self.instantiate_lr_scheduler( 428 lr_scheduler_cfg=self.cfg.lr_scheduler, 429 task_id=task_id, 430 ) 431 self.instantiate_trainer( 432 trainer_cfg=self.cfg.trainer, 433 lightning_loggers=self.lightning_loggers, 434 callbacks=self.callbacks, 435 task_id=task_id, 436 ) # trainer should be instantiated after lightning loggers and callbacks 437 438 # setup task ID for dataset and model 439 self.cl_dataset.setup_task_id(task_id=task_id) 440 self.model.setup_task_id( 441 task_id=task_id, 442 num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)), 443 optimizer=self.optimizer_t, 444 lr_scheduler=self.lr_scheduler_t, 445 ) 446 447 # train and validate the model 448 self.trainer_t.fit( 449 model=self.model, 450 datamodule=self.cl_dataset, 451 ) 452 453 # evaluation after training and validation 454 if task_id in self.eval_after_tasks: 455 self.trainer_t.test( 456 model=self.model, 457 datamodule=self.cl_dataset, 458 ) 459 460 self.processed_task_ids.append(task_id)
The base class for continual learning main experiment.
32 def __init__(self, cfg: DictConfig) -> None: 33 r""" 34 **Args:** 35 - **cfg** (`DictConfig`): the complete config dict for the continual learning main experiment. 36 """ 37 self.cfg: DictConfig = cfg 38 r"""The complete config dict.""" 39 40 CLMainExperiment.sanity_check(self) 41 42 # required config fields 43 self.cl_paradigm: str = cfg.cl_paradigm 44 r"""The continual learning paradigm.""" 45 self.train_tasks: list[int] = ( 46 cfg.train_tasks 47 if isinstance(cfg.train_tasks, ListConfig) 48 else list(range(1, cfg.train_tasks + 1)) 49 ) 50 r"""The list of task IDs to train.""" 51 self.eval_after_tasks: list[int] = ( 52 cfg.eval_after_tasks 53 if isinstance(cfg.eval_after_tasks, ListConfig) 54 else list(range(1, cfg.eval_after_tasks + 1)) 55 ) 56 r"""If task ID $t$ is in this list, run the evaluation process for all seen tasks after training task $t$.""" 57 self.global_seed: int = cfg.global_seed 58 r"""The global seed for the entire experiment.""" 59 self.output_dir: str = cfg.output_dir 60 r"""The folder for storing the experiment results.""" 61 62 # components 63 64 # global components 65 self.cl_dataset: CLDataset 66 r"""CL dataset object.""" 67 self.backbone: CLBackbone 68 r"""Backbone network object.""" 69 self.heads: HeadsTIL | HeadsCIL 70 r"""CL output heads object.""" 71 self.model: CLAlgorithm 72 r"""CL model object.""" 73 self.lightning_loggers: list[Logger] 74 r"""Lightning logger objects.""" 75 self.callbacks: list[Callback] 76 r"""Callback objects.""" 77 78 # task-specific components 79 self.optimizer_t: Optimizer 80 r"""Optimizer object for the current task `self.task_id`.""" 81 self.lr_scheduler_t: LRScheduler | None = None 82 r"""Learning rate scheduler object for the current task `self.task_id`.""" 83 self.trainer_t: Trainer 84 r"""Trainer object for the current task `self.task_id`.""" 85 86 # task ID control 87 self.task_id: int 88 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 89 self.processed_task_ids: list[int] = [] 90 r"""Task IDs that have been processed."""
Args:
- cfg (
DictConfig): the complete config dict for the continual learning main experiment.
If task ID $t$ is in this list, run the evaluation process for all seen tasks after training task $t$.
Learning rate scheduler object for the current task self.task_id.
Trainer object for the current task self.task_id.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.
92 def sanity_check(self) -> None: 93 r"""Sanity check for config.""" 94 95 # check required config fields 96 required_config_fields = [ 97 "pipeline", 98 "expr_name", 99 "cl_paradigm", 100 "train_tasks", 101 "eval_after_tasks", 102 "global_seed", 103 "cl_dataset", 104 "cl_algorithm", 105 "backbone", 106 "optimizer", 107 "trainer", 108 "metrics", 109 "lightning_loggers", 110 "callbacks", 111 "output_dir", 112 # "hydra" is excluded as it doesn't appear 113 "misc", 114 ] 115 for field in required_config_fields: 116 if not self.cfg.get(field): 117 raise KeyError( 118 f"Field `{field}` is required in the experiment index config." 119 ) 120 121 # check cl_paradigm 122 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 123 raise ValueError( 124 f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 125 ) 126 127 # get dataset number of tasks 128 if self.cfg.cl_dataset.get("num_tasks"): 129 num_tasks = self.cfg.cl_dataset.get("num_tasks") 130 elif self.cfg.cl_dataset.get("class_split"): 131 num_tasks = len(self.cfg.cl_dataset.class_split) 132 elif self.cfg.cl_dataset.get("datasets"): 133 num_tasks = len(self.cfg.cl_dataset.datasets) 134 else: 135 raise KeyError( 136 "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 137 ) 138 139 # check train_tasks 140 train_tasks = self.cfg.train_tasks 141 if isinstance(train_tasks, ListConfig): 142 if len(train_tasks) < 1: 143 raise ValueError("`train_tasks` config must contain at least one task.") 144 if any(t < 1 or t > num_tasks for t in train_tasks): 145 raise ValueError( 146 f"All task IDs in `train_tasks` config must be between 1 and {num_tasks}." 147 ) 148 elif isinstance(train_tasks, int): 149 if train_tasks < 0 or train_tasks > num_tasks: 150 raise ValueError( 151 f"`train_tasks` config as integer must be between 0 and {num_tasks}." 152 ) 153 else: 154 raise TypeError( 155 "`train_tasks` config must be either a list of integers or an integer." 156 ) 157 158 # check eval_after_tasks 159 eval_after_tasks = self.cfg.eval_after_tasks 160 if isinstance(eval_after_tasks, ListConfig): 161 if len(eval_after_tasks) < 1: 162 raise ValueError( 163 "`eval_after_tasks` config must contain at least one task." 164 ) 165 if any(t < 1 or t > num_tasks for t in eval_after_tasks): 166 raise ValueError( 167 f"All task IDs in `eval_after_tasks` config must be between 1 and {num_tasks}." 168 ) 169 elif isinstance(eval_after_tasks, int): 170 if eval_after_tasks < 0 or eval_after_tasks > num_tasks: 171 raise ValueError( 172 f"`eval_after_tasks` config as integer must be between 0 and {num_tasks}." 173 ) 174 else: 175 raise TypeError( 176 "`eval_after_tasks` config must be either a list of integers or an integer." 177 ) 178 179 # check that eval_after_tasks is a subset of train_tasks 180 if isinstance(train_tasks, list) and isinstance(eval_after_tasks, list): 181 if not set(eval_after_tasks).issubset(set(train_tasks)): 182 raise ValueError( 183 "`eval_after_tasks` config must be a subset of `train_tasks` config." 184 )
Sanity check for config.
186 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 187 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 188 pylogger.debug( 189 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 190 cl_dataset_cfg.get("_target_"), 191 ) 192 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 193 pylogger.debug( 194 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 195 cl_dataset_cfg.get("_target_"), 196 )
Instantiate the CL dataset object from cl_dataset_cfg.
198 def instantiate_backbone( 199 self, backbone_cfg: DictConfig, disable_unlearning: bool 200 ) -> None: 201 r"""Instantiate the CL backbone network object from `backbone_cfg`.""" 202 pylogger.debug( 203 "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...", 204 backbone_cfg.get("_target_"), 205 ) 206 self.backbone = hydra.utils.instantiate( 207 backbone_cfg, disable_unlearning=disable_unlearning 208 ) 209 pylogger.debug( 210 "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!", 211 backbone_cfg.get("_target_"), 212 )
Instantiate the CL backbone network object from backbone_cfg.
214 def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None: 215 r"""Instantiate the CL output heads object. 216 217 **Args:** 218 - **cl_paradigm** (`str`): the CL paradigm, either 'TIL', 'CIL' or 'DIL'. 'TIL' uses `HeadsTIL`, 'CIL' uses `HeadsCIL`, and 'DIL' uses `HeadDIL`. 219 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 220 """ 221 pylogger.debug( 222 "CL paradigm is set as %s. Instantiating %s heads...", 223 cl_paradigm, 224 cl_paradigm, 225 ) 226 if cl_paradigm == "TIL": 227 self.heads = HeadsTIL(input_dim=input_dim) 228 elif cl_paradigm == "CIL": 229 self.heads = HeadsCIL(input_dim=input_dim) 230 elif cl_paradigm == "DIL": 231 self.heads = HeadDIL(input_dim=input_dim) 232 233 pylogger.debug("%s heads instantiated!", cl_paradigm)
Instantiate the CL output heads object.
Args:
- cl_paradigm (
str): the CL paradigm, either 'TIL', 'CIL' or 'DIL'. 'TIL' usesHeadsTIL, 'CIL' usesHeadsCIL, and 'DIL' usesHeadDIL. - input_dim (
int): the input dimension of the heads. Must be equal to theoutput_dimof the connected backbone.
235 def instantiate_cl_algorithm( 236 self, 237 cl_algorithm_cfg: DictConfig, 238 backbone: CLBackbone, 239 heads: HeadsTIL | HeadsCIL | HeadDIL, 240 non_algorithmic_hparams: dict[str, Any], 241 disable_unlearning: bool, 242 ) -> None: 243 r"""Instantiate the cl_algorithm object from `cl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 244 pylogger.debug( 245 "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...", 246 cl_algorithm_cfg.get("_target_"), 247 cl_algorithm_cfg.get("_target_"), 248 ) 249 self.model = hydra.utils.instantiate( 250 cl_algorithm_cfg, 251 backbone=backbone, 252 heads=heads, 253 non_algorithmic_hparams=non_algorithmic_hparams, 254 disable_unlearning=disable_unlearning, 255 ) 256 pylogger.debug( 257 "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!", 258 cl_algorithm_cfg.get("_target_"), 259 )
261 def instantiate_optimizer( 262 self, 263 optimizer_cfg: DictConfig, 264 task_id: int, 265 ) -> None: 266 r"""Instantiate the optimizer object for task `task_id` from `optimizer_cfg`.""" 267 268 # distinguish whether the optimizer config is uniform or task-specific 269 if not optimizer_cfg.get("_target_"): 270 pylogger.debug("Distinct optimizer config is applied to each task.") 271 optimizer_cfg = optimizer_cfg[task_id] 272 else: 273 pylogger.debug("Uniform optimizer config is applied to all tasks.") 274 275 # partially instantiate optimizer as the 'params' argument from Lightning Modules cannot be passed for now 276 pylogger.debug( 277 "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...", 278 optimizer_cfg.get("_target_"), 279 task_id, 280 ) 281 self.optimizer_t = hydra.utils.instantiate(optimizer_cfg) 282 pylogger.debug( 283 "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!", 284 optimizer_cfg.get("_target_"), 285 task_id, 286 )
Instantiate the optimizer object for task task_id from optimizer_cfg.
288 def instantiate_lr_scheduler( 289 self, 290 lr_scheduler_cfg: DictConfig, 291 task_id: int, 292 ) -> None: 293 r"""Instantiate the learning rate scheduler object for task `task_id` from `lr_scheduler_cfg`.""" 294 295 # distinguish whether the learning rate scheduler config is uniform or task-specific 296 if not lr_scheduler_cfg.get("_target_"): 297 pylogger.debug( 298 "Distinct learning rate scheduler config is applied to each task." 299 ) 300 lr_scheduler_cfg = lr_scheduler_cfg[task_id] 301 else: 302 pylogger.debug( 303 "Uniform learning rate scheduler config is applied to all tasks." 304 ) 305 306 # partially instantiate learning rate scheduler as the 'optimizer' argument from Lightning Modules cannot be passed for now 307 pylogger.debug( 308 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) for task %d...", 309 lr_scheduler_cfg.get("_target_"), 310 task_id, 311 ) 312 self.lr_scheduler_t = hydra.utils.instantiate(lr_scheduler_cfg) 313 pylogger.debug( 314 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially for task %d instantiated!", 315 lr_scheduler_cfg.get("_target_"), 316 task_id, 317 )
Instantiate the learning rate scheduler object for task task_id from lr_scheduler_cfg.
319 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 320 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 321 322 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 323 self.lightning_loggers = [ 324 hydra.utils.instantiate(lightning_logger) 325 for lightning_logger in lightning_loggers_cfg.values() 326 ] 327 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
Instantiate the list of lightning loggers objects from lightning_loggers_cfg.
329 def instantiate_callbacks( 330 self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig 331 ) -> None: 332 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 333 pylogger.debug( 334 "Instantiating callbacks (lightning.Callback)...", 335 ) 336 337 # instantiate metric callbacks 338 metric_callbacks = [ 339 hydra.utils.instantiate(callback) for callback in metrics_cfg 340 ] 341 342 # instantiate other callbacks 343 other_callbacks = [ 344 hydra.utils.instantiate(callback) for callback in callbacks_cfg 345 ] 346 347 # add metric callbacks to the list of callbacks 348 self.callbacks = metric_callbacks + other_callbacks 349 pylogger.debug("Callbacks (lightning.Callback) instantiated!")
Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.
351 def instantiate_trainer( 352 self, 353 trainer_cfg: DictConfig, 354 lightning_loggers: list[Logger], 355 callbacks: list[Callback], 356 task_id: int, 357 ) -> None: 358 r"""Instantiate the trainer object for task `task_id` from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 359 360 if not trainer_cfg.get("_target_"): 361 pylogger.debug("Distinct trainer config is applied to each task.") 362 trainer_cfg = trainer_cfg[task_id] 363 else: 364 pylogger.debug("Uniform trainer config is applied to all tasks.") 365 366 pylogger.debug( 367 "Instantiating trainer (lightning.Trainer) for task %d...", 368 task_id, 369 ) 370 self.trainer_t = hydra.utils.instantiate( 371 trainer_cfg, 372 logger=lightning_loggers, 373 callbacks=callbacks, 374 ) 375 pylogger.debug( 376 "Trainer (lightning.Trainer) for task %d instantiated!", 377 task_id, 378 )
Instantiate the trainer object for task task_id from trainer_cfg, lightning_loggers, and callbacks.
380 def set_global_seed(self, global_seed: int) -> None: 381 r"""Set the `global_seed` for the entire experiment.""" 382 L.seed_everything(self.global_seed, workers=True) 383 pylogger.debug("Global seed is set as %d.", global_seed)
Set the global_seed for the entire experiment.
385 def run(self) -> None: 386 r"""The main method to run the continual learning main experiment.""" 387 388 self.set_global_seed(self.global_seed) 389 390 # global components 391 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 392 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 393 self.instantiate_backbone( 394 backbone_cfg=self.cfg.backbone, disable_unlearning=True 395 ) 396 self.instantiate_heads( 397 cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim 398 ) 399 self.instantiate_cl_algorithm( 400 cl_algorithm_cfg=self.cfg.cl_algorithm, 401 backbone=self.backbone, 402 heads=self.heads, 403 non_algorithmic_hparams=select_hyperparameters_from_config( 404 cfg=self.cfg, type=self.cfg.pipeline 405 ), 406 disable_unlearning=True, 407 ) # cl_algorithm should be instantiated after backbone and heads 408 self.instantiate_lightning_loggers( 409 lightning_loggers_cfg=self.cfg.lightning_loggers 410 ) 411 self.instantiate_callbacks( 412 metrics_cfg=self.cfg.metrics, 413 callbacks_cfg=self.cfg.callbacks, 414 ) 415 416 # task loop 417 for task_id in self.train_tasks: 418 419 self.task_id = task_id 420 421 # task-specific components 422 self.instantiate_optimizer( 423 optimizer_cfg=self.cfg.optimizer, 424 task_id=task_id, 425 ) 426 if self.cfg.get("lr_scheduler"): 427 self.instantiate_lr_scheduler( 428 lr_scheduler_cfg=self.cfg.lr_scheduler, 429 task_id=task_id, 430 ) 431 self.instantiate_trainer( 432 trainer_cfg=self.cfg.trainer, 433 lightning_loggers=self.lightning_loggers, 434 callbacks=self.callbacks, 435 task_id=task_id, 436 ) # trainer should be instantiated after lightning loggers and callbacks 437 438 # setup task ID for dataset and model 439 self.cl_dataset.setup_task_id(task_id=task_id) 440 self.model.setup_task_id( 441 task_id=task_id, 442 num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)), 443 optimizer=self.optimizer_t, 444 lr_scheduler=self.lr_scheduler_t, 445 ) 446 447 # train and validate the model 448 self.trainer_t.fit( 449 model=self.model, 450 datamodule=self.cl_dataset, 451 ) 452 453 # evaluation after training and validation 454 if task_id in self.eval_after_tasks: 455 self.trainer_t.test( 456 model=self.model, 457 datamodule=self.cl_dataset, 458 ) 459 460 self.processed_task_ids.append(task_id)
The main method to run the continual learning main experiment.