clarena
Welcome to CLArena
CLArena (Continual Learning Arena) is a open-source Python package for Continual Learning (CL) research, solely developed by myself. In this package, I provide a integrated environment and various APIs to conduct CL experiments for research purposes, as well as implemented CL algorithms and datasets that you can give it a spin immediately.
Please note that this is an API documantation providing detailed information about the available classes, functions, and modules in CLArena. Please refer to the main documentation and my beginners' guide to continual learning for more intuitive tutorials, examples, and guides on how to use CLArena.
- Main documentation: https://pengxiang-wang.com/projects/continual-learning-arena/
- A beginners' guide to continual learning: https://pengxiang-wang.com/posts/continual-learning-beginners-guide
We provide various components of continual learning system in the submodules:
clarena.cl_datasets
: Continual learning datasets.clarena.backbones
: Neural network architectures used as backbones for CL algorithms.clarena.cl_heads
: Multi-head classifiers for continual learning outputs. Task-Incremental Learning (TIL) head and Class-Incremental Learning (CIL) head are included.clarena.cl_algorithms
: Implementations of various continual learning algorithms.clarena.optimizers
: Custom optimizers tailored for continual learning.clarena.callbacks
: Extra functionality or action added in the continual learning process.
As well as the base class in the outmost directory of the package:
- Class
CLExperiment
: The base class for continual learning experiments.
1""" 2 3# Welcome to CLArena 4 5**CLArena (Continual Learning Arena)** is a open-source Python package for Continual Learning (CL) research, solely developed by myself. In this package, I provide a integrated environment and various APIs to conduct CL experiments for research purposes, as well as implemented CL algorithms and datasets that you can give it a spin immediately. 6 7Please note that this is an API documantation providing detailed information about the available classes, functions, and modules in CLArena. Please refer to the main documentation and my beginners' guide to continual learning for more intuitive tutorials, examples, and guides on how to use CLArena. 8 9- **Main documentation:** [https://pengxiang-wang.com/projects/continual-learning-arena/](https://pengxiang-wang.com/projects/continual-learning-arena/) 10- **A beginners' guide to continual learning:** [https://pengxiang-wang.com/posts/continual-learning-beginners-guide](https://pengxiang-wang.com/posts/continual-learning-beginners-guide) 11 12We provide various components of continual learning system in the submodules: 13 14- `clarena.cl_datasets`: Continual learning datasets. 15- `clarena.backbones`: Neural network architectures used as backbones for CL algorithms. 16- `clarena.cl_heads`: Multi-head classifiers for continual learning outputs. Task-Incremental Learning (TIL) head and Class-Incremental Learning (CIL) head are included. 17- `clarena.cl_algorithms`: Implementations of various continual learning algorithms. 18- `clarena.optimizers`: Custom optimizers tailored for continual learning. 19- `clarena.callbacks`: Extra functionality or action added in the continual learning process. 20 21As well as the base class in the outmost directory of the package: 22 23- Class `CLExperiment`: The base class for continual learning experiments. 24 25""" 26 27from .base import CLExperiment 28 29__all__ = [ 30 "CLExperiment", 31 "cl_datasets", 32 "backbones", 33 "cl_heads", 34 "cl_algorithms", 35 "optimizers", 36 "callbacks", 37]
24class CLExperiment: 25 """The base class for continual learning experiments.""" 26 27 def __init__(self, cfg: DictConfig) -> None: 28 """Initializes the CL experiment object with a complete configuration. 29 30 **Args:** 31 - **cfg** (`DictConfig`): the complete config dict for the CL experiment. 32 """ 33 self.cfg: DictConfig = cfg 34 """Store the complete config dict for any future reference.""" 35 self.sanity_check() 36 37 self.cl_paradigm: str = cfg.cl_paradigm 38 """Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset.""" 39 self.num_tasks: int = cfg.num_tasks 40 """Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop.""" 41 self.test: bool = cfg.test 42 """Store whether to test the model after training and validation. Parsed from config and used in the tasks loop.""" 43 self.output_dir_name: str = cfg.output_dir_name 44 """Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory.""" 45 46 self.task_id: int 47 """Task ID counter indicating which task is being processed. Self updated during the task loop.""" 48 49 self.cl_dataset: CLDataset 50 """CL dataset object. Instantiate in `instantiate_cl_dataset()`.""" 51 self.backbone: CLBackbone 52 """Backbone network object. Instantiate in `instantiate_backbone()`.""" 53 self.heads: HeadsTIL | HeadsCIL 54 """CL output heads object. Instantiate in `instantiate_heads()`.""" 55 self.model: CLAlgorithm 56 """CL model object. Instantiate in `instantiate_cl_algorithm()`.""" 57 58 self.optimizer: Optimizer 59 """Optimizer object for current task `self.task_id`. Instantiate in `instantiate_optimizer()`.""" 60 self.trainer: Trainer 61 """Trainer object for current task `self.task_id`. Instantiate in `instantiate_trainer()`.""" 62 self.lightning_loggers: list[Logger] 63 """The list of initialised lightning loggers objects for current task `self.task_id`. Instantiate in `instantiate_lightning_loggers()`.""" 64 self.callbacks: list[Callback] 65 """The list of initialised callbacks objects for current task `self.task_id`. Instantiate in `instantiate_callbacks()`.""" 66 67 def sanity_check(self) -> None: 68 """Check the sanity of the config dict `self.cfg`. 69 70 **Raises:** 71 - **KeyError**: when required fields in experiment config are missing, including `cl_paradigm`, `num_tasks`, `test`, `output_dir_name`. 72 - **ValueError**: when the value of `cl_paradigm` is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset. 73 """ 74 if not self.cfg.get("cl_paradigm"): 75 raise KeyError( 76 "Field cl_paradigm should be specified in experiment config!" 77 ) 78 79 if self.cfg.cl_paradigm not in ["TIL", "CIL"]: 80 raise ValueError( 81 f"Field cl_paradigm should be either 'TIL' or 'CIL' but got {self.cfg.cl_paradigm}!" 82 ) 83 84 if not self.cfg.get("num_tasks"): 85 raise KeyError("Field num_tasks should be specified in experiment config!") 86 87 if not self.cfg.cl_dataset.get("num_tasks"): 88 raise KeyError("Field num_tasks should be specified in cl_dataset config!") 89 90 if not self.cfg.num_tasks <= self.cfg.cl_dataset.num_tasks: 91 raise ValueError( 92 f"The experiment is set to run {self.cfg.num_tasks} tasks whereas only {self.cfg.cl_dataset.num_tasks} exists in current cl_dataset setting!" 93 ) 94 95 if not self.cfg.get("test"): 96 raise KeyError("Field test should be specified in experiment config!") 97 98 if not self.cfg.get("output_dir_name"): 99 raise KeyError( 100 "Field output_dir_name should be specified in experiment config!" 101 ) 102 103 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 104 """Instantiate the CL dataset object from cl_dataset config. 105 106 **Args:** 107 - **cl_dataset_cfg** (`DictConfig`): the cl_dataset config dict. 108 """ 109 pylogger.debug( 110 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 111 cl_dataset_cfg.get("_target_"), 112 ) 113 self.cl_dataset: LightningDataModule = hydra.utils.instantiate(cl_dataset_cfg) 114 pylogger.debug( 115 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 116 cl_dataset_cfg.get("_target_"), 117 ) 118 119 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 120 """Instantiate the CL backbone network object from backbone config. 121 122 **Args:** 123 - **backbone_cfg** (`DictConfig`): the backbone config dict. 124 """ 125 pylogger.debug( 126 "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...", 127 backbone_cfg.get("_target_"), 128 ) 129 self.backbone: nn.Module = hydra.utils.instantiate(backbone_cfg) 130 pylogger.debug( 131 "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!", 132 backbone_cfg.get("_target_"), 133 ) 134 135 def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None: 136 """Instantiate the CL output heads object according to field `cl_paradigm` and backbone `output_dim` in the config. 137 138 **Args:** 139 - **cl_paradigm** (`str`): the CL paradigm, either 'TIL' or 'CIL'. 140 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 141 """ 142 pylogger.debug( 143 "CL paradigm is set as %s. Instantiating %s heads (torch.nn.Module)...", 144 cl_paradigm, 145 cl_paradigm, 146 ) 147 self.heads: HeadsTIL | HeadsCIL = ( 148 HeadsTIL(input_dim=input_dim) 149 if cl_paradigm == "TIL" 150 else HeadsCIL(input_dim=input_dim) 151 ) 152 pylogger.debug("%s heads (torch.nn.Module) instantiated! ", cl_paradigm) 153 154 def instantiate_cl_algorithm(self, cl_algorithm_cfg: DictConfig) -> None: 155 """Instantiate the cl_algorithm object from cl_algorithm config. 156 157 **Args:** 158 - **cl_algorithm_cfg** (`DictConfig`): the cl_algorithm config dict. 159 """ 160 pylogger.debug( 161 "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...", 162 cl_algorithm_cfg.get("_target_"), 163 cl_algorithm_cfg.get("_target_"), 164 ) 165 self.model: LightningModule = hydra.utils.instantiate( 166 cl_algorithm_cfg, 167 backbone=self.backbone, 168 heads=self.heads, 169 ) 170 pylogger.debug( 171 "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!", 172 cl_algorithm_cfg.get("_target_"), 173 ) 174 175 def instantiate_optimizer( 176 self, optimizer_cfg: DictConfig | ListConfig, task_id: int 177 ) -> None: 178 """Instantiate the optimizer object for task `task_id` from optimizer config. 179 180 **Args:** 181 - **optimizer_cfg** (`DictConfig` or `ListConfig`): the optimizer config dict. If it's a `ListConfig`, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks. 182 - **task_id** (`int`): the target task ID. 183 """ 184 if isinstance(optimizer_cfg, ListConfig): 185 pylogger.debug("Distinct optimizer config is applied to each task.") 186 optimizer_cfg = optimizer_cfg[task_id - 1] 187 elif isinstance(optimizer_cfg, DictConfig): 188 pylogger.debug("Uniform optimizer config is applied to all tasks.") 189 190 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 191 pylogger.debug( 192 "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...", 193 optimizer_cfg.get("_target_"), 194 task_id, 195 ) 196 self.optimizer: Optimizer = hydra.utils.instantiate(optimizer_cfg) 197 pylogger.debug( 198 "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!", 199 optimizer_cfg.get("_target_"), 200 task_id, 201 ) 202 203 def instantiate_trainer(self, trainer_cfg: DictConfig, task_id: int) -> None: 204 """Instantiate the trainer object for task `task_id` from trainer config. 205 206 **Args:** 207 - **trainer_cfg** (`DictConfig`): the trainer config dict. All tasks share the same trainer config but different objects. 208 - **task_id** (`int`): the target task ID. 209 """ 210 pylogger.debug( 211 "Instantiating trainer <%s> (lightning.Trainer) for task %d...", 212 trainer_cfg.get("_target_"), 213 task_id, 214 ) 215 self.trainer: Trainer = hydra.utils.instantiate( 216 trainer_cfg, callbacks=self.callbacks, logger=self.lightning_loggers 217 ) 218 pylogger.debug( 219 "Trainer <%s> (lightning.Trainer) for task %d instantiated!", 220 trainer_cfg.get("_target_"), 221 task_id, 222 ) 223 224 def instantiate_lightning_loggers( 225 self, lightning_loggers_cfg: DictConfig, task_id: int 226 ) -> None: 227 """Instantiate the list of lightning loggers objects for task `task_id` from lightning_loggers config. 228 229 **Args:** 230 - **lightning_loggers_cfg** (`DictConfig`): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects. 231 - **task_id** (`int`): the target task ID. 232 """ 233 pylogger.debug( 234 "Instantiating Lightning loggers (lightning.Logger) for task %d...", task_id 235 ) 236 self.lightning_loggers: list[Logger] = [ 237 hydra.utils.instantiate( 238 lightning_logger, version=f"task_{task_id}" 239 ) # change the directory name to "task_" prefix in lightning logs 240 for lightning_logger in lightning_loggers_cfg.values() 241 ] 242 pylogger.debug( 243 "Lightning loggers (lightning.Logger) for task %d instantiated!", task_id 244 ) 245 246 def instantiate_callbacks(self, callbacks_cfg: DictConfig, task_id: int) -> None: 247 """Instantiate the list of callbacks objects for task `task_id` from callbacks config. 248 249 **Args:** 250 - **callbacks_cfg** (`DictConfig`): the callbacks config dict. All tasks share the same callbacks config but different objects. 251 - **task_id** (`int`): the target task ID. 252 """ 253 pylogger.debug( 254 "Instantiating callbacks (lightning.Callback) for task %d...", task_id 255 ) 256 self.callbacks: list[Callback] = [ 257 hydra.utils.instantiate(callback) for callback in callbacks_cfg.values() 258 ] 259 pylogger.debug( 260 "Callbacks (lightning.Callback) for task %d instantiated!", task_id 261 ) 262 263 def setup_task_id(self, task_id: int) -> None: 264 """Set up current task_id in the beginning of the continual learning process of a new task. 265 266 **Args:** 267 - **task_id** (`int`): current task_id. 268 """ 269 self.task_id = task_id 270 271 def instantiate_global(self) -> None: 272 """Instantiate global components for the entire CL experiment from `self.cfg`.""" 273 274 self.instantiate_cl_dataset(self.cfg.cl_dataset) 275 self.instantiate_backbone(self.cfg.backbone) 276 self.instantiate_heads(self.cfg.cl_paradigm, self.cfg.backbone.output_dim) 277 self.instantiate_cl_algorithm( 278 self.cfg.cl_algorithm 279 ) # cl_algorithm should be instantiated after backbone and heads 280 281 def setup_global(self) -> None: 282 """Let CL dataset know the CL paradigm to define its CL class map.""" 283 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 284 285 def instantiate_task_specific(self) -> None: 286 """Instantiate task-specific components for the current task `self.task_id` from `self.cfg`.""" 287 288 self.instantiate_optimizer(self.cfg.optimizer, self.task_id) 289 self.instantiate_callbacks(self.cfg.callbacks, self.task_id) 290 self.instantiate_lightning_loggers(self.cfg.lightning_loggers, self.task_id) 291 self.instantiate_trainer( 292 self.cfg.trainer, self.task_id 293 ) # trainer should be instantiated after loggers and callbacks 294 295 def setup_task_specific(self) -> None: 296 """Setup task-specific components to get ready for the current task `self.task_id`.""" 297 298 self.cl_dataset.setup_task_id(self.task_id) 299 self.backbone.setup_task_id(self.task_id) 300 self.model.setup_task_id( 301 self.task_id, 302 len(self.cl_dataset.cl_class_map(self.task_id)), 303 self.optimizer, 304 ) 305 306 pylogger.debug( 307 "Datamodule, model and loggers are all set up ready for task %d!", 308 self.task_id, 309 ) 310 311 def fit_task(self) -> None: 312 """Fit the model on the current task `self.task_id`. Also test the model if `self.test` is set to True.""" 313 314 self.trainer.fit( 315 model=self.model, 316 datamodule=self.cl_dataset, 317 ) 318 319 if self.test: 320 # test after train and validate 321 self.trainer.test( 322 model=self.model, 323 datamodule=self.cl_dataset, 324 ) 325 326 def fit(self) -> None: 327 """The main method to run the continual learning experiment.""" 328 329 self.instantiate_global() 330 self.setup_global() 331 332 # task loop 333 for task_id in range(1, self.num_tasks + 1): # task ID counts from 1 334 335 self.setup_task_id(task_id) 336 self.instantiate_task_specific() 337 self.setup_task_specific() 338 339 self.fit_task()
The base class for continual learning experiments.
27 def __init__(self, cfg: DictConfig) -> None: 28 """Initializes the CL experiment object with a complete configuration. 29 30 **Args:** 31 - **cfg** (`DictConfig`): the complete config dict for the CL experiment. 32 """ 33 self.cfg: DictConfig = cfg 34 """Store the complete config dict for any future reference.""" 35 self.sanity_check() 36 37 self.cl_paradigm: str = cfg.cl_paradigm 38 """Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset.""" 39 self.num_tasks: int = cfg.num_tasks 40 """Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop.""" 41 self.test: bool = cfg.test 42 """Store whether to test the model after training and validation. Parsed from config and used in the tasks loop.""" 43 self.output_dir_name: str = cfg.output_dir_name 44 """Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory.""" 45 46 self.task_id: int 47 """Task ID counter indicating which task is being processed. Self updated during the task loop.""" 48 49 self.cl_dataset: CLDataset 50 """CL dataset object. Instantiate in `instantiate_cl_dataset()`.""" 51 self.backbone: CLBackbone 52 """Backbone network object. Instantiate in `instantiate_backbone()`.""" 53 self.heads: HeadsTIL | HeadsCIL 54 """CL output heads object. Instantiate in `instantiate_heads()`.""" 55 self.model: CLAlgorithm 56 """CL model object. Instantiate in `instantiate_cl_algorithm()`.""" 57 58 self.optimizer: Optimizer 59 """Optimizer object for current task `self.task_id`. Instantiate in `instantiate_optimizer()`.""" 60 self.trainer: Trainer 61 """Trainer object for current task `self.task_id`. Instantiate in `instantiate_trainer()`.""" 62 self.lightning_loggers: list[Logger] 63 """The list of initialised lightning loggers objects for current task `self.task_id`. Instantiate in `instantiate_lightning_loggers()`.""" 64 self.callbacks: list[Callback] 65 """The list of initialised callbacks objects for current task `self.task_id`. Instantiate in `instantiate_callbacks()`."""
Initializes the CL experiment object with a complete configuration.
Args:
- cfg (
DictConfig
): the complete config dict for the CL experiment.
Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset.
Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop.
Store whether to test the model after training and validation. Parsed from config and used in the tasks loop.
Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory.
Task ID counter indicating which task is being processed. Self updated during the task loop.
CL dataset object. Instantiate in instantiate_cl_dataset()
.
Backbone network object. Instantiate in instantiate_backbone()
.
CL output heads object. Instantiate in instantiate_heads()
.
CL model object. Instantiate in instantiate_cl_algorithm()
.
Optimizer object for current task self.task_id
. Instantiate in instantiate_optimizer()
.
Trainer object for current task self.task_id
. Instantiate in instantiate_trainer()
.
The list of initialised lightning loggers objects for current task self.task_id
. Instantiate in instantiate_lightning_loggers()
.
The list of initialised callbacks objects for current task self.task_id
. Instantiate in instantiate_callbacks()
.
67 def sanity_check(self) -> None: 68 """Check the sanity of the config dict `self.cfg`. 69 70 **Raises:** 71 - **KeyError**: when required fields in experiment config are missing, including `cl_paradigm`, `num_tasks`, `test`, `output_dir_name`. 72 - **ValueError**: when the value of `cl_paradigm` is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset. 73 """ 74 if not self.cfg.get("cl_paradigm"): 75 raise KeyError( 76 "Field cl_paradigm should be specified in experiment config!" 77 ) 78 79 if self.cfg.cl_paradigm not in ["TIL", "CIL"]: 80 raise ValueError( 81 f"Field cl_paradigm should be either 'TIL' or 'CIL' but got {self.cfg.cl_paradigm}!" 82 ) 83 84 if not self.cfg.get("num_tasks"): 85 raise KeyError("Field num_tasks should be specified in experiment config!") 86 87 if not self.cfg.cl_dataset.get("num_tasks"): 88 raise KeyError("Field num_tasks should be specified in cl_dataset config!") 89 90 if not self.cfg.num_tasks <= self.cfg.cl_dataset.num_tasks: 91 raise ValueError( 92 f"The experiment is set to run {self.cfg.num_tasks} tasks whereas only {self.cfg.cl_dataset.num_tasks} exists in current cl_dataset setting!" 93 ) 94 95 if not self.cfg.get("test"): 96 raise KeyError("Field test should be specified in experiment config!") 97 98 if not self.cfg.get("output_dir_name"): 99 raise KeyError( 100 "Field output_dir_name should be specified in experiment config!" 101 )
Check the sanity of the config dict self.cfg
.
Raises:
- KeyError: when required fields in experiment config are missing, including
cl_paradigm
,num_tasks
,test
,output_dir_name
. - ValueError: when the value of
cl_paradigm
is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset.
103 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 104 """Instantiate the CL dataset object from cl_dataset config. 105 106 **Args:** 107 - **cl_dataset_cfg** (`DictConfig`): the cl_dataset config dict. 108 """ 109 pylogger.debug( 110 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 111 cl_dataset_cfg.get("_target_"), 112 ) 113 self.cl_dataset: LightningDataModule = hydra.utils.instantiate(cl_dataset_cfg) 114 pylogger.debug( 115 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 116 cl_dataset_cfg.get("_target_"), 117 )
Instantiate the CL dataset object from cl_dataset config.
Args:
- cl_dataset_cfg (
DictConfig
): the cl_dataset config dict.
119 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 120 """Instantiate the CL backbone network object from backbone config. 121 122 **Args:** 123 - **backbone_cfg** (`DictConfig`): the backbone config dict. 124 """ 125 pylogger.debug( 126 "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...", 127 backbone_cfg.get("_target_"), 128 ) 129 self.backbone: nn.Module = hydra.utils.instantiate(backbone_cfg) 130 pylogger.debug( 131 "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!", 132 backbone_cfg.get("_target_"), 133 )
Instantiate the CL backbone network object from backbone config.
Args:
- backbone_cfg (
DictConfig
): the backbone config dict.
135 def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None: 136 """Instantiate the CL output heads object according to field `cl_paradigm` and backbone `output_dim` in the config. 137 138 **Args:** 139 - **cl_paradigm** (`str`): the CL paradigm, either 'TIL' or 'CIL'. 140 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 141 """ 142 pylogger.debug( 143 "CL paradigm is set as %s. Instantiating %s heads (torch.nn.Module)...", 144 cl_paradigm, 145 cl_paradigm, 146 ) 147 self.heads: HeadsTIL | HeadsCIL = ( 148 HeadsTIL(input_dim=input_dim) 149 if cl_paradigm == "TIL" 150 else HeadsCIL(input_dim=input_dim) 151 ) 152 pylogger.debug("%s heads (torch.nn.Module) instantiated! ", cl_paradigm)
Instantiate the CL output heads object according to field cl_paradigm
and backbone output_dim
in the config.
Args:
- cl_paradigm (
str
): the CL paradigm, either 'TIL' or 'CIL'. - input_dim (
int
): the input dimension of the heads. Must be equal to theoutput_dim
of the connected backbone.
154 def instantiate_cl_algorithm(self, cl_algorithm_cfg: DictConfig) -> None: 155 """Instantiate the cl_algorithm object from cl_algorithm config. 156 157 **Args:** 158 - **cl_algorithm_cfg** (`DictConfig`): the cl_algorithm config dict. 159 """ 160 pylogger.debug( 161 "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...", 162 cl_algorithm_cfg.get("_target_"), 163 cl_algorithm_cfg.get("_target_"), 164 ) 165 self.model: LightningModule = hydra.utils.instantiate( 166 cl_algorithm_cfg, 167 backbone=self.backbone, 168 heads=self.heads, 169 ) 170 pylogger.debug( 171 "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!", 172 cl_algorithm_cfg.get("_target_"), 173 )
Instantiate the cl_algorithm object from cl_algorithm config.
Args:
- cl_algorithm_cfg (
DictConfig
): the cl_algorithm config dict.
175 def instantiate_optimizer( 176 self, optimizer_cfg: DictConfig | ListConfig, task_id: int 177 ) -> None: 178 """Instantiate the optimizer object for task `task_id` from optimizer config. 179 180 **Args:** 181 - **optimizer_cfg** (`DictConfig` or `ListConfig`): the optimizer config dict. If it's a `ListConfig`, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks. 182 - **task_id** (`int`): the target task ID. 183 """ 184 if isinstance(optimizer_cfg, ListConfig): 185 pylogger.debug("Distinct optimizer config is applied to each task.") 186 optimizer_cfg = optimizer_cfg[task_id - 1] 187 elif isinstance(optimizer_cfg, DictConfig): 188 pylogger.debug("Uniform optimizer config is applied to all tasks.") 189 190 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 191 pylogger.debug( 192 "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...", 193 optimizer_cfg.get("_target_"), 194 task_id, 195 ) 196 self.optimizer: Optimizer = hydra.utils.instantiate(optimizer_cfg) 197 pylogger.debug( 198 "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!", 199 optimizer_cfg.get("_target_"), 200 task_id, 201 )
Instantiate the optimizer object for task task_id
from optimizer config.
Args:
- optimizer_cfg (
DictConfig
orListConfig
): the optimizer config dict. If it's aListConfig
, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks. - task_id (
int
): the target task ID.
203 def instantiate_trainer(self, trainer_cfg: DictConfig, task_id: int) -> None: 204 """Instantiate the trainer object for task `task_id` from trainer config. 205 206 **Args:** 207 - **trainer_cfg** (`DictConfig`): the trainer config dict. All tasks share the same trainer config but different objects. 208 - **task_id** (`int`): the target task ID. 209 """ 210 pylogger.debug( 211 "Instantiating trainer <%s> (lightning.Trainer) for task %d...", 212 trainer_cfg.get("_target_"), 213 task_id, 214 ) 215 self.trainer: Trainer = hydra.utils.instantiate( 216 trainer_cfg, callbacks=self.callbacks, logger=self.lightning_loggers 217 ) 218 pylogger.debug( 219 "Trainer <%s> (lightning.Trainer) for task %d instantiated!", 220 trainer_cfg.get("_target_"), 221 task_id, 222 )
Instantiate the trainer object for task task_id
from trainer config.
Args:
- trainer_cfg (
DictConfig
): the trainer config dict. All tasks share the same trainer config but different objects. - task_id (
int
): the target task ID.
224 def instantiate_lightning_loggers( 225 self, lightning_loggers_cfg: DictConfig, task_id: int 226 ) -> None: 227 """Instantiate the list of lightning loggers objects for task `task_id` from lightning_loggers config. 228 229 **Args:** 230 - **lightning_loggers_cfg** (`DictConfig`): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects. 231 - **task_id** (`int`): the target task ID. 232 """ 233 pylogger.debug( 234 "Instantiating Lightning loggers (lightning.Logger) for task %d...", task_id 235 ) 236 self.lightning_loggers: list[Logger] = [ 237 hydra.utils.instantiate( 238 lightning_logger, version=f"task_{task_id}" 239 ) # change the directory name to "task_" prefix in lightning logs 240 for lightning_logger in lightning_loggers_cfg.values() 241 ] 242 pylogger.debug( 243 "Lightning loggers (lightning.Logger) for task %d instantiated!", task_id 244 )
Instantiate the list of lightning loggers objects for task task_id
from lightning_loggers config.
Args:
- lightning_loggers_cfg (
DictConfig
): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects. - task_id (
int
): the target task ID.
246 def instantiate_callbacks(self, callbacks_cfg: DictConfig, task_id: int) -> None: 247 """Instantiate the list of callbacks objects for task `task_id` from callbacks config. 248 249 **Args:** 250 - **callbacks_cfg** (`DictConfig`): the callbacks config dict. All tasks share the same callbacks config but different objects. 251 - **task_id** (`int`): the target task ID. 252 """ 253 pylogger.debug( 254 "Instantiating callbacks (lightning.Callback) for task %d...", task_id 255 ) 256 self.callbacks: list[Callback] = [ 257 hydra.utils.instantiate(callback) for callback in callbacks_cfg.values() 258 ] 259 pylogger.debug( 260 "Callbacks (lightning.Callback) for task %d instantiated!", task_id 261 )
Instantiate the list of callbacks objects for task task_id
from callbacks config.
Args:
- callbacks_cfg (
DictConfig
): the callbacks config dict. All tasks share the same callbacks config but different objects. - task_id (
int
): the target task ID.
263 def setup_task_id(self, task_id: int) -> None: 264 """Set up current task_id in the beginning of the continual learning process of a new task. 265 266 **Args:** 267 - **task_id** (`int`): current task_id. 268 """ 269 self.task_id = task_id
Set up current task_id in the beginning of the continual learning process of a new task.
Args:
- task_id (
int
): current task_id.
271 def instantiate_global(self) -> None: 272 """Instantiate global components for the entire CL experiment from `self.cfg`.""" 273 274 self.instantiate_cl_dataset(self.cfg.cl_dataset) 275 self.instantiate_backbone(self.cfg.backbone) 276 self.instantiate_heads(self.cfg.cl_paradigm, self.cfg.backbone.output_dim) 277 self.instantiate_cl_algorithm( 278 self.cfg.cl_algorithm 279 ) # cl_algorithm should be instantiated after backbone and heads
Instantiate global components for the entire CL experiment from self.cfg
.
281 def setup_global(self) -> None: 282 """Let CL dataset know the CL paradigm to define its CL class map.""" 283 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
Let CL dataset know the CL paradigm to define its CL class map.
285 def instantiate_task_specific(self) -> None: 286 """Instantiate task-specific components for the current task `self.task_id` from `self.cfg`.""" 287 288 self.instantiate_optimizer(self.cfg.optimizer, self.task_id) 289 self.instantiate_callbacks(self.cfg.callbacks, self.task_id) 290 self.instantiate_lightning_loggers(self.cfg.lightning_loggers, self.task_id) 291 self.instantiate_trainer( 292 self.cfg.trainer, self.task_id 293 ) # trainer should be instantiated after loggers and callbacks
Instantiate task-specific components for the current task self.task_id
from self.cfg
.
295 def setup_task_specific(self) -> None: 296 """Setup task-specific components to get ready for the current task `self.task_id`.""" 297 298 self.cl_dataset.setup_task_id(self.task_id) 299 self.backbone.setup_task_id(self.task_id) 300 self.model.setup_task_id( 301 self.task_id, 302 len(self.cl_dataset.cl_class_map(self.task_id)), 303 self.optimizer, 304 ) 305 306 pylogger.debug( 307 "Datamodule, model and loggers are all set up ready for task %d!", 308 self.task_id, 309 )
Setup task-specific components to get ready for the current task self.task_id
.
311 def fit_task(self) -> None: 312 """Fit the model on the current task `self.task_id`. Also test the model if `self.test` is set to True.""" 313 314 self.trainer.fit( 315 model=self.model, 316 datamodule=self.cl_dataset, 317 ) 318 319 if self.test: 320 # test after train and validate 321 self.trainer.test( 322 model=self.model, 323 datamodule=self.cl_dataset, 324 )
Fit the model on the current task self.task_id
. Also test the model if self.test
is set to True.
326 def fit(self) -> None: 327 """The main method to run the continual learning experiment.""" 328 329 self.instantiate_global() 330 self.setup_global() 331 332 # task loop 333 for task_id in range(1, self.num_tasks + 1): # task ID counts from 1 334 335 self.setup_task_id(task_id) 336 self.instantiate_task_specific() 337 self.setup_task_specific() 338 339 self.fit_task()
The main method to run the continual learning experiment.