clarena.pipelines.cl_main_expr

The submodule in pipelines for continual learning main experiment.

  1r"""
  2The submodule in `pipelines` for continual learning main experiment.
  3"""
  4
  5__all__ = ["CLMainExperiment"]
  6
  7import logging
  8from typing import Any
  9
 10import hydra
 11import lightning as L
 12from lightning import Callback, Trainer
 13from lightning.pytorch.loggers import Logger
 14from omegaconf import DictConfig, ListConfig
 15from torch.optim import Optimizer
 16from torch.optim.lr_scheduler import LRScheduler
 17
 18from clarena.backbones import CLBackbone
 19from clarena.cl_algorithms import CLAlgorithm
 20from clarena.cl_datasets import CLDataset
 21from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL
 22from clarena.utils.cfg import select_hyperparameters_from_config
 23
 24# always get logger for built-in logging in each module
 25pylogger = logging.getLogger(__name__)
 26
 27
 28class CLMainExperiment:
 29    r"""The base class for continual learning main experiment."""
 30
 31    def __init__(self, cfg: DictConfig) -> None:
 32        r"""
 33        **Args:**
 34        - **cfg** (`DictConfig`): the complete config dict for the continual learning main experiment.
 35        """
 36        self.cfg: DictConfig = cfg
 37        r"""The complete config dict."""
 38
 39        CLMainExperiment.sanity_check(self)
 40
 41        # required config fields
 42        self.cl_paradigm: str = cfg.cl_paradigm
 43        r"""The continual learning paradigm."""
 44        self.train_tasks: list[int] = (
 45            cfg.train_tasks
 46            if isinstance(cfg.train_tasks, ListConfig)
 47            else list(range(1, cfg.train_tasks + 1))
 48        )
 49        r"""The list of task IDs to train."""
 50        self.eval_after_tasks: list[int] = (
 51            cfg.eval_after_tasks
 52            if isinstance(cfg.eval_after_tasks, ListConfig)
 53            else list(range(1, cfg.eval_after_tasks + 1))
 54        )
 55        r"""If task ID $t$ is in this list, run the evaluation process for all seen tasks after training task $t$."""
 56        self.global_seed: int = cfg.global_seed
 57        r"""The global seed for the entire experiment."""
 58        self.output_dir: str = cfg.output_dir
 59        r"""The folder for storing the experiment results."""
 60
 61        # components
 62
 63        # global components
 64        self.cl_dataset: CLDataset
 65        r"""CL dataset object."""
 66        self.backbone: CLBackbone
 67        r"""Backbone network object."""
 68        self.heads: HeadsTIL | HeadsCIL
 69        r"""CL output heads object."""
 70        self.model: CLAlgorithm
 71        r"""CL model object."""
 72        self.lightning_loggers: list[Logger]
 73        r"""Lightning logger objects."""
 74        self.callbacks: list[Callback]
 75        r"""Callback objects."""
 76
 77        # task-specific components
 78        self.optimizer_t: Optimizer
 79        r"""Optimizer object for the current task `self.task_id`."""
 80        self.lr_scheduler_t: LRScheduler | None = None
 81        r"""Learning rate scheduler object for the current task `self.task_id`."""
 82        self.trainer_t: Trainer
 83        r"""Trainer object for the current task `self.task_id`."""
 84
 85        # task ID control
 86        self.task_id: int
 87        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
 88        self.processed_task_ids: list[int] = []
 89        r"""Task IDs that have been processed."""
 90
 91    def sanity_check(self) -> None:
 92        r"""Sanity check for config."""
 93
 94        # check required config fields
 95        required_config_fields = [
 96            "pipeline",
 97            "expr_name",
 98            "cl_paradigm",
 99            "train_tasks",
100            "eval_after_tasks",
101            "global_seed",
102            "cl_dataset",
103            "cl_algorithm",
104            "backbone",
105            "optimizer",
106            "trainer",
107            "metrics",
108            "lightning_loggers",
109            "callbacks",
110            "output_dir",
111            # "hydra" is excluded as it doesn't appear
112            "misc",
113        ]
114        for field in required_config_fields:
115            if not self.cfg.get(field):
116                raise KeyError(
117                    f"Field `{field}` is required in the experiment index config."
118                )
119
120        # check cl_paradigm
121        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
122            raise ValueError(
123                f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
124            )
125
126        # get dataset number of tasks
127        if self.cfg.cl_dataset.get("num_tasks"):
128            num_tasks = self.cfg.cl_dataset.get("num_tasks")
129        elif self.cfg.cl_dataset.get("class_split"):
130            num_tasks = len(self.cfg.cl_dataset.class_split)
131        elif self.cfg.cl_dataset.get("datasets"):
132            num_tasks = len(self.cfg.cl_dataset.datasets)
133        else:
134            raise KeyError(
135                "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
136            )
137
138        # check train_tasks
139        train_tasks = self.cfg.train_tasks
140        if isinstance(train_tasks, ListConfig):
141            if len(train_tasks) < 1:
142                raise ValueError("`train_tasks` config must contain at least one task.")
143            if any(t < 1 or t > num_tasks for t in train_tasks):
144                raise ValueError(
145                    f"All task IDs in `train_tasks` config must be between 1 and {num_tasks}."
146                )
147        elif isinstance(train_tasks, int):
148            if train_tasks < 0 or train_tasks > num_tasks:
149                raise ValueError(
150                    f"`train_tasks` config as integer must be between 0 and {num_tasks}."
151                )
152        else:
153            raise TypeError(
154                "`train_tasks` config must be either a list of integers or an integer."
155            )
156
157        # check eval_after_tasks
158        eval_after_tasks = self.cfg.eval_after_tasks
159        if isinstance(eval_after_tasks, ListConfig):
160            if len(eval_after_tasks) < 1:
161                raise ValueError(
162                    "`eval_after_tasks` config must contain at least one task."
163                )
164            if any(t < 1 or t > num_tasks for t in eval_after_tasks):
165                raise ValueError(
166                    f"All task IDs in `eval_after_tasks` config must be between 1 and {num_tasks}."
167                )
168        elif isinstance(eval_after_tasks, int):
169            if eval_after_tasks < 0 or eval_after_tasks > num_tasks:
170                raise ValueError(
171                    f"`eval_after_tasks` config as integer must be between 0 and {num_tasks}."
172                )
173        else:
174            raise TypeError(
175                "`eval_after_tasks` config must be either a list of integers or an integer."
176            )
177
178        # check that eval_after_tasks is a subset of train_tasks
179        if isinstance(train_tasks, list) and isinstance(eval_after_tasks, list):
180            if not set(eval_after_tasks).issubset(set(train_tasks)):
181                raise ValueError(
182                    "`eval_after_tasks` config must be a subset of `train_tasks` config."
183                )
184
185    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
186        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
187        pylogger.debug(
188            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
189            cl_dataset_cfg.get("_target_"),
190        )
191        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
192        pylogger.debug(
193            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
194            cl_dataset_cfg.get("_target_"),
195        )
196
197    def instantiate_backbone(
198        self, backbone_cfg: DictConfig, disable_unlearning: bool
199    ) -> None:
200        r"""Instantiate the CL backbone network object from `backbone_cfg`."""
201        pylogger.debug(
202            "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...",
203            backbone_cfg.get("_target_"),
204        )
205        self.backbone = hydra.utils.instantiate(
206            backbone_cfg, disable_unlearning=disable_unlearning
207        )
208        pylogger.debug(
209            "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!",
210            backbone_cfg.get("_target_"),
211        )
212
213    def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
214        r"""Instantiate the CL output heads object.
215
216        **Args:**
217        - **cl_paradigm** (`str`): the CL paradigm, either 'TIL', 'CIL' or 'DIL'. 'TIL' uses `HeadsTIL`, 'CIL' uses `HeadsCIL`, and 'DIL' uses `HeadDIL`.
218        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
219        """
220        pylogger.debug(
221            "CL paradigm is set as %s. Instantiating %s heads...",
222            cl_paradigm,
223            cl_paradigm,
224        )
225        if cl_paradigm == "TIL":
226            self.heads = HeadsTIL(input_dim=input_dim)
227        elif cl_paradigm == "CIL":
228            self.heads = HeadsCIL(input_dim=input_dim)
229        elif cl_paradigm == "DIL":
230            self.heads = HeadDIL(input_dim=input_dim)
231
232        pylogger.debug("%s heads instantiated!", cl_paradigm)
233
234    def instantiate_cl_algorithm(
235        self,
236        cl_algorithm_cfg: DictConfig,
237        backbone: CLBackbone,
238        heads: HeadsTIL | HeadsCIL | HeadDIL,
239        non_algorithmic_hparams: dict[str, Any],
240        disable_unlearning: bool,
241    ) -> None:
242        r"""Instantiate the cl_algorithm object from `cl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
243        pylogger.debug(
244            "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...",
245            cl_algorithm_cfg.get("_target_"),
246            cl_algorithm_cfg.get("_target_"),
247        )
248        self.model = hydra.utils.instantiate(
249            cl_algorithm_cfg,
250            backbone=backbone,
251            heads=heads,
252            non_algorithmic_hparams=non_algorithmic_hparams,
253            disable_unlearning=disable_unlearning,
254        )
255        pylogger.debug(
256            "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!",
257            cl_algorithm_cfg.get("_target_"),
258        )
259
260    def instantiate_optimizer(
261        self,
262        optimizer_cfg: DictConfig,
263        task_id: int,
264    ) -> None:
265        r"""Instantiate the optimizer object for task `task_id` from `optimizer_cfg`."""
266
267        # distinguish whether the optimizer config is uniform or task-specific
268        if not optimizer_cfg.get("_target_"):
269            pylogger.debug("Distinct optimizer config is applied to each task.")
270            optimizer_cfg = optimizer_cfg[task_id]
271        else:
272            pylogger.debug("Uniform optimizer config is applied to all tasks.")
273
274        # partially instantiate optimizer as the 'params' argument from Lightning Modules cannot be passed for now
275        pylogger.debug(
276            "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...",
277            optimizer_cfg.get("_target_"),
278            task_id,
279        )
280        self.optimizer_t = hydra.utils.instantiate(optimizer_cfg)
281        pylogger.debug(
282            "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!",
283            optimizer_cfg.get("_target_"),
284            task_id,
285        )
286
287    def instantiate_lr_scheduler(
288        self,
289        lr_scheduler_cfg: DictConfig,
290        task_id: int,
291    ) -> None:
292        r"""Instantiate the learning rate scheduler object for task `task_id` from `lr_scheduler_cfg`."""
293
294        # distinguish whether the learning rate scheduler config is uniform or task-specific
295        if not lr_scheduler_cfg.get("_target_"):
296            pylogger.debug(
297                "Distinct learning rate scheduler config is applied to each task."
298            )
299            lr_scheduler_cfg = lr_scheduler_cfg[task_id]
300        else:
301            pylogger.debug(
302                "Uniform learning rate scheduler config is applied to all tasks."
303            )
304
305        # partially instantiate learning rate scheduler as the 'optimizer' argument from Lightning Modules cannot be passed for now
306        pylogger.debug(
307            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) for task %d...",
308            lr_scheduler_cfg.get("_target_"),
309            task_id,
310        )
311        self.lr_scheduler_t = hydra.utils.instantiate(lr_scheduler_cfg)
312        pylogger.debug(
313            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially for task %d instantiated!",
314            lr_scheduler_cfg.get("_target_"),
315            task_id,
316        )
317
318    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
319        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
320
321        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
322        self.lightning_loggers = [
323            hydra.utils.instantiate(lightning_logger)
324            for lightning_logger in lightning_loggers_cfg.values()
325        ]
326        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
327
328    def instantiate_callbacks(
329        self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig
330    ) -> None:
331        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
332        pylogger.debug(
333            "Instantiating callbacks (lightning.Callback)...",
334        )
335
336        # instantiate metric callbacks
337        metric_callbacks = [
338            hydra.utils.instantiate(callback) for callback in metrics_cfg
339        ]
340
341        # instantiate other callbacks
342        other_callbacks = [
343            hydra.utils.instantiate(callback) for callback in callbacks_cfg
344        ]
345
346        # add metric callbacks to the list of callbacks
347        self.callbacks = metric_callbacks + other_callbacks
348        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
349
350    def instantiate_trainer(
351        self,
352        trainer_cfg: DictConfig,
353        lightning_loggers: list[Logger],
354        callbacks: list[Callback],
355        task_id: int,
356    ) -> None:
357        r"""Instantiate the trainer object for task `task_id` from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
358
359        if not trainer_cfg.get("_target_"):
360            pylogger.debug("Distinct trainer config is applied to each task.")
361            trainer_cfg = trainer_cfg[task_id]
362        else:
363            pylogger.debug("Uniform trainer config is applied to all tasks.")
364
365        pylogger.debug(
366            "Instantiating trainer (lightning.Trainer) for task %d...",
367            task_id,
368        )
369        self.trainer_t = hydra.utils.instantiate(
370            trainer_cfg,
371            logger=lightning_loggers,
372            callbacks=callbacks,
373        )
374        pylogger.debug(
375            "Trainer (lightning.Trainer) for task %d instantiated!",
376            task_id,
377        )
378
379    def set_global_seed(self, global_seed: int) -> None:
380        r"""Set the `global_seed` for the entire experiment."""
381        L.seed_everything(self.global_seed, workers=True)
382        pylogger.debug("Global seed is set as %d.", global_seed)
383
384    def run(self) -> None:
385        r"""The main method to run the continual learning main experiment."""
386
387        self.set_global_seed(self.global_seed)
388
389        # global components
390        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
391        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
392        self.instantiate_backbone(
393            backbone_cfg=self.cfg.backbone, disable_unlearning=True
394        )
395        self.instantiate_heads(
396            cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim
397        )
398        self.instantiate_cl_algorithm(
399            cl_algorithm_cfg=self.cfg.cl_algorithm,
400            backbone=self.backbone,
401            heads=self.heads,
402            non_algorithmic_hparams=select_hyperparameters_from_config(
403                cfg=self.cfg, type=self.cfg.pipeline
404            ),
405            disable_unlearning=True,
406        )  # cl_algorithm should be instantiated after backbone and heads
407        self.instantiate_lightning_loggers(
408            lightning_loggers_cfg=self.cfg.lightning_loggers
409        )
410        self.instantiate_callbacks(
411            metrics_cfg=self.cfg.metrics,
412            callbacks_cfg=self.cfg.callbacks,
413        )
414
415        # task loop
416        for task_id in self.train_tasks:
417
418            self.task_id = task_id
419
420            # task-specific components
421            self.instantiate_optimizer(
422                optimizer_cfg=self.cfg.optimizer,
423                task_id=task_id,
424            )
425            if self.cfg.get("lr_scheduler"):
426                self.instantiate_lr_scheduler(
427                    lr_scheduler_cfg=self.cfg.lr_scheduler,
428                    task_id=task_id,
429                )
430            self.instantiate_trainer(
431                trainer_cfg=self.cfg.trainer,
432                lightning_loggers=self.lightning_loggers,
433                callbacks=self.callbacks,
434                task_id=task_id,
435            )  # trainer should be instantiated after lightning loggers and callbacks
436
437            # setup task ID for dataset and model
438            self.cl_dataset.setup_task_id(task_id=task_id)
439            self.model.setup_task_id(
440                task_id=task_id,
441                num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)),
442                optimizer=self.optimizer_t,
443                lr_scheduler=self.lr_scheduler_t,
444            )
445
446            # train and validate the model
447            self.trainer_t.fit(
448                model=self.model,
449                datamodule=self.cl_dataset,
450            )
451
452            # evaluation after training and validation
453            if task_id in self.eval_after_tasks:
454                self.trainer_t.test(
455                    model=self.model,
456                    datamodule=self.cl_dataset,
457                )
458
459            self.processed_task_ids.append(task_id)
class CLMainExperiment:
 29class CLMainExperiment:
 30    r"""The base class for continual learning main experiment."""
 31
 32    def __init__(self, cfg: DictConfig) -> None:
 33        r"""
 34        **Args:**
 35        - **cfg** (`DictConfig`): the complete config dict for the continual learning main experiment.
 36        """
 37        self.cfg: DictConfig = cfg
 38        r"""The complete config dict."""
 39
 40        CLMainExperiment.sanity_check(self)
 41
 42        # required config fields
 43        self.cl_paradigm: str = cfg.cl_paradigm
 44        r"""The continual learning paradigm."""
 45        self.train_tasks: list[int] = (
 46            cfg.train_tasks
 47            if isinstance(cfg.train_tasks, ListConfig)
 48            else list(range(1, cfg.train_tasks + 1))
 49        )
 50        r"""The list of task IDs to train."""
 51        self.eval_after_tasks: list[int] = (
 52            cfg.eval_after_tasks
 53            if isinstance(cfg.eval_after_tasks, ListConfig)
 54            else list(range(1, cfg.eval_after_tasks + 1))
 55        )
 56        r"""If task ID $t$ is in this list, run the evaluation process for all seen tasks after training task $t$."""
 57        self.global_seed: int = cfg.global_seed
 58        r"""The global seed for the entire experiment."""
 59        self.output_dir: str = cfg.output_dir
 60        r"""The folder for storing the experiment results."""
 61
 62        # components
 63
 64        # global components
 65        self.cl_dataset: CLDataset
 66        r"""CL dataset object."""
 67        self.backbone: CLBackbone
 68        r"""Backbone network object."""
 69        self.heads: HeadsTIL | HeadsCIL
 70        r"""CL output heads object."""
 71        self.model: CLAlgorithm
 72        r"""CL model object."""
 73        self.lightning_loggers: list[Logger]
 74        r"""Lightning logger objects."""
 75        self.callbacks: list[Callback]
 76        r"""Callback objects."""
 77
 78        # task-specific components
 79        self.optimizer_t: Optimizer
 80        r"""Optimizer object for the current task `self.task_id`."""
 81        self.lr_scheduler_t: LRScheduler | None = None
 82        r"""Learning rate scheduler object for the current task `self.task_id`."""
 83        self.trainer_t: Trainer
 84        r"""Trainer object for the current task `self.task_id`."""
 85
 86        # task ID control
 87        self.task_id: int
 88        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
 89        self.processed_task_ids: list[int] = []
 90        r"""Task IDs that have been processed."""
 91
 92    def sanity_check(self) -> None:
 93        r"""Sanity check for config."""
 94
 95        # check required config fields
 96        required_config_fields = [
 97            "pipeline",
 98            "expr_name",
 99            "cl_paradigm",
100            "train_tasks",
101            "eval_after_tasks",
102            "global_seed",
103            "cl_dataset",
104            "cl_algorithm",
105            "backbone",
106            "optimizer",
107            "trainer",
108            "metrics",
109            "lightning_loggers",
110            "callbacks",
111            "output_dir",
112            # "hydra" is excluded as it doesn't appear
113            "misc",
114        ]
115        for field in required_config_fields:
116            if not self.cfg.get(field):
117                raise KeyError(
118                    f"Field `{field}` is required in the experiment index config."
119                )
120
121        # check cl_paradigm
122        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
123            raise ValueError(
124                f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
125            )
126
127        # get dataset number of tasks
128        if self.cfg.cl_dataset.get("num_tasks"):
129            num_tasks = self.cfg.cl_dataset.get("num_tasks")
130        elif self.cfg.cl_dataset.get("class_split"):
131            num_tasks = len(self.cfg.cl_dataset.class_split)
132        elif self.cfg.cl_dataset.get("datasets"):
133            num_tasks = len(self.cfg.cl_dataset.datasets)
134        else:
135            raise KeyError(
136                "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
137            )
138
139        # check train_tasks
140        train_tasks = self.cfg.train_tasks
141        if isinstance(train_tasks, ListConfig):
142            if len(train_tasks) < 1:
143                raise ValueError("`train_tasks` config must contain at least one task.")
144            if any(t < 1 or t > num_tasks for t in train_tasks):
145                raise ValueError(
146                    f"All task IDs in `train_tasks` config must be between 1 and {num_tasks}."
147                )
148        elif isinstance(train_tasks, int):
149            if train_tasks < 0 or train_tasks > num_tasks:
150                raise ValueError(
151                    f"`train_tasks` config as integer must be between 0 and {num_tasks}."
152                )
153        else:
154            raise TypeError(
155                "`train_tasks` config must be either a list of integers or an integer."
156            )
157
158        # check eval_after_tasks
159        eval_after_tasks = self.cfg.eval_after_tasks
160        if isinstance(eval_after_tasks, ListConfig):
161            if len(eval_after_tasks) < 1:
162                raise ValueError(
163                    "`eval_after_tasks` config must contain at least one task."
164                )
165            if any(t < 1 or t > num_tasks for t in eval_after_tasks):
166                raise ValueError(
167                    f"All task IDs in `eval_after_tasks` config must be between 1 and {num_tasks}."
168                )
169        elif isinstance(eval_after_tasks, int):
170            if eval_after_tasks < 0 or eval_after_tasks > num_tasks:
171                raise ValueError(
172                    f"`eval_after_tasks` config as integer must be between 0 and {num_tasks}."
173                )
174        else:
175            raise TypeError(
176                "`eval_after_tasks` config must be either a list of integers or an integer."
177            )
178
179        # check that eval_after_tasks is a subset of train_tasks
180        if isinstance(train_tasks, list) and isinstance(eval_after_tasks, list):
181            if not set(eval_after_tasks).issubset(set(train_tasks)):
182                raise ValueError(
183                    "`eval_after_tasks` config must be a subset of `train_tasks` config."
184                )
185
186    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
187        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
188        pylogger.debug(
189            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
190            cl_dataset_cfg.get("_target_"),
191        )
192        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
193        pylogger.debug(
194            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
195            cl_dataset_cfg.get("_target_"),
196        )
197
198    def instantiate_backbone(
199        self, backbone_cfg: DictConfig, disable_unlearning: bool
200    ) -> None:
201        r"""Instantiate the CL backbone network object from `backbone_cfg`."""
202        pylogger.debug(
203            "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...",
204            backbone_cfg.get("_target_"),
205        )
206        self.backbone = hydra.utils.instantiate(
207            backbone_cfg, disable_unlearning=disable_unlearning
208        )
209        pylogger.debug(
210            "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!",
211            backbone_cfg.get("_target_"),
212        )
213
214    def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
215        r"""Instantiate the CL output heads object.
216
217        **Args:**
218        - **cl_paradigm** (`str`): the CL paradigm, either 'TIL', 'CIL' or 'DIL'. 'TIL' uses `HeadsTIL`, 'CIL' uses `HeadsCIL`, and 'DIL' uses `HeadDIL`.
219        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
220        """
221        pylogger.debug(
222            "CL paradigm is set as %s. Instantiating %s heads...",
223            cl_paradigm,
224            cl_paradigm,
225        )
226        if cl_paradigm == "TIL":
227            self.heads = HeadsTIL(input_dim=input_dim)
228        elif cl_paradigm == "CIL":
229            self.heads = HeadsCIL(input_dim=input_dim)
230        elif cl_paradigm == "DIL":
231            self.heads = HeadDIL(input_dim=input_dim)
232
233        pylogger.debug("%s heads instantiated!", cl_paradigm)
234
235    def instantiate_cl_algorithm(
236        self,
237        cl_algorithm_cfg: DictConfig,
238        backbone: CLBackbone,
239        heads: HeadsTIL | HeadsCIL | HeadDIL,
240        non_algorithmic_hparams: dict[str, Any],
241        disable_unlearning: bool,
242    ) -> None:
243        r"""Instantiate the cl_algorithm object from `cl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
244        pylogger.debug(
245            "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...",
246            cl_algorithm_cfg.get("_target_"),
247            cl_algorithm_cfg.get("_target_"),
248        )
249        self.model = hydra.utils.instantiate(
250            cl_algorithm_cfg,
251            backbone=backbone,
252            heads=heads,
253            non_algorithmic_hparams=non_algorithmic_hparams,
254            disable_unlearning=disable_unlearning,
255        )
256        pylogger.debug(
257            "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!",
258            cl_algorithm_cfg.get("_target_"),
259        )
260
261    def instantiate_optimizer(
262        self,
263        optimizer_cfg: DictConfig,
264        task_id: int,
265    ) -> None:
266        r"""Instantiate the optimizer object for task `task_id` from `optimizer_cfg`."""
267
268        # distinguish whether the optimizer config is uniform or task-specific
269        if not optimizer_cfg.get("_target_"):
270            pylogger.debug("Distinct optimizer config is applied to each task.")
271            optimizer_cfg = optimizer_cfg[task_id]
272        else:
273            pylogger.debug("Uniform optimizer config is applied to all tasks.")
274
275        # partially instantiate optimizer as the 'params' argument from Lightning Modules cannot be passed for now
276        pylogger.debug(
277            "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...",
278            optimizer_cfg.get("_target_"),
279            task_id,
280        )
281        self.optimizer_t = hydra.utils.instantiate(optimizer_cfg)
282        pylogger.debug(
283            "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!",
284            optimizer_cfg.get("_target_"),
285            task_id,
286        )
287
288    def instantiate_lr_scheduler(
289        self,
290        lr_scheduler_cfg: DictConfig,
291        task_id: int,
292    ) -> None:
293        r"""Instantiate the learning rate scheduler object for task `task_id` from `lr_scheduler_cfg`."""
294
295        # distinguish whether the learning rate scheduler config is uniform or task-specific
296        if not lr_scheduler_cfg.get("_target_"):
297            pylogger.debug(
298                "Distinct learning rate scheduler config is applied to each task."
299            )
300            lr_scheduler_cfg = lr_scheduler_cfg[task_id]
301        else:
302            pylogger.debug(
303                "Uniform learning rate scheduler config is applied to all tasks."
304            )
305
306        # partially instantiate learning rate scheduler as the 'optimizer' argument from Lightning Modules cannot be passed for now
307        pylogger.debug(
308            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) for task %d...",
309            lr_scheduler_cfg.get("_target_"),
310            task_id,
311        )
312        self.lr_scheduler_t = hydra.utils.instantiate(lr_scheduler_cfg)
313        pylogger.debug(
314            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially for task %d instantiated!",
315            lr_scheduler_cfg.get("_target_"),
316            task_id,
317        )
318
319    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
320        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
321
322        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
323        self.lightning_loggers = [
324            hydra.utils.instantiate(lightning_logger)
325            for lightning_logger in lightning_loggers_cfg.values()
326        ]
327        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
328
329    def instantiate_callbacks(
330        self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig
331    ) -> None:
332        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
333        pylogger.debug(
334            "Instantiating callbacks (lightning.Callback)...",
335        )
336
337        # instantiate metric callbacks
338        metric_callbacks = [
339            hydra.utils.instantiate(callback) for callback in metrics_cfg
340        ]
341
342        # instantiate other callbacks
343        other_callbacks = [
344            hydra.utils.instantiate(callback) for callback in callbacks_cfg
345        ]
346
347        # add metric callbacks to the list of callbacks
348        self.callbacks = metric_callbacks + other_callbacks
349        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
350
351    def instantiate_trainer(
352        self,
353        trainer_cfg: DictConfig,
354        lightning_loggers: list[Logger],
355        callbacks: list[Callback],
356        task_id: int,
357    ) -> None:
358        r"""Instantiate the trainer object for task `task_id` from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
359
360        if not trainer_cfg.get("_target_"):
361            pylogger.debug("Distinct trainer config is applied to each task.")
362            trainer_cfg = trainer_cfg[task_id]
363        else:
364            pylogger.debug("Uniform trainer config is applied to all tasks.")
365
366        pylogger.debug(
367            "Instantiating trainer (lightning.Trainer) for task %d...",
368            task_id,
369        )
370        self.trainer_t = hydra.utils.instantiate(
371            trainer_cfg,
372            logger=lightning_loggers,
373            callbacks=callbacks,
374        )
375        pylogger.debug(
376            "Trainer (lightning.Trainer) for task %d instantiated!",
377            task_id,
378        )
379
380    def set_global_seed(self, global_seed: int) -> None:
381        r"""Set the `global_seed` for the entire experiment."""
382        L.seed_everything(self.global_seed, workers=True)
383        pylogger.debug("Global seed is set as %d.", global_seed)
384
385    def run(self) -> None:
386        r"""The main method to run the continual learning main experiment."""
387
388        self.set_global_seed(self.global_seed)
389
390        # global components
391        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
392        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
393        self.instantiate_backbone(
394            backbone_cfg=self.cfg.backbone, disable_unlearning=True
395        )
396        self.instantiate_heads(
397            cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim
398        )
399        self.instantiate_cl_algorithm(
400            cl_algorithm_cfg=self.cfg.cl_algorithm,
401            backbone=self.backbone,
402            heads=self.heads,
403            non_algorithmic_hparams=select_hyperparameters_from_config(
404                cfg=self.cfg, type=self.cfg.pipeline
405            ),
406            disable_unlearning=True,
407        )  # cl_algorithm should be instantiated after backbone and heads
408        self.instantiate_lightning_loggers(
409            lightning_loggers_cfg=self.cfg.lightning_loggers
410        )
411        self.instantiate_callbacks(
412            metrics_cfg=self.cfg.metrics,
413            callbacks_cfg=self.cfg.callbacks,
414        )
415
416        # task loop
417        for task_id in self.train_tasks:
418
419            self.task_id = task_id
420
421            # task-specific components
422            self.instantiate_optimizer(
423                optimizer_cfg=self.cfg.optimizer,
424                task_id=task_id,
425            )
426            if self.cfg.get("lr_scheduler"):
427                self.instantiate_lr_scheduler(
428                    lr_scheduler_cfg=self.cfg.lr_scheduler,
429                    task_id=task_id,
430                )
431            self.instantiate_trainer(
432                trainer_cfg=self.cfg.trainer,
433                lightning_loggers=self.lightning_loggers,
434                callbacks=self.callbacks,
435                task_id=task_id,
436            )  # trainer should be instantiated after lightning loggers and callbacks
437
438            # setup task ID for dataset and model
439            self.cl_dataset.setup_task_id(task_id=task_id)
440            self.model.setup_task_id(
441                task_id=task_id,
442                num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)),
443                optimizer=self.optimizer_t,
444                lr_scheduler=self.lr_scheduler_t,
445            )
446
447            # train and validate the model
448            self.trainer_t.fit(
449                model=self.model,
450                datamodule=self.cl_dataset,
451            )
452
453            # evaluation after training and validation
454            if task_id in self.eval_after_tasks:
455                self.trainer_t.test(
456                    model=self.model,
457                    datamodule=self.cl_dataset,
458                )
459
460            self.processed_task_ids.append(task_id)

