clarena
Welcome to CLArena
CLArena (Continual Learning Arena) is a open-source Python package for Continual Learning (CL) research. In this package, we provide a integrated environment and various APIs to conduct CL experiments for research purposes, as well as implemented CL algorithms and datasets that you can give it a spin immediately.
Please note that this is an API documantation providing detailed information about the available classes, functions, and modules in CLArena. Please refer to the main documentation and my beginners' guide to continual learning for more intuitive tutorials, examples, and guides on how to use CLArena:
We provide various components of continual learning system in the submodules:
clarena.cl_datasets
: Continual learning datasets.clarena.backbones
: Neural network architectures used as backbones for CL algorithms.clarena.cl_heads
: Multi-head classifiers for continual learning outputs. Task-Incremental Learning (TIL) head and Class-Incremental Learning (CIL) head are included.clarena.cl_algorithms
: Implementation of various continual learning algorithms.clarena.callbacks
: Extra actions added in the continual learning process.utils
: Utility functions for continual learning experiments.
As well as the base class in the outmost directory of the package:
CLExperiment
: The base class for continual learning experiments.
1r""" 2 3# Welcome to CLArena 4 5**CLArena (Continual Learning Arena)** is a open-source Python package for Continual Learning (CL) research. In this package, we provide a integrated environment and various APIs to conduct CL experiments for research purposes, as well as implemented CL algorithms and datasets that you can give it a spin immediately. 6 7Please note that this is an API documantation providing detailed information about the available classes, functions, and modules in CLArena. Please refer to the main documentation and my beginners' guide to continual learning for more intuitive tutorials, examples, and guides on how to use CLArena: 8 9- [**Main Documentation**](https://pengxiang-wang.com/projects/continual-learning-arena) 10- [**A Beginners' Guide to Continual Learning**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide) 11 12We provide various components of continual learning system in the submodules: 13 14- `clarena.cl_datasets`: Continual learning datasets. 15- `clarena.backbones`: Neural network architectures used as backbones for CL algorithms. 16- `clarena.cl_heads`: Multi-head classifiers for continual learning outputs. Task-Incremental Learning (TIL) head and Class-Incremental Learning (CIL) head are included. 17- `clarena.cl_algorithms`: Implementation of various continual learning algorithms. 18- `clarena.callbacks`: Extra actions added in the continual learning process. 19- `utils`: Utility functions for continual learning experiments. 20 21As well as the base class in the outmost directory of the package: 22 23- `CLExperiment`: The base class for continual learning experiments. 24 25""" 26 27from .base import CLExperiment 28 29__all__ = [ 30 "CLExperiment", 31 "cl_datasets", 32 "backbones", 33 "cl_heads", 34 "cl_algorithms", 35 "callbacks", 36 "utils", 37]
25class CLExperiment: 26 r"""The base class for continual learning experiments.""" 27 28 def __init__(self, cfg: DictConfig) -> None: 29 r"""Initializes the CL experiment object with a complete configuration. 30 31 **Args:** 32 - **cfg** (`DictConfig`): the complete config dict for the CL experiment. 33 """ 34 self.cfg: DictConfig = cfg 35 r"""Store the complete config dict for any future reference.""" 36 37 self.cl_paradigm: str = cfg.cl_paradigm 38 r"""Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset.""" 39 self.num_tasks: int = cfg.num_tasks 40 r"""Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop.""" 41 self.global_seed: int = cfg.global_seed if cfg.get("global_seed") else None 42 r"""Store the global seed for the entire experiment. Parsed from config and used to seed all random number generators.""" 43 self.test: bool = cfg.test 44 r"""Store whether to test the model after training and validation. Parsed from config and used in the tasks loop.""" 45 self.output_dir_name: str = cfg.output_dir_name 46 r"""Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory.""" 47 48 self.task_id: int 49 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 50 51 self.cl_dataset: CLDataset 52 r"""CL dataset object. Instantiate in `instantiate_cl_dataset()`.""" 53 self.backbone: CLBackbone 54 r"""Backbone network object. Instantiate in `instantiate_backbone()`.""" 55 self.heads: HeadsTIL | HeadsCIL 56 r"""CL output heads object. Instantiate in `instantiate_heads()`.""" 57 self.model: CLAlgorithm 58 r"""CL model object. Instantiate in `instantiate_cl_algorithm()`.""" 59 60 self.optimizer: Optimizer 61 r"""Optimizer object for current task `self.task_id`. Instantiate in `instantiate_optimizer()`.""" 62 self.trainer: Trainer 63 r"""Trainer object for current task `self.task_id`. Instantiate in `instantiate_trainer()`.""" 64 self.lightning_loggers: list[Logger] 65 r"""The list of initialised lightning loggers objects for current task `self.task_id`. Instantiate in `instantiate_lightning_loggers()`.""" 66 self.callbacks: list[Callback] 67 r"""The list of initialised callbacks objects for current task `self.task_id`. Instantiate in `instantiate_callbacks()`.""" 68 69 CLExperiment.sanity_check(self) 70 71 def sanity_check(self) -> None: 72 r"""Check the sanity of the config dict `self.cfg`. 73 74 **Raises:** 75 - **KeyError**: when required fields in experiment config are missing, including `cl_paradigm`, `num_tasks`, `test`, `output_dir_name`. 76 - **ValueError**: when the value of `cl_paradigm` is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset. 77 """ 78 if not self.cfg.get("cl_paradigm"): 79 raise KeyError( 80 "Field cl_paradigm should be specified in experiment config!" 81 ) 82 83 if self.cfg.cl_paradigm not in ["TIL", "CIL"]: 84 raise ValueError( 85 f"Field cl_paradigm should be either 'TIL' or 'CIL' but got {self.cfg.cl_paradigm}!" 86 ) 87 88 if not self.cfg.get("num_tasks"): 89 raise KeyError("Field num_tasks should be specified in experiment config!") 90 91 if not self.cfg.cl_dataset.get("num_tasks"): 92 raise KeyError("Field num_tasks should be specified in cl_dataset config!") 93 94 if not self.cfg.num_tasks <= self.cfg.cl_dataset.num_tasks: 95 raise ValueError( 96 f"The experiment is set to run {self.cfg.num_tasks} tasks whereas only {self.cfg.cl_dataset.num_tasks} exists in current cl_dataset setting!" 97 ) 98 99 if not self.cfg.get("test"): 100 raise KeyError("Field test should be specified in experiment config!") 101 102 if not self.cfg.get("output_dir_name"): 103 raise KeyError( 104 "Field output_dir_name should be specified in experiment config!" 105 ) 106 107 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 108 r"""Instantiate the CL dataset object from cl_dataset config. 109 110 **Args:** 111 - **cl_dataset_cfg** (`DictConfig`): the cl_dataset config dict. 112 """ 113 pylogger.debug( 114 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 115 cl_dataset_cfg.get("_target_"), 116 ) 117 self.cl_dataset: LightningDataModule = hydra.utils.instantiate(cl_dataset_cfg) 118 pylogger.debug( 119 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 120 cl_dataset_cfg.get("_target_"), 121 ) 122 123 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 124 r"""Instantiate the CL backbone network object from backbone config. 125 126 **Args:** 127 - **backbone_cfg** (`DictConfig`): the backbone config dict. 128 """ 129 pylogger.debug( 130 "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...", 131 backbone_cfg.get("_target_"), 132 ) 133 self.backbone: nn.Module = hydra.utils.instantiate(backbone_cfg) 134 pylogger.debug( 135 "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!", 136 backbone_cfg.get("_target_"), 137 ) 138 139 def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None: 140 r"""Instantiate the CL output heads object according to field `cl_paradigm` and backbone `output_dim` in the config. 141 142 **Args:** 143 - **cl_paradigm** (`str`): the CL paradigm, either 'TIL' or 'CIL'. 144 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 145 """ 146 pylogger.debug( 147 "CL paradigm is set as %s. Instantiating %s heads (torch.nn.Module)...", 148 cl_paradigm, 149 cl_paradigm, 150 ) 151 self.heads: HeadsTIL | HeadsCIL = ( 152 HeadsTIL(input_dim=input_dim) 153 if cl_paradigm == "TIL" 154 else HeadsCIL(input_dim=input_dim) 155 ) 156 pylogger.debug("%s heads (torch.nn.Module) instantiated! ", cl_paradigm) 157 158 def instantiate_cl_algorithm(self, cl_algorithm_cfg: DictConfig) -> None: 159 r"""Instantiate the cl_algorithm object from cl_algorithm config. 160 161 **Args:** 162 - **cl_algorithm_cfg** (`DictConfig`): the cl_algorithm config dict. 163 """ 164 pylogger.debug( 165 "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...", 166 cl_algorithm_cfg.get("_target_"), 167 cl_algorithm_cfg.get("_target_"), 168 ) 169 self.model: LightningModule = hydra.utils.instantiate( 170 cl_algorithm_cfg, 171 backbone=self.backbone, 172 heads=self.heads, 173 ) 174 pylogger.debug( 175 "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!", 176 cl_algorithm_cfg.get("_target_"), 177 ) 178 179 def instantiate_optimizer( 180 self, optimizer_cfg: DictConfig | ListConfig, task_id: int 181 ) -> None: 182 r"""Instantiate the optimizer object for task `task_id` from optimizer config. 183 184 **Args:** 185 - **optimizer_cfg** (`DictConfig` or `ListConfig`): the optimizer config dict. If it's a `ListConfig`, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks. 186 - **task_id** (`int`): the target task ID. 187 """ 188 if isinstance(optimizer_cfg, ListConfig): 189 pylogger.debug("Distinct optimizer config is applied to each task.") 190 optimizer_cfg = optimizer_cfg[task_id - 1] 191 elif isinstance(optimizer_cfg, DictConfig): 192 pylogger.debug("Uniform optimizer config is applied to all tasks.") 193 194 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 195 pylogger.debug( 196 "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...", 197 optimizer_cfg.get("_target_"), 198 task_id, 199 ) 200 self.optimizer: Optimizer = hydra.utils.instantiate(optimizer_cfg) 201 pylogger.debug( 202 "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!", 203 optimizer_cfg.get("_target_"), 204 task_id, 205 ) 206 207 def instantiate_trainer(self, trainer_cfg: DictConfig, task_id: int) -> None: 208 r"""Instantiate the trainer object for task `task_id` from trainer config. 209 210 **Args:** 211 - **trainer_cfg** (`DictConfig`): the trainer config dict. All tasks share the same trainer config but different objects. 212 - **task_id** (`int`): the target task ID. 213 """ 214 pylogger.debug( 215 "Instantiating trainer <%s> (lightning.Trainer) for task %d...", 216 trainer_cfg.get("_target_"), 217 task_id, 218 ) 219 self.trainer: Trainer = hydra.utils.instantiate( 220 trainer_cfg, callbacks=self.callbacks, logger=self.lightning_loggers 221 ) 222 pylogger.debug( 223 "Trainer <%s> (lightning.Trainer) for task %d instantiated!", 224 trainer_cfg.get("_target_"), 225 task_id, 226 ) 227 228 def instantiate_lightning_loggers( 229 self, lightning_loggers_cfg: DictConfig, task_id: int 230 ) -> None: 231 r"""Instantiate the list of lightning loggers objects for task `task_id` from lightning_loggers config. 232 233 **Args:** 234 - **lightning_loggers_cfg** (`DictConfig`): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects. 235 - **task_id** (`int`): the target task ID. 236 """ 237 pylogger.debug( 238 "Instantiating Lightning loggers (lightning.Logger) for task %d...", task_id 239 ) 240 self.lightning_loggers: list[Logger] = [ 241 hydra.utils.instantiate( 242 lightning_logger, version=f"task_{task_id}" 243 ) # change the directory name to "task_" prefix in lightning logs 244 for lightning_logger in lightning_loggers_cfg.values() 245 ] 246 pylogger.debug( 247 "Lightning loggers (lightning.Logger) for task %d instantiated!", task_id 248 ) 249 250 def instantiate_callbacks(self, callbacks_cfg: DictConfig, task_id: int) -> None: 251 r"""Instantiate the list of callbacks objects for task `task_id` from callbacks config. 252 253 **Args:** 254 - **callbacks_cfg** (`DictConfig`): the callbacks config dict. All tasks share the same callbacks config but different objects. 255 - **task_id** (`int`): the target task ID. 256 """ 257 pylogger.debug( 258 "Instantiating callbacks (lightning.Callback) for task %d...", task_id 259 ) 260 self.callbacks: list[Callback] = [ 261 hydra.utils.instantiate(callback) for callback in callbacks_cfg.values() 262 ] 263 pylogger.debug( 264 "Callbacks (lightning.Callback) for task %d instantiated!", task_id 265 ) 266 267 def set_global_seed(self) -> None: 268 r"""Set the global seed for the entire experiment.""" 269 L.seed_everything(self.global_seed, workers=True) 270 pylogger.debug("Global seed is set as %d.", self.global_seed) 271 272 def setup_task_id(self, task_id: int) -> None: 273 r"""Set up current task_id in the beginning of the continual learning process of a new task. 274 275 **Args:** 276 - **task_id** (`int`): current task_id. 277 """ 278 self.task_id = task_id 279 280 def instantiate_global(self) -> None: 281 r"""Instantiate global components for the entire CL experiment from `self.cfg`.""" 282 283 self.instantiate_cl_dataset(self.cfg.cl_dataset) 284 self.instantiate_backbone(self.cfg.backbone) 285 self.instantiate_heads(self.cfg.cl_paradigm, self.cfg.backbone.output_dim) 286 self.instantiate_cl_algorithm( 287 self.cfg.cl_algorithm 288 ) # cl_algorithm should be instantiated after backbone and heads 289 290 def setup_global(self) -> None: 291 r"""Let CL dataset know the CL paradigm to define its CL class map.""" 292 self.set_global_seed() 293 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 294 295 def instantiate_task_specific(self) -> None: 296 r"""Instantiate task-specific components for the current task `self.task_id` from `self.cfg`.""" 297 298 self.instantiate_optimizer(self.cfg.optimizer, self.task_id) 299 self.instantiate_callbacks(self.cfg.callbacks, self.task_id) 300 self.instantiate_lightning_loggers(self.cfg.lightning_loggers, self.task_id) 301 self.instantiate_trainer( 302 self.cfg.trainer, self.task_id 303 ) # trainer should be instantiated after loggers and callbacks 304 305 def setup_task_specific(self) -> None: 306 r"""Setup task-specific components to get ready for the current task `self.task_id`.""" 307 308 self.cl_dataset.setup_task_id(self.task_id) 309 self.backbone.setup_task_id(self.task_id) 310 self.model.setup_task_id( 311 self.task_id, 312 len(self.cl_dataset.cl_class_map(self.task_id)), 313 self.optimizer, 314 ) 315 316 pylogger.debug( 317 "Datamodule, model and loggers are all set up ready for task %d!", 318 self.task_id, 319 ) 320 321 def run_task(self) -> None: 322 r"""Fit the model on the current task `self.task_id`. Also test the model if `self.test` is set to True.""" 323 324 self.trainer.fit( 325 model=self.model, 326 datamodule=self.cl_dataset, 327 ) 328 329 if self.test: 330 # test after training and validation 331 self.trainer.test( 332 model=self.model, 333 datamodule=self.cl_dataset, 334 ) 335 336 def run(self) -> None: 337 r"""The main method to run the continual learning experiment.""" 338 339 self.instantiate_global() 340 self.setup_global() 341 342 # task loop 343 for task_id in range(1, self.num_tasks + 1): # task ID counts from 1 344 345 self.setup_task_id(task_id) 346 self.instantiate_task_specific() 347 self.setup_task_specific() 348 349 self.run_task()
The base class for continual learning experiments.
28 def __init__(self, cfg: DictConfig) -> None: 29 r"""Initializes the CL experiment object with a complete configuration. 30 31 **Args:** 32 - **cfg** (`DictConfig`): the complete config dict for the CL experiment. 33 """ 34 self.cfg: DictConfig = cfg 35 r"""Store the complete config dict for any future reference.""" 36 37 self.cl_paradigm: str = cfg.cl_paradigm 38 r"""Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset.""" 39 self.num_tasks: int = cfg.num_tasks 40 r"""Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop.""" 41 self.global_seed: int = cfg.global_seed if cfg.get("global_seed") else None 42 r"""Store the global seed for the entire experiment. Parsed from config and used to seed all random number generators.""" 43 self.test: bool = cfg.test 44 r"""Store whether to test the model after training and validation. Parsed from config and used in the tasks loop.""" 45 self.output_dir_name: str = cfg.output_dir_name 46 r"""Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory.""" 47 48 self.task_id: int 49 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 50 51 self.cl_dataset: CLDataset 52 r"""CL dataset object. Instantiate in `instantiate_cl_dataset()`.""" 53 self.backbone: CLBackbone 54 r"""Backbone network object. Instantiate in `instantiate_backbone()`.""" 55 self.heads: HeadsTIL | HeadsCIL 56 r"""CL output heads object. Instantiate in `instantiate_heads()`.""" 57 self.model: CLAlgorithm 58 r"""CL model object. Instantiate in `instantiate_cl_algorithm()`.""" 59 60 self.optimizer: Optimizer 61 r"""Optimizer object for current task `self.task_id`. Instantiate in `instantiate_optimizer()`.""" 62 self.trainer: Trainer 63 r"""Trainer object for current task `self.task_id`. Instantiate in `instantiate_trainer()`.""" 64 self.lightning_loggers: list[Logger] 65 r"""The list of initialised lightning loggers objects for current task `self.task_id`. Instantiate in `instantiate_lightning_loggers()`.""" 66 self.callbacks: list[Callback] 67 r"""The list of initialised callbacks objects for current task `self.task_id`. Instantiate in `instantiate_callbacks()`.""" 68 69 CLExperiment.sanity_check(self)
Initializes the CL experiment object with a complete configuration.
Args:
- cfg (
DictConfig
): the complete config dict for the CL experiment.
Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset.
Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop.
Store the global seed for the entire experiment. Parsed from config and used to seed all random number generators.
Store whether to test the model after training and validation. Parsed from config and used in the tasks loop.
Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory.
Task ID counter indicating which task is being processed. Self updated during the task loop.
CL dataset object. Instantiate in instantiate_cl_dataset()
.
Backbone network object. Instantiate in instantiate_backbone()
.
CL output heads object. Instantiate in instantiate_heads()
.
CL model object. Instantiate in instantiate_cl_algorithm()
.
Optimizer object for current task self.task_id
. Instantiate in instantiate_optimizer()
.
Trainer object for current task self.task_id
. Instantiate in instantiate_trainer()
.
The list of initialised lightning loggers objects for current task self.task_id
. Instantiate in instantiate_lightning_loggers()
.
The list of initialised callbacks objects for current task self.task_id
. Instantiate in instantiate_callbacks()
.
71 def sanity_check(self) -> None: 72 r"""Check the sanity of the config dict `self.cfg`. 73 74 **Raises:** 75 - **KeyError**: when required fields in experiment config are missing, including `cl_paradigm`, `num_tasks`, `test`, `output_dir_name`. 76 - **ValueError**: when the value of `cl_paradigm` is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset. 77 """ 78 if not self.cfg.get("cl_paradigm"): 79 raise KeyError( 80 "Field cl_paradigm should be specified in experiment config!" 81 ) 82 83 if self.cfg.cl_paradigm not in ["TIL", "CIL"]: 84 raise ValueError( 85 f"Field cl_paradigm should be either 'TIL' or 'CIL' but got {self.cfg.cl_paradigm}!" 86 ) 87 88 if not self.cfg.get("num_tasks"): 89 raise KeyError("Field num_tasks should be specified in experiment config!") 90 91 if not self.cfg.cl_dataset.get("num_tasks"): 92 raise KeyError("Field num_tasks should be specified in cl_dataset config!") 93 94 if not self.cfg.num_tasks <= self.cfg.cl_dataset.num_tasks: 95 raise ValueError( 96 f"The experiment is set to run {self.cfg.num_tasks} tasks whereas only {self.cfg.cl_dataset.num_tasks} exists in current cl_dataset setting!" 97 ) 98 99 if not self.cfg.get("test"): 100 raise KeyError("Field test should be specified in experiment config!") 101 102 if not self.cfg.get("output_dir_name"): 103 raise KeyError( 104 "Field output_dir_name should be specified in experiment config!" 105 )
Check the sanity of the config dict self.cfg
.
Raises:
- KeyError: when required fields in experiment config are missing, including
cl_paradigm
,num_tasks
,test
,output_dir_name
. - ValueError: when the value of
cl_paradigm
is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset.
107 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 108 r"""Instantiate the CL dataset object from cl_dataset config. 109 110 **Args:** 111 - **cl_dataset_cfg** (`DictConfig`): the cl_dataset config dict. 112 """ 113 pylogger.debug( 114 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 115 cl_dataset_cfg.get("_target_"), 116 ) 117 self.cl_dataset: LightningDataModule = hydra.utils.instantiate(cl_dataset_cfg) 118 pylogger.debug( 119 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 120 cl_dataset_cfg.get("_target_"), 121 )
Instantiate the CL dataset object from cl_dataset config.
Args:
- cl_dataset_cfg (
DictConfig
): the cl_dataset config dict.
123 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 124 r"""Instantiate the CL backbone network object from backbone config. 125 126 **Args:** 127 - **backbone_cfg** (`DictConfig`): the backbone config dict. 128 """ 129 pylogger.debug( 130 "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...", 131 backbone_cfg.get("_target_"), 132 ) 133 self.backbone: nn.Module = hydra.utils.instantiate(backbone_cfg) 134 pylogger.debug( 135 "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!", 136 backbone_cfg.get("_target_"), 137 )
Instantiate the CL backbone network object from backbone config.
Args:
- backbone_cfg (
DictConfig
): the backbone config dict.
139 def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None: 140 r"""Instantiate the CL output heads object according to field `cl_paradigm` and backbone `output_dim` in the config. 141 142 **Args:** 143 - **cl_paradigm** (`str`): the CL paradigm, either 'TIL' or 'CIL'. 144 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 145 """ 146 pylogger.debug( 147 "CL paradigm is set as %s. Instantiating %s heads (torch.nn.Module)...", 148 cl_paradigm, 149 cl_paradigm, 150 ) 151 self.heads: HeadsTIL | HeadsCIL = ( 152 HeadsTIL(input_dim=input_dim) 153 if cl_paradigm == "TIL" 154 else HeadsCIL(input_dim=input_dim) 155 ) 156 pylogger.debug("%s heads (torch.nn.Module) instantiated! ", cl_paradigm)
Instantiate the CL output heads object according to field cl_paradigm
and backbone output_dim
in the config.
Args:
- cl_paradigm (
str
): the CL paradigm, either 'TIL' or 'CIL'. - input_dim (
int
): the input dimension of the heads. Must be equal to theoutput_dim
of the connected backbone.
158 def instantiate_cl_algorithm(self, cl_algorithm_cfg: DictConfig) -> None: 159 r"""Instantiate the cl_algorithm object from cl_algorithm config. 160 161 **Args:** 162 - **cl_algorithm_cfg** (`DictConfig`): the cl_algorithm config dict. 163 """ 164 pylogger.debug( 165 "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...", 166 cl_algorithm_cfg.get("_target_"), 167 cl_algorithm_cfg.get("_target_"), 168 ) 169 self.model: LightningModule = hydra.utils.instantiate( 170 cl_algorithm_cfg, 171 backbone=self.backbone, 172 heads=self.heads, 173 ) 174 pylogger.debug( 175 "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!", 176 cl_algorithm_cfg.get("_target_"), 177 )
Instantiate the cl_algorithm object from cl_algorithm config.
Args:
- cl_algorithm_cfg (
DictConfig
): the cl_algorithm config dict.
179 def instantiate_optimizer( 180 self, optimizer_cfg: DictConfig | ListConfig, task_id: int 181 ) -> None: 182 r"""Instantiate the optimizer object for task `task_id` from optimizer config. 183 184 **Args:** 185 - **optimizer_cfg** (`DictConfig` or `ListConfig`): the optimizer config dict. If it's a `ListConfig`, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks. 186 - **task_id** (`int`): the target task ID. 187 """ 188 if isinstance(optimizer_cfg, ListConfig): 189 pylogger.debug("Distinct optimizer config is applied to each task.") 190 optimizer_cfg = optimizer_cfg[task_id - 1] 191 elif isinstance(optimizer_cfg, DictConfig): 192 pylogger.debug("Uniform optimizer config is applied to all tasks.") 193 194 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 195 pylogger.debug( 196 "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...", 197 optimizer_cfg.get("_target_"), 198 task_id, 199 ) 200 self.optimizer: Optimizer = hydra.utils.instantiate(optimizer_cfg) 201 pylogger.debug( 202 "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!", 203 optimizer_cfg.get("_target_"), 204 task_id, 205 )
Instantiate the optimizer object for task task_id
from optimizer config.
Args:
- optimizer_cfg (
DictConfig
orListConfig
): the optimizer config dict. If it's aListConfig
, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks. - task_id (
int
): the target task ID.
207 def instantiate_trainer(self, trainer_cfg: DictConfig, task_id: int) -> None: 208 r"""Instantiate the trainer object for task `task_id` from trainer config. 209 210 **Args:** 211 - **trainer_cfg** (`DictConfig`): the trainer config dict. All tasks share the same trainer config but different objects. 212 - **task_id** (`int`): the target task ID. 213 """ 214 pylogger.debug( 215 "Instantiating trainer <%s> (lightning.Trainer) for task %d...", 216 trainer_cfg.get("_target_"), 217 task_id, 218 ) 219 self.trainer: Trainer = hydra.utils.instantiate( 220 trainer_cfg, callbacks=self.callbacks, logger=self.lightning_loggers 221 ) 222 pylogger.debug( 223 "Trainer <%s> (lightning.Trainer) for task %d instantiated!", 224 trainer_cfg.get("_target_"), 225 task_id, 226 )
Instantiate the trainer object for task task_id
from trainer config.
Args:
- trainer_cfg (
DictConfig
): the trainer config dict. All tasks share the same trainer config but different objects. - task_id (
int
): the target task ID.
228 def instantiate_lightning_loggers( 229 self, lightning_loggers_cfg: DictConfig, task_id: int 230 ) -> None: 231 r"""Instantiate the list of lightning loggers objects for task `task_id` from lightning_loggers config. 232 233 **Args:** 234 - **lightning_loggers_cfg** (`DictConfig`): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects. 235 - **task_id** (`int`): the target task ID. 236 """ 237 pylogger.debug( 238 "Instantiating Lightning loggers (lightning.Logger) for task %d...", task_id 239 ) 240 self.lightning_loggers: list[Logger] = [ 241 hydra.utils.instantiate( 242 lightning_logger, version=f"task_{task_id}" 243 ) # change the directory name to "task_" prefix in lightning logs 244 for lightning_logger in lightning_loggers_cfg.values() 245 ] 246 pylogger.debug( 247 "Lightning loggers (lightning.Logger) for task %d instantiated!", task_id 248 )
Instantiate the list of lightning loggers objects for task task_id
from lightning_loggers config.
Args:
- lightning_loggers_cfg (
DictConfig
): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects. - task_id (
int
): the target task ID.
250 def instantiate_callbacks(self, callbacks_cfg: DictConfig, task_id: int) -> None: 251 r"""Instantiate the list of callbacks objects for task `task_id` from callbacks config. 252 253 **Args:** 254 - **callbacks_cfg** (`DictConfig`): the callbacks config dict. All tasks share the same callbacks config but different objects. 255 - **task_id** (`int`): the target task ID. 256 """ 257 pylogger.debug( 258 "Instantiating callbacks (lightning.Callback) for task %d...", task_id 259 ) 260 self.callbacks: list[Callback] = [ 261 hydra.utils.instantiate(callback) for callback in callbacks_cfg.values() 262 ] 263 pylogger.debug( 264 "Callbacks (lightning.Callback) for task %d instantiated!", task_id 265 )
Instantiate the list of callbacks objects for task task_id
from callbacks config.
Args:
- callbacks_cfg (
DictConfig
): the callbacks config dict. All tasks share the same callbacks config but different objects. - task_id (
int
): the target task ID.
267 def set_global_seed(self) -> None: 268 r"""Set the global seed for the entire experiment.""" 269 L.seed_everything(self.global_seed, workers=True) 270 pylogger.debug("Global seed is set as %d.", self.global_seed)
Set the global seed for the entire experiment.
272 def setup_task_id(self, task_id: int) -> None: 273 r"""Set up current task_id in the beginning of the continual learning process of a new task. 274 275 **Args:** 276 - **task_id** (`int`): current task_id. 277 """ 278 self.task_id = task_id
Set up current task_id in the beginning of the continual learning process of a new task.
Args:
- task_id (
int
): current task_id.
280 def instantiate_global(self) -> None: 281 r"""Instantiate global components for the entire CL experiment from `self.cfg`.""" 282 283 self.instantiate_cl_dataset(self.cfg.cl_dataset) 284 self.instantiate_backbone(self.cfg.backbone) 285 self.instantiate_heads(self.cfg.cl_paradigm, self.cfg.backbone.output_dim) 286 self.instantiate_cl_algorithm( 287 self.cfg.cl_algorithm 288 ) # cl_algorithm should be instantiated after backbone and heads
Instantiate global components for the entire CL experiment from self.cfg
.
290 def setup_global(self) -> None: 291 r"""Let CL dataset know the CL paradigm to define its CL class map.""" 292 self.set_global_seed() 293 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
Let CL dataset know the CL paradigm to define its CL class map.
295 def instantiate_task_specific(self) -> None: 296 r"""Instantiate task-specific components for the current task `self.task_id` from `self.cfg`.""" 297 298 self.instantiate_optimizer(self.cfg.optimizer, self.task_id) 299 self.instantiate_callbacks(self.cfg.callbacks, self.task_id) 300 self.instantiate_lightning_loggers(self.cfg.lightning_loggers, self.task_id) 301 self.instantiate_trainer( 302 self.cfg.trainer, self.task_id 303 ) # trainer should be instantiated after loggers and callbacks
Instantiate task-specific components for the current task self.task_id
from self.cfg
.
305 def setup_task_specific(self) -> None: 306 r"""Setup task-specific components to get ready for the current task `self.task_id`.""" 307 308 self.cl_dataset.setup_task_id(self.task_id) 309 self.backbone.setup_task_id(self.task_id) 310 self.model.setup_task_id( 311 self.task_id, 312 len(self.cl_dataset.cl_class_map(self.task_id)), 313 self.optimizer, 314 ) 315 316 pylogger.debug( 317 "Datamodule, model and loggers are all set up ready for task %d!", 318 self.task_id, 319 )
Setup task-specific components to get ready for the current task self.task_id
.
321 def run_task(self) -> None: 322 r"""Fit the model on the current task `self.task_id`. Also test the model if `self.test` is set to True.""" 323 324 self.trainer.fit( 325 model=self.model, 326 datamodule=self.cl_dataset, 327 ) 328 329 if self.test: 330 # test after training and validation 331 self.trainer.test( 332 model=self.model, 333 datamodule=self.cl_dataset, 334 )
Fit the model on the current task self.task_id
. Also test the model if self.test
is set to True.
336 def run(self) -> None: 337 r"""The main method to run the continual learning experiment.""" 338 339 self.instantiate_global() 340 self.setup_global() 341 342 # task loop 343 for task_id in range(1, self.num_tasks + 1): # task ID counts from 1 344 345 self.setup_task_id(task_id) 346 self.instantiate_task_specific() 347 self.setup_task_specific() 348 349 self.run_task()
The main method to run the continual learning experiment.