The base class for continual learning main experiment.

CLMainExperiment(cfg: omegaconf.dictconfig.DictConfig)
32    def __init__(self, cfg: DictConfig) -> None:
33        r"""
34        **Args:**
35        - **cfg** (`DictConfig`): the complete config dict for the continual learning main experiment.
36        """
37        self.cfg: DictConfig = cfg
38        r"""The complete config dict."""
39
40        CLMainExperiment.sanity_check(self)
41
42        # required config fields
43        self.cl_paradigm: str = cfg.cl_paradigm
44        r"""The continual learning paradigm."""
45        self.train_tasks: list[int] = (
46            cfg.train_tasks
47            if isinstance(cfg.train_tasks, ListConfig)
48            else list(range(1, cfg.train_tasks + 1))
49        )
50        r"""The list of task IDs to train."""
51        self.eval_after_tasks: list[int] = (
52            cfg.eval_after_tasks
53            if isinstance(cfg.eval_after_tasks, ListConfig)
54            else list(range(1, cfg.eval_after_tasks + 1))
55        )
56        r"""If task ID $t$ is in this list, run the evaluation process for all seen tasks after training task $t$."""
57        self.global_seed: int = cfg.global_seed
58        r"""The global seed for the entire experiment."""
59        self.output_dir: str = cfg.output_dir
60        r"""The folder for storing the experiment results."""
61
62        # components
63
64        # global components
65        self.cl_dataset: CLDataset
66        r"""CL dataset object."""
67        self.backbone: CLBackbone
68        r"""Backbone network object."""
69        self.heads: HeadsTIL | HeadsCIL
70        r"""CL output heads object."""
71        self.model: CLAlgorithm
72        r"""CL model object."""
73        self.lightning_loggers: list[Logger]
74        r"""Lightning logger objects."""
75        self.callbacks: list[Callback]
76        r"""Callback objects."""
77
78        # task-specific components
79        self.optimizer_t: Optimizer
80        r"""Optimizer object for the current task `self.task_id`."""
81        self.lr_scheduler_t: LRScheduler | None = None
82        r"""Learning rate scheduler object for the current task `self.task_id`."""
83        self.trainer_t: Trainer
84        r"""Trainer object for the current task `self.task_id`."""
85
86        # task ID control
87        self.task_id: int
88        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
89        self.processed_task_ids: list[int] = []
90        r"""Task IDs that have been processed."""

Args:

  • cfg (DictConfig): the complete config dict for the continual learning main experiment.
cfg: omegaconf.dictconfig.DictConfig

The complete config dict.

cl_paradigm: str

The continual learning paradigm.

train_tasks: list[int]

The list of task IDs to train.

eval_after_tasks: list[int]

If task ID $t$ is in this list, run the evaluation process for all seen tasks after training task $t$.

global_seed: int

The global seed for the entire experiment.

output_dir: str

The folder for storing the experiment results.

CL dataset object.

Backbone network object.

CL output heads object.

CL model object.

lightning_loggers: list[lightning.pytorch.loggers.logger.Logger]

Lightning logger objects.

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

Callback objects.

optimizer_t: torch.optim.optimizer.Optimizer

Optimizer object for the current task self.task_id.

lr_scheduler_t: torch.optim.lr_scheduler.LRScheduler | None

Learning rate scheduler object for the current task self.task_id.

trainer_t: lightning.pytorch.trainer.trainer.Trainer

Trainer object for the current task self.task_id.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.

processed_task_ids: list[int]

Task IDs that have been processed.

def sanity_check(self) -> None:
 92    def sanity_check(self) -> None:
 93        r"""Sanity check for config."""
 94
 95        # check required config fields
 96        required_config_fields = [
 97            "pipeline",
 98            "expr_name",
 99            "cl_paradigm",
100            "train_tasks",
101            "eval_after_tasks",
102            "global_seed",
103            "cl_dataset",
104            "cl_algorithm",
105            "backbone",
106            "optimizer",
107            "trainer",
108            "metrics",
109            "lightning_loggers",
110            "callbacks",
111            "output_dir",
112            # "hydra" is excluded as it doesn't appear
113            "misc",
114        ]
115        for field in required_config_fields:
116            if not self.cfg.get(field):
117                raise KeyError(
118                    f"Field `{field}` is required in the experiment index config."
119                )
120
121        # check cl_paradigm
122        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
123            raise ValueError(
124                f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
125            )
126
127        # get dataset number of tasks
128        if self.cfg.cl_dataset.get("num_tasks"):
129            num_tasks = self.cfg.cl_dataset.get("num_tasks")
130        elif self.cfg.cl_dataset.get("class_split"):
131            num_tasks = len(self.cfg.cl_dataset.class_split)
132        elif self.cfg.cl_dataset.get("datasets"):
133            num_tasks = len(self.cfg.cl_dataset.datasets)
134        else:
135            raise KeyError(
136                "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
137            )
138
139        # check train_tasks
140        train_tasks = self.cfg.train_tasks
141        if isinstance(train_tasks, ListConfig):
142            if len(train_tasks) < 1:
143                raise ValueError("`train_tasks` config must contain at least one task.")
144            if any(t < 1 or t > num_tasks for t in train_tasks):
145                raise ValueError(
146                    f"All task IDs in `train_tasks` config must be between 1 and {num_tasks}."
147                )
148        elif isinstance(train_tasks, int):
149            if train_tasks < 0 or train_tasks > num_tasks:
150                raise ValueError(
151                    f"`train_tasks` config as integer must be between 0 and {num_tasks}."
152                )
153        else:
154            raise TypeError(
155                "`train_tasks` config must be either a list of integers or an integer."
156            )
157
158        # check eval_after_tasks
159        eval_after_tasks = self.cfg.eval_after_tasks
160        if isinstance(eval_after_tasks, ListConfig):
161            if len(eval_after_tasks) < 1:
162                raise ValueError(
163                    "`eval_after_tasks` config must contain at least one task."
164                )
165            if any(t < 1 or t > num_tasks for t in eval_after_tasks):
166                raise ValueError(
167                    f"All task IDs in `eval_after_tasks` config must be between 1 and {num_tasks}."
168                )
169        elif isinstance(eval_after_tasks, int):
170            if eval_after_tasks < 0 or eval_after_tasks > num_tasks:
171                raise ValueError(
172                    f"`eval_after_tasks` config as integer must be between 0 and {num_tasks}."
173                )
174        else:
175            raise TypeError(
176                "`eval_after_tasks` config must be either a list of integers or an integer."
177            )
178
179        # check that eval_after_tasks is a subset of train_tasks
180        if isinstance(train_tasks, list) and isinstance(eval_after_tasks, list):
181            if not set(eval_after_tasks).issubset(set(train_tasks)):
182                raise ValueError(
183                    "`eval_after_tasks` config must be a subset of `train_tasks` config."
184                )

Sanity check for config.

def instantiate_cl_dataset(self, cl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
186    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
187        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
188        pylogger.debug(
189            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
190            cl_dataset_cfg.get("_target_"),
191        )
192        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
193        pylogger.debug(
194            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
195            cl_dataset_cfg.get("_target_"),
196        )

Instantiate the CL dataset object from cl_dataset_cfg.

def instantiate_backbone( self, backbone_cfg: omegaconf.dictconfig.DictConfig, disable_unlearning: bool) -> None:
198    def instantiate_backbone(
199        self, backbone_cfg: DictConfig, disable_unlearning: bool
200    ) -> None:
201        r"""Instantiate the CL backbone network object from `backbone_cfg`."""
202        pylogger.debug(
203            "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...",
204            backbone_cfg.get("_target_"),
205        )
206        self.backbone = hydra.utils.instantiate(
207            backbone_cfg, disable_unlearning=disable_unlearning
208        )
209        pylogger.debug(
210            "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!",
211            backbone_cfg.get("_target_"),
212        )

Instantiate the CL backbone network object from backbone_cfg.

def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
214    def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
215        r"""Instantiate the CL output heads object.
216
217        **Args:**
218        - **cl_paradigm** (`str`): the CL paradigm, either 'TIL', 'CIL' or 'DIL'. 'TIL' uses `HeadsTIL`, 'CIL' uses `HeadsCIL`, and 'DIL' uses `HeadDIL`.
219        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
220        """
221        pylogger.debug(
222            "CL paradigm is set as %s. Instantiating %s heads...",
223            cl_paradigm,
224            cl_paradigm,
225        )
226        if cl_paradigm == "TIL":
227            self.heads = HeadsTIL(input_dim=input_dim)
228        elif cl_paradigm == "CIL":
229            self.heads = HeadsCIL(input_dim=input_dim)
230        elif cl_paradigm == "DIL":
231            self.heads = HeadDIL(input_dim=input_dim)
232
233        pylogger.debug("%s heads instantiated!", cl_paradigm)

Instantiate the CL output heads object.

Args:

  • cl_paradigm (str): the CL paradigm, either 'TIL', 'CIL' or 'DIL'. 'TIL' uses HeadsTIL, 'CIL' uses HeadsCIL, and 'DIL' uses HeadDIL.
  • input_dim (int): the input dimension of the heads. Must be equal to the output_dim of the connected backbone.
def instantiate_cl_algorithm( self, cl_algorithm_cfg: omegaconf.dictconfig.DictConfig, backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any], disable_unlearning: bool) -> None:
235    def instantiate_cl_algorithm(
236        self,
237        cl_algorithm_cfg: DictConfig,
238        backbone: CLBackbone,
239        heads: HeadsTIL | HeadsCIL | HeadDIL,
240        non_algorithmic_hparams: dict[str, Any],
241        disable_unlearning: bool,
242    ) -> None:
243        r"""Instantiate the cl_algorithm object from `cl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
244        pylogger.debug(
245            "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...",
246            cl_algorithm_cfg.get("_target_"),
247            cl_algorithm_cfg.get("_target_"),
248        )
249        self.model = hydra.utils.instantiate(
250            cl_algorithm_cfg,
251            backbone=backbone,
252            heads=heads,
253            non_algorithmic_hparams=non_algorithmic_hparams,
254            disable_unlearning=disable_unlearning,
255        )
256        pylogger.debug(
257            "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!",
258            cl_algorithm_cfg.get("_target_"),
259        )

Instantiate the cl_algorithm object from cl_algorithm_cfg, backbone, heads and non_algorithmic_hparams.

def instantiate_optimizer( self, optimizer_cfg: omegaconf.dictconfig.DictConfig, task_id: int) -> None:
261    def instantiate_optimizer(
262        self,
263        optimizer_cfg: DictConfig,
264        task_id: int,
265    ) -> None:
266        r"""Instantiate the optimizer object for task `task_id` from `optimizer_cfg`."""
267
268        # distinguish whether the optimizer config is uniform or task-specific
269        if not optimizer_cfg.get("_target_"):
270            pylogger.debug("Distinct optimizer config is applied to each task.")
271            optimizer_cfg = optimizer_cfg[task_id]
272        else:
273            pylogger.debug("Uniform optimizer config is applied to all tasks.")
274
275        # partially instantiate optimizer as the 'params' argument from Lightning Modules cannot be passed for now
276        pylogger.debug(
277            "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...",
278            optimizer_cfg.get("_target_"),
279            task_id,
280        )
281        self.optimizer_t = hydra.utils.instantiate(optimizer_cfg)
282        pylogger.debug(
283            "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!",
284            optimizer_cfg.get("_target_"),
285            task_id,
286        )

Instantiate the optimizer object for task task_id from optimizer_cfg.

def instantiate_lr_scheduler( self, lr_scheduler_cfg: omegaconf.dictconfig.DictConfig, task_id: int) -> None:
288    def instantiate_lr_scheduler(
289        self,
290        lr_scheduler_cfg: DictConfig,
291        task_id: int,
292    ) -> None:
293        r"""Instantiate the learning rate scheduler object for task `task_id` from `lr_scheduler_cfg`."""
294
295        # distinguish whether the learning rate scheduler config is uniform or task-specific
296        if not lr_scheduler_cfg.get("_target_"):
297            pylogger.debug(
298                "Distinct learning rate scheduler config is applied to each task."
299            )
300            lr_scheduler_cfg = lr_scheduler_cfg[task_id]
301        else:
302            pylogger.debug(
303                "Uniform learning rate scheduler config is applied to all tasks."
304            )
305
306        # partially instantiate learning rate scheduler as the 'optimizer' argument from Lightning Modules cannot be passed for now
307        pylogger.debug(
308            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) for task %d...",
309            lr_scheduler_cfg.get("_target_"),
310            task_id,
311        )
312        self.lr_scheduler_t = hydra.utils.instantiate(lr_scheduler_cfg)
313        pylogger.debug(
314            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially for task %d instantiated!",
315            lr_scheduler_cfg.get("_target_"),
316            task_id,
317        )

Instantiate the learning rate scheduler object for task task_id from lr_scheduler_cfg.

def instantiate_lightning_loggers(self, lightning_loggers_cfg: omegaconf.dictconfig.DictConfig) -> None:
319    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
320        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
321
322        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
323        self.lightning_loggers = [
324            hydra.utils.instantiate(lightning_logger)
325            for lightning_logger in lightning_loggers_cfg.values()
326        ]
327        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")

Instantiate the list of lightning loggers objects from lightning_loggers_cfg.

def instantiate_callbacks( self, metrics_cfg: omegaconf.listconfig.ListConfig, callbacks_cfg: omegaconf.listconfig.ListConfig) -> None:
329    def instantiate_callbacks(
330        self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig
331    ) -> None:
332        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
333        pylogger.debug(
334            "Instantiating callbacks (lightning.Callback)...",
335        )
336
337        # instantiate metric callbacks
338        metric_callbacks = [
339            hydra.utils.instantiate(callback) for callback in metrics_cfg
340        ]
341
342        # instantiate other callbacks
343        other_callbacks = [
344            hydra.utils.instantiate(callback) for callback in callbacks_cfg
345        ]
346
347        # add metric callbacks to the list of callbacks
348        self.callbacks = metric_callbacks + other_callbacks
349        pylogger.debug("Callbacks (lightning.Callback) instantiated!")

Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.

def instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, lightning_loggers: list[lightning.pytorch.loggers.logger.Logger], callbacks: list[lightning.pytorch.callbacks.callback.Callback], task_id: int) -> None:
351    def instantiate_trainer(
352        self,
353        trainer_cfg: DictConfig,
354        lightning_loggers: list[Logger],
355        callbacks: list[Callback],
356        task_id: int,
357    ) -> None:
358        r"""Instantiate the trainer object for task `task_id` from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
359
360        if not trainer_cfg.get("_target_"):
361            pylogger.debug("Distinct trainer config is applied to each task.")
362            trainer_cfg = trainer_cfg[task_id]
363        else:
364            pylogger.debug("Uniform trainer config is applied to all tasks.")
365
366        pylogger.debug(
367            "Instantiating trainer (lightning.Trainer) for task %d...",
368            task_id,
369        )
370        self.trainer_t = hydra.utils.instantiate(
371            trainer_cfg,
372            logger=lightning_loggers,
373            callbacks=callbacks,
374        )
375        pylogger.debug(
376            "Trainer (lightning.Trainer) for task %d instantiated!",
377            task_id,
378        )

Instantiate the trainer object for task task_id from trainer_cfg, lightning_loggers, and callbacks.

def set_global_seed(self, global_seed: int) -> None:
380    def set_global_seed(self, global_seed: int) -> None:
381        r"""Set the `global_seed` for the entire experiment."""
382        L.seed_everything(self.global_seed, workers=True)
383        pylogger.debug("Global seed is set as %d.", global_seed)

Set the global_seed for the entire experiment.

def run(self) -> None:
385    def run(self) -> None:
386        r"""The main method to run the continual learning main experiment."""
387
388        self.set_global_seed(self.global_seed)
389
390        # global components
391        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
392        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
393        self.instantiate_backbone(
394            backbone_cfg=self.cfg.backbone, disable_unlearning=True
395        )
396        self.instantiate_heads(
397            cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim
398        )
399        self.instantiate_cl_algorithm(
400            cl_algorithm_cfg=self.cfg.cl_algorithm,
401            backbone=self.backbone,
402            heads=self.heads,
403            non_algorithmic_hparams=select_hyperparameters_from_config(
404                cfg=self.cfg, type=self.cfg.pipeline
405            ),
406            disable_unlearning=True,
407        )  # cl_algorithm should be instantiated after backbone and heads
408        self.instantiate_lightning_loggers(
409            lightning_loggers_cfg=self.cfg.lightning_loggers
410        )
411        self.instantiate_callbacks(
412            metrics_cfg=self.cfg.metrics,
413            callbacks_cfg=self.cfg.callbacks,
414        )
415
416        # task loop
417        for task_id in self.train_tasks:
418
419            self.task_id = task_id
420
421            # task-specific components
422            self.instantiate_optimizer(
423                optimizer_cfg=self.cfg.optimizer,
424                task_id=task_id,
425            )
426            if self.cfg.get("lr_scheduler"):
427                self.instantiate_lr_scheduler(
428                    lr_scheduler_cfg=self.cfg.lr_scheduler,
429                    task_id=task_id,
430                )
431            self.instantiate_trainer(
432                trainer_cfg=self.cfg.trainer,
433                lightning_loggers=self.lightning_loggers,
434                callbacks=self.callbacks,
435                task_id=task_id,
436            )  # trainer should be instantiated after lightning loggers and callbacks
437
438            # setup task ID for dataset and model
439            self.cl_dataset.setup_task_id(task_id=task_id)
440            self.model.setup_task_id(
441                task_id=task_id,
442                num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)),
443                optimizer=self.optimizer_t,
444                lr_scheduler=self.lr_scheduler_t,
445            )
446
447            # train and validate the model
448            self.trainer_t.fit(
449                model=self.model,
450                datamodule=self.cl_dataset,
451            )
452
453            # evaluation after training and validation
454            if task_id in self.eval_after_tasks:
455                self.trainer_t.test(
456                    model=self.model,
457                    datamodule=self.cl_dataset,
458                )
459
460            self.processed_task_ids.append(task_id)

The main method to run the continual learning main experiment.