clarena

Welcome to CLArena

CLArena (Continual Learning Arena) is a open-source Python package for Continual Learning (CL) research. In this package, we provide a integrated environment and various APIs to conduct CL experiments for research purposes, as well as implemented CL algorithms and datasets that you can give it a spin immediately.

Please note that this is an API documantation providing detailed information about the available classes, functions, and modules in CLArena. Please refer to the main documentation and my beginners' guide to continual learning for more intuitive tutorials, examples, and guides on how to use CLArena:

We provide various components of continual learning system in the submodules:

  • clarena.cl_datasets: Continual learning datasets.
  • clarena.backbones: Neural network architectures used as backbones for CL algorithms.
  • clarena.cl_heads: Multi-head classifiers for continual learning outputs. Task-Incremental Learning (TIL) head and Class-Incremental Learning (CIL) head are included.
  • clarena.cl_algorithms: Implementation of various continual learning algorithms.
  • clarena.callbacks: Extra actions added in the continual learning process.
  • utils: Utility functions for continual learning experiments.

As well as the base class in the outmost directory of the package:

  • CLExperiment: The base class for continual learning experiments.
 1r"""
 2
 3# Welcome to CLArena
 4
 5**CLArena (Continual Learning Arena)** is a open-source Python package for Continual Learning (CL) research. In this package, we provide a integrated environment and various APIs to conduct CL experiments for research purposes, as well as implemented CL algorithms and datasets that you can give it a spin immediately.
 6
 7Please note that this is an API documantation providing detailed information about the available classes, functions, and modules in CLArena. Please refer to the main documentation and my beginners' guide to continual learning for more intuitive tutorials, examples, and guides on how to use CLArena:
 8
 9- [**Main Documentation**](https://pengxiang-wang.com/projects/continual-learning-arena)
10- [**A Beginners' Guide to Continual Learning**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide)
11
12We provide various components of continual learning system in the submodules:
13
14- `clarena.cl_datasets`: Continual learning datasets.
15- `clarena.backbones`: Neural network architectures used as backbones for CL algorithms.
16- `clarena.cl_heads`: Multi-head classifiers for continual learning outputs. Task-Incremental Learning (TIL) head and Class-Incremental Learning (CIL) head are included.
17- `clarena.cl_algorithms`: Implementation of various continual learning algorithms.
18- `clarena.callbacks`: Extra actions added in the continual learning process.
19- `utils`: Utility functions for continual learning experiments.
20
21As well as the base class in the outmost directory of the package:
22
23- `CLExperiment`: The base class for continual learning experiments.
24
25"""
26
27from .base import CLExperiment
28
29__all__ = [
30    "CLExperiment",
31    "cl_datasets",
32    "backbones",
33    "cl_heads",
34    "cl_algorithms",
35    "callbacks",
36    "utils",
37]
class CLExperiment:
 25class CLExperiment:
 26    r"""The base class for continual learning experiments."""
 27
 28    def __init__(self, cfg: DictConfig) -> None:
 29        r"""Initializes the CL experiment object with a complete configuration.
 30
 31        **Args:**
 32        - **cfg** (`DictConfig`): the complete config dict for the CL experiment.
 33        """
 34        self.cfg: DictConfig = cfg
 35        r"""Store the complete config dict for any future reference."""
 36
 37        self.cl_paradigm: str = cfg.cl_paradigm
 38        r"""Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset."""
 39        self.num_tasks: int = cfg.num_tasks
 40        r"""Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop."""
 41        self.global_seed: int = cfg.global_seed if cfg.get("global_seed") else None
 42        r"""Store the global seed for the entire experiment. Parsed from config and used to seed all random number generators."""
 43        self.test: bool = cfg.test
 44        r"""Store whether to test the model after training and validation. Parsed from config and used in the tasks loop."""
 45        self.output_dir_name: str = cfg.output_dir_name
 46        r"""Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory."""
 47
 48        self.task_id: int
 49        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
 50
 51        self.cl_dataset: CLDataset
 52        r"""CL dataset object. Instantiate in `instantiate_cl_dataset()`."""
 53        self.backbone: CLBackbone
 54        r"""Backbone network object. Instantiate in `instantiate_backbone()`."""
 55        self.heads: HeadsTIL | HeadsCIL
 56        r"""CL output heads object. Instantiate in `instantiate_heads()`."""
 57        self.model: CLAlgorithm
 58        r"""CL model object. Instantiate in `instantiate_cl_algorithm()`."""
 59
 60        self.optimizer: Optimizer
 61        r"""Optimizer object for current task `self.task_id`. Instantiate in `instantiate_optimizer()`."""
 62        self.trainer: Trainer
 63        r"""Trainer object for current task `self.task_id`. Instantiate in `instantiate_trainer()`."""
 64        self.lightning_loggers: list[Logger]
 65        r"""The list of initialised lightning loggers objects for current task `self.task_id`. Instantiate in `instantiate_lightning_loggers()`."""
 66        self.callbacks: list[Callback]
 67        r"""The list of initialised callbacks objects for current task `self.task_id`. Instantiate in `instantiate_callbacks()`."""
 68
 69        CLExperiment.sanity_check(self)
 70
 71    def sanity_check(self) -> None:
 72        r"""Check the sanity of the config dict `self.cfg`.
 73
 74        **Raises:**
 75        - **KeyError**: when required fields in experiment config are missing, including `cl_paradigm`, `num_tasks`, `test`, `output_dir_name`.
 76        - **ValueError**: when the value of `cl_paradigm` is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset.
 77        """
 78        if not self.cfg.get("cl_paradigm"):
 79            raise KeyError(
 80                "Field cl_paradigm should be specified in experiment config!"
 81            )
 82
 83        if self.cfg.cl_paradigm not in ["TIL", "CIL"]:
 84            raise ValueError(
 85                f"Field cl_paradigm should be either 'TIL' or 'CIL' but got {self.cfg.cl_paradigm}!"
 86            )
 87
 88        if not self.cfg.get("num_tasks"):
 89            raise KeyError("Field num_tasks should be specified in experiment config!")
 90
 91        if not self.cfg.cl_dataset.get("num_tasks"):
 92            raise KeyError("Field num_tasks should be specified in cl_dataset config!")
 93
 94        if not self.cfg.num_tasks <= self.cfg.cl_dataset.num_tasks:
 95            raise ValueError(
 96                f"The experiment is set to run {self.cfg.num_tasks} tasks whereas only {self.cfg.cl_dataset.num_tasks} exists in current cl_dataset setting!"
 97            )
 98
 99        if not self.cfg.get("test"):
100            raise KeyError("Field test should be specified in experiment config!")
101
102        if not self.cfg.get("output_dir_name"):
103            raise KeyError(
104                "Field output_dir_name should be specified in experiment config!"
105            )
106
107    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
108        r"""Instantiate the CL dataset object from cl_dataset config.
109
110        **Args:**
111        - **cl_dataset_cfg** (`DictConfig`): the cl_dataset config dict.
112        """
113        pylogger.debug(
114            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
115            cl_dataset_cfg.get("_target_"),
116        )
117        self.cl_dataset: LightningDataModule = hydra.utils.instantiate(cl_dataset_cfg)
118        pylogger.debug(
119            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
120            cl_dataset_cfg.get("_target_"),
121        )
122
123    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
124        r"""Instantiate the CL backbone network object from backbone config.
125
126        **Args:**
127        - **backbone_cfg** (`DictConfig`): the backbone config dict.
128        """
129        pylogger.debug(
130            "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...",
131            backbone_cfg.get("_target_"),
132        )
133        self.backbone: nn.Module = hydra.utils.instantiate(backbone_cfg)
134        pylogger.debug(
135            "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!",
136            backbone_cfg.get("_target_"),
137        )
138
139    def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
140        r"""Instantiate the CL output heads object according to field `cl_paradigm` and backbone `output_dim` in the config.
141
142        **Args:**
143        - **cl_paradigm** (`str`): the CL paradigm, either 'TIL' or 'CIL'.
144        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
145        """
146        pylogger.debug(
147            "CL paradigm is set as %s. Instantiating %s heads (torch.nn.Module)...",
148            cl_paradigm,
149            cl_paradigm,
150        )
151        self.heads: HeadsTIL | HeadsCIL = (
152            HeadsTIL(input_dim=input_dim)
153            if cl_paradigm == "TIL"
154            else HeadsCIL(input_dim=input_dim)
155        )
156        pylogger.debug("%s heads (torch.nn.Module) instantiated! ", cl_paradigm)
157
158    def instantiate_cl_algorithm(self, cl_algorithm_cfg: DictConfig) -> None:
159        r"""Instantiate the cl_algorithm object from cl_algorithm config.
160
161        **Args:**
162        - **cl_algorithm_cfg** (`DictConfig`): the cl_algorithm config dict.
163        """
164        pylogger.debug(
165            "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...",
166            cl_algorithm_cfg.get("_target_"),
167            cl_algorithm_cfg.get("_target_"),
168        )
169        self.model: LightningModule = hydra.utils.instantiate(
170            cl_algorithm_cfg,
171            backbone=self.backbone,
172            heads=self.heads,
173        )
174        pylogger.debug(
175            "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!",
176            cl_algorithm_cfg.get("_target_"),
177        )
178
179    def instantiate_optimizer(
180        self, optimizer_cfg: DictConfig | ListConfig, task_id: int
181    ) -> None:
182        r"""Instantiate the optimizer object for task `task_id` from optimizer config.
183
184        **Args:**
185        - **optimizer_cfg** (`DictConfig` or `ListConfig`): the optimizer config dict. If it's a `ListConfig`, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks.
186        - **task_id** (`int`): the target task ID.
187        """
188        if isinstance(optimizer_cfg, ListConfig):
189            pylogger.debug("Distinct optimizer config is applied to each task.")
190            optimizer_cfg = optimizer_cfg[task_id - 1]
191        elif isinstance(optimizer_cfg, DictConfig):
192            pylogger.debug("Uniform optimizer config is applied to all tasks.")
193
194            # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
195            pylogger.debug(
196                "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...",
197                optimizer_cfg.get("_target_"),
198                task_id,
199            )
200            self.optimizer: Optimizer = hydra.utils.instantiate(optimizer_cfg)
201            pylogger.debug(
202                "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!",
203                optimizer_cfg.get("_target_"),
204                task_id,
205            )
206
207    def instantiate_trainer(self, trainer_cfg: DictConfig, task_id: int) -> None:
208        r"""Instantiate the trainer object for task `task_id` from trainer config.
209
210        **Args:**
211        - **trainer_cfg** (`DictConfig`): the trainer config dict. All tasks share the same trainer config but different objects.
212        - **task_id** (`int`): the target task ID.
213        """
214        pylogger.debug(
215            "Instantiating trainer <%s> (lightning.Trainer) for task %d...",
216            trainer_cfg.get("_target_"),
217            task_id,
218        )
219        self.trainer: Trainer = hydra.utils.instantiate(
220            trainer_cfg, callbacks=self.callbacks, logger=self.lightning_loggers
221        )
222        pylogger.debug(
223            "Trainer <%s> (lightning.Trainer) for task %d instantiated!",
224            trainer_cfg.get("_target_"),
225            task_id,
226        )
227
228    def instantiate_lightning_loggers(
229        self, lightning_loggers_cfg: DictConfig, task_id: int
230    ) -> None:
231        r"""Instantiate the list of lightning loggers objects for task `task_id` from lightning_loggers config.
232
233        **Args:**
234        - **lightning_loggers_cfg** (`DictConfig`): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects.
235        - **task_id** (`int`): the target task ID.
236        """
237        pylogger.debug(
238            "Instantiating Lightning loggers (lightning.Logger) for task %d...", task_id
239        )
240        self.lightning_loggers: list[Logger] = [
241            hydra.utils.instantiate(
242                lightning_logger, version=f"task_{task_id}"
243            )  # change the directory name to "task_" prefix in lightning logs
244            for lightning_logger in lightning_loggers_cfg.values()
245        ]
246        pylogger.debug(
247            "Lightning loggers (lightning.Logger) for task %d instantiated!", task_id
248        )
249
250    def instantiate_callbacks(self, callbacks_cfg: DictConfig, task_id: int) -> None:
251        r"""Instantiate the list of callbacks objects for task `task_id` from callbacks config.
252
253        **Args:**
254        - **callbacks_cfg** (`DictConfig`): the callbacks config dict. All tasks share the same callbacks config but different objects.
255        - **task_id** (`int`): the target task ID.
256        """
257        pylogger.debug(
258            "Instantiating callbacks (lightning.Callback) for task %d...", task_id
259        )
260        self.callbacks: list[Callback] = [
261            hydra.utils.instantiate(callback) for callback in callbacks_cfg.values()
262        ]
263        pylogger.debug(
264            "Callbacks (lightning.Callback) for task %d instantiated!", task_id
265        )
266
267    def set_global_seed(self) -> None:
268        r"""Set the global seed for the entire experiment."""
269        L.seed_everything(self.global_seed, workers=True)
270        pylogger.debug("Global seed is set as %d.", self.global_seed)
271
272    def setup_task_id(self, task_id: int) -> None:
273        r"""Set up current task_id in the beginning of the continual learning process of a new task.
274
275        **Args:**
276        - **task_id** (`int`): current task_id.
277        """
278        self.task_id = task_id
279
280    def instantiate_global(self) -> None:
281        r"""Instantiate global components for the entire CL experiment from `self.cfg`."""
282
283        self.instantiate_cl_dataset(self.cfg.cl_dataset)
284        self.instantiate_backbone(self.cfg.backbone)
285        self.instantiate_heads(self.cfg.cl_paradigm, self.cfg.backbone.output_dim)
286        self.instantiate_cl_algorithm(
287            self.cfg.cl_algorithm
288        )  # cl_algorithm should be instantiated after backbone and heads
289
290    def setup_global(self) -> None:
291        r"""Let CL dataset know the CL paradigm to define its CL class map."""
292        self.set_global_seed()
293        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
294
295    def instantiate_task_specific(self) -> None:
296        r"""Instantiate task-specific components for the current task `self.task_id` from `self.cfg`."""
297
298        self.instantiate_optimizer(self.cfg.optimizer, self.task_id)
299        self.instantiate_callbacks(self.cfg.callbacks, self.task_id)
300        self.instantiate_lightning_loggers(self.cfg.lightning_loggers, self.task_id)
301        self.instantiate_trainer(
302            self.cfg.trainer, self.task_id
303        )  # trainer should be instantiated after loggers and callbacks
304
305    def setup_task_specific(self) -> None:
306        r"""Setup task-specific components to get ready for the current task `self.task_id`."""
307
308        self.cl_dataset.setup_task_id(self.task_id)
309        self.backbone.setup_task_id(self.task_id)
310        self.model.setup_task_id(
311            self.task_id,
312            len(self.cl_dataset.cl_class_map(self.task_id)),
313            self.optimizer,
314        )
315
316        pylogger.debug(
317            "Datamodule, model and loggers are all set up ready for task %d!",
318            self.task_id,
319        )
320
321    def run_task(self) -> None:
322        r"""Fit the model on the current task `self.task_id`. Also test the model if `self.test` is set to True."""
323
324        self.trainer.fit(
325            model=self.model,
326            datamodule=self.cl_dataset,
327        )
328
329        if self.test:
330            # test after training and validation
331            self.trainer.test(
332                model=self.model,
333                datamodule=self.cl_dataset,
334            )
335
336    def run(self) -> None:
337        r"""The main method to run the continual learning experiment."""
338
339        self.instantiate_global()
340        self.setup_global()
341
342        # task loop
343        for task_id in range(1, self.num_tasks + 1):  # task ID counts from 1
344
345            self.setup_task_id(task_id)
346            self.instantiate_task_specific()
347            self.setup_task_specific()
348
349            self.run_task()

The base class for continual learning experiments.

CLExperiment(cfg: omegaconf.dictconfig.DictConfig)
28    def __init__(self, cfg: DictConfig) -> None:
29        r"""Initializes the CL experiment object with a complete configuration.
30
31        **Args:**
32        - **cfg** (`DictConfig`): the complete config dict for the CL experiment.
33        """
34        self.cfg: DictConfig = cfg
35        r"""Store the complete config dict for any future reference."""
36
37        self.cl_paradigm: str = cfg.cl_paradigm
38        r"""Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset."""
39        self.num_tasks: int = cfg.num_tasks
40        r"""Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop."""
41        self.global_seed: int = cfg.global_seed if cfg.get("global_seed") else None
42        r"""Store the global seed for the entire experiment. Parsed from config and used to seed all random number generators."""
43        self.test: bool = cfg.test
44        r"""Store whether to test the model after training and validation. Parsed from config and used in the tasks loop."""
45        self.output_dir_name: str = cfg.output_dir_name
46        r"""Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory."""
47
48        self.task_id: int
49        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
50
51        self.cl_dataset: CLDataset
52        r"""CL dataset object. Instantiate in `instantiate_cl_dataset()`."""
53        self.backbone: CLBackbone
54        r"""Backbone network object. Instantiate in `instantiate_backbone()`."""
55        self.heads: HeadsTIL | HeadsCIL
56        r"""CL output heads object. Instantiate in `instantiate_heads()`."""
57        self.model: CLAlgorithm
58        r"""CL model object. Instantiate in `instantiate_cl_algorithm()`."""
59
60        self.optimizer: Optimizer
61        r"""Optimizer object for current task `self.task_id`. Instantiate in `instantiate_optimizer()`."""
62        self.trainer: Trainer
63        r"""Trainer object for current task `self.task_id`. Instantiate in `instantiate_trainer()`."""
64        self.lightning_loggers: list[Logger]
65        r"""The list of initialised lightning loggers objects for current task `self.task_id`. Instantiate in `instantiate_lightning_loggers()`."""
66        self.callbacks: list[Callback]
67        r"""The list of initialised callbacks objects for current task `self.task_id`. Instantiate in `instantiate_callbacks()`."""
68
69        CLExperiment.sanity_check(self)

Initializes the CL experiment object with a complete configuration.

Args:

  • cfg (DictConfig): the complete config dict for the CL experiment.
cfg: omegaconf.dictconfig.DictConfig

Store the complete config dict for any future reference.

cl_paradigm: str

Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset.

num_tasks: int

Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop.

global_seed: int

Store the global seed for the entire experiment. Parsed from config and used to seed all random number generators.

test: bool

Store whether to test the model after training and validation. Parsed from config and used in the tasks loop.

output_dir_name: str

Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

CL dataset object. Instantiate in instantiate_cl_dataset().

Backbone network object. Instantiate in instantiate_backbone().

CL output heads object. Instantiate in instantiate_heads().

CL model object. Instantiate in instantiate_cl_algorithm().

optimizer: torch.optim.optimizer.Optimizer

Optimizer object for current task self.task_id. Instantiate in instantiate_optimizer().

trainer: lightning.pytorch.trainer.trainer.Trainer

Trainer object for current task self.task_id. Instantiate in instantiate_trainer().

lightning_loggers: list[lightning.pytorch.loggers.logger.Logger]

The list of initialised lightning loggers objects for current task self.task_id. Instantiate in instantiate_lightning_loggers().

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

The list of initialised callbacks objects for current task self.task_id. Instantiate in instantiate_callbacks().

def sanity_check(self) -> None:
 71    def sanity_check(self) -> None:
 72        r"""Check the sanity of the config dict `self.cfg`.
 73
 74        **Raises:**
 75        - **KeyError**: when required fields in experiment config are missing, including `cl_paradigm`, `num_tasks`, `test`, `output_dir_name`.
 76        - **ValueError**: when the value of `cl_paradigm` is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset.
 77        """
 78        if not self.cfg.get("cl_paradigm"):
 79            raise KeyError(
 80                "Field cl_paradigm should be specified in experiment config!"
 81            )
 82
 83        if self.cfg.cl_paradigm not in ["TIL", "CIL"]:
 84            raise ValueError(
 85                f"Field cl_paradigm should be either 'TIL' or 'CIL' but got {self.cfg.cl_paradigm}!"
 86            )
 87
 88        if not self.cfg.get("num_tasks"):
 89            raise KeyError("Field num_tasks should be specified in experiment config!")
 90
 91        if not self.cfg.cl_dataset.get("num_tasks"):
 92            raise KeyError("Field num_tasks should be specified in cl_dataset config!")
 93
 94        if not self.cfg.num_tasks <= self.cfg.cl_dataset.num_tasks:
 95            raise ValueError(
 96                f"The experiment is set to run {self.cfg.num_tasks} tasks whereas only {self.cfg.cl_dataset.num_tasks} exists in current cl_dataset setting!"
 97            )
 98
 99        if not self.cfg.get("test"):
100            raise KeyError("Field test should be specified in experiment config!")
101
102        if not self.cfg.get("output_dir_name"):
103            raise KeyError(
104                "Field output_dir_name should be specified in experiment config!"
105            )

Check the sanity of the config dict self.cfg.

Raises:

  • KeyError: when required fields in experiment config are missing, including cl_paradigm, num_tasks, test, output_dir_name.
  • ValueError: when the value of cl_paradigm is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset.
def instantiate_cl_dataset(self, cl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
107    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
108        r"""Instantiate the CL dataset object from cl_dataset config.
109
110        **Args:**
111        - **cl_dataset_cfg** (`DictConfig`): the cl_dataset config dict.
112        """
113        pylogger.debug(
114            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
115            cl_dataset_cfg.get("_target_"),
116        )
117        self.cl_dataset: LightningDataModule = hydra.utils.instantiate(cl_dataset_cfg)
118        pylogger.debug(
119            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
120            cl_dataset_cfg.get("_target_"),
121        )

Instantiate the CL dataset object from cl_dataset config.

Args:

  • cl_dataset_cfg (DictConfig): the cl_dataset config dict.
def instantiate_backbone(self, backbone_cfg: omegaconf.dictconfig.DictConfig) -> None:
123    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
124        r"""Instantiate the CL backbone network object from backbone config.
125
126        **Args:**
127        - **backbone_cfg** (`DictConfig`): the backbone config dict.
128        """
129        pylogger.debug(
130            "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...",
131            backbone_cfg.get("_target_"),
132        )
133        self.backbone: nn.Module = hydra.utils.instantiate(backbone_cfg)
134        pylogger.debug(
135            "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!",
136            backbone_cfg.get("_target_"),
137        )

Instantiate the CL backbone network object from backbone config.

Args:

  • backbone_cfg (DictConfig): the backbone config dict.
def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
139    def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
140        r"""Instantiate the CL output heads object according to field `cl_paradigm` and backbone `output_dim` in the config.
141
142        **Args:**
143        - **cl_paradigm** (`str`): the CL paradigm, either 'TIL' or 'CIL'.
144        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
145        """
146        pylogger.debug(
147            "CL paradigm is set as %s. Instantiating %s heads (torch.nn.Module)...",
148            cl_paradigm,
149            cl_paradigm,
150        )
151        self.heads: HeadsTIL | HeadsCIL = (
152            HeadsTIL(input_dim=input_dim)
153            if cl_paradigm == "TIL"
154            else HeadsCIL(input_dim=input_dim)
155        )
156        pylogger.debug("%s heads (torch.nn.Module) instantiated! ", cl_paradigm)

Instantiate the CL output heads object according to field cl_paradigm and backbone output_dim in the config.

Args:

  • cl_paradigm (str): the CL paradigm, either 'TIL' or 'CIL'.
  • input_dim (int): the input dimension of the heads. Must be equal to the output_dim of the connected backbone.
def instantiate_cl_algorithm(self, cl_algorithm_cfg: omegaconf.dictconfig.DictConfig) -> None:
158    def instantiate_cl_algorithm(self, cl_algorithm_cfg: DictConfig) -> None:
159        r"""Instantiate the cl_algorithm object from cl_algorithm config.
160
161        **Args:**
162        - **cl_algorithm_cfg** (`DictConfig`): the cl_algorithm config dict.
163        """
164        pylogger.debug(
165            "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...",
166            cl_algorithm_cfg.get("_target_"),
167            cl_algorithm_cfg.get("_target_"),
168        )
169        self.model: LightningModule = hydra.utils.instantiate(
170            cl_algorithm_cfg,
171            backbone=self.backbone,
172            heads=self.heads,
173        )
174        pylogger.debug(
175            "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!",
176            cl_algorithm_cfg.get("_target_"),
177        )

Instantiate the cl_algorithm object from cl_algorithm config.

Args:

  • cl_algorithm_cfg (DictConfig): the cl_algorithm config dict.
def instantiate_optimizer( self, optimizer_cfg: omegaconf.dictconfig.DictConfig | omegaconf.listconfig.ListConfig, task_id: int) -> None:
179    def instantiate_optimizer(
180        self, optimizer_cfg: DictConfig | ListConfig, task_id: int
181    ) -> None:
182        r"""Instantiate the optimizer object for task `task_id` from optimizer config.
183
184        **Args:**
185        - **optimizer_cfg** (`DictConfig` or `ListConfig`): the optimizer config dict. If it's a `ListConfig`, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks.
186        - **task_id** (`int`): the target task ID.
187        """
188        if isinstance(optimizer_cfg, ListConfig):
189            pylogger.debug("Distinct optimizer config is applied to each task.")
190            optimizer_cfg = optimizer_cfg[task_id - 1]
191        elif isinstance(optimizer_cfg, DictConfig):
192            pylogger.debug("Uniform optimizer config is applied to all tasks.")
193
194            # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
195            pylogger.debug(
196                "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...",
197                optimizer_cfg.get("_target_"),
198                task_id,
199            )
200            self.optimizer: Optimizer = hydra.utils.instantiate(optimizer_cfg)
201            pylogger.debug(
202                "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!",
203                optimizer_cfg.get("_target_"),
204                task_id,
205            )

Instantiate the optimizer object for task task_id from optimizer config.

Args:

  • optimizer_cfg (DictConfig or ListConfig): the optimizer config dict. If it's a ListConfig, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks.
  • task_id (int): the target task ID.
def instantiate_trainer(self, trainer_cfg: omegaconf.dictconfig.DictConfig, task_id: int) -> None:
207    def instantiate_trainer(self, trainer_cfg: DictConfig, task_id: int) -> None:
208        r"""Instantiate the trainer object for task `task_id` from trainer config.
209
210        **Args:**
211        - **trainer_cfg** (`DictConfig`): the trainer config dict. All tasks share the same trainer config but different objects.
212        - **task_id** (`int`): the target task ID.
213        """
214        pylogger.debug(
215            "Instantiating trainer <%s> (lightning.Trainer) for task %d...",
216            trainer_cfg.get("_target_"),
217            task_id,
218        )
219        self.trainer: Trainer = hydra.utils.instantiate(
220            trainer_cfg, callbacks=self.callbacks, logger=self.lightning_loggers
221        )
222        pylogger.debug(
223            "Trainer <%s> (lightning.Trainer) for task %d instantiated!",
224            trainer_cfg.get("_target_"),
225            task_id,
226        )

Instantiate the trainer object for task task_id from trainer config.

Args:

  • trainer_cfg (DictConfig): the trainer config dict. All tasks share the same trainer config but different objects.
  • task_id (int): the target task ID.
def instantiate_lightning_loggers( self, lightning_loggers_cfg: omegaconf.dictconfig.DictConfig, task_id: int) -> None:
228    def instantiate_lightning_loggers(
229        self, lightning_loggers_cfg: DictConfig, task_id: int
230    ) -> None:
231        r"""Instantiate the list of lightning loggers objects for task `task_id` from lightning_loggers config.
232
233        **Args:**
234        - **lightning_loggers_cfg** (`DictConfig`): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects.
235        - **task_id** (`int`): the target task ID.
236        """
237        pylogger.debug(
238            "Instantiating Lightning loggers (lightning.Logger) for task %d...", task_id
239        )
240        self.lightning_loggers: list[Logger] = [
241            hydra.utils.instantiate(
242                lightning_logger, version=f"task_{task_id}"
243            )  # change the directory name to "task_" prefix in lightning logs
244            for lightning_logger in lightning_loggers_cfg.values()
245        ]
246        pylogger.debug(
247            "Lightning loggers (lightning.Logger) for task %d instantiated!", task_id
248        )

Instantiate the list of lightning loggers objects for task task_id from lightning_loggers config.

Args:

  • lightning_loggers_cfg (DictConfig): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects.
  • task_id (int): the target task ID.
def instantiate_callbacks( self, callbacks_cfg: omegaconf.dictconfig.DictConfig, task_id: int) -> None:
250    def instantiate_callbacks(self, callbacks_cfg: DictConfig, task_id: int) -> None:
251        r"""Instantiate the list of callbacks objects for task `task_id` from callbacks config.
252
253        **Args:**
254        - **callbacks_cfg** (`DictConfig`): the callbacks config dict. All tasks share the same callbacks config but different objects.
255        - **task_id** (`int`): the target task ID.
256        """
257        pylogger.debug(
258            "Instantiating callbacks (lightning.Callback) for task %d...", task_id
259        )
260        self.callbacks: list[Callback] = [
261            hydra.utils.instantiate(callback) for callback in callbacks_cfg.values()
262        ]
263        pylogger.debug(
264            "Callbacks (lightning.Callback) for task %d instantiated!", task_id
265        )

Instantiate the list of callbacks objects for task task_id from callbacks config.

Args:

  • callbacks_cfg (DictConfig): the callbacks config dict. All tasks share the same callbacks config but different objects.
  • task_id (int): the target task ID.
def set_global_seed(self) -> None:
267    def set_global_seed(self) -> None:
268        r"""Set the global seed for the entire experiment."""
269        L.seed_everything(self.global_seed, workers=True)
270        pylogger.debug("Global seed is set as %d.", self.global_seed)

Set the global seed for the entire experiment.

def setup_task_id(self, task_id: int) -> None:
272    def setup_task_id(self, task_id: int) -> None:
273        r"""Set up current task_id in the beginning of the continual learning process of a new task.
274
275        **Args:**
276        - **task_id** (`int`): current task_id.
277        """
278        self.task_id = task_id

Set up current task_id in the beginning of the continual learning process of a new task.

Args:

  • task_id (int): current task_id.
def instantiate_global(self) -> None:
280    def instantiate_global(self) -> None:
281        r"""Instantiate global components for the entire CL experiment from `self.cfg`."""
282
283        self.instantiate_cl_dataset(self.cfg.cl_dataset)
284        self.instantiate_backbone(self.cfg.backbone)
285        self.instantiate_heads(self.cfg.cl_paradigm, self.cfg.backbone.output_dim)
286        self.instantiate_cl_algorithm(
287            self.cfg.cl_algorithm
288        )  # cl_algorithm should be instantiated after backbone and heads

Instantiate global components for the entire CL experiment from self.cfg.

def setup_global(self) -> None:
290    def setup_global(self) -> None:
291        r"""Let CL dataset know the CL paradigm to define its CL class map."""
292        self.set_global_seed()
293        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)

Let CL dataset know the CL paradigm to define its CL class map.

def instantiate_task_specific(self) -> None:
295    def instantiate_task_specific(self) -> None:
296        r"""Instantiate task-specific components for the current task `self.task_id` from `self.cfg`."""
297
298        self.instantiate_optimizer(self.cfg.optimizer, self.task_id)
299        self.instantiate_callbacks(self.cfg.callbacks, self.task_id)
300        self.instantiate_lightning_loggers(self.cfg.lightning_loggers, self.task_id)
301        self.instantiate_trainer(
302            self.cfg.trainer, self.task_id
303        )  # trainer should be instantiated after loggers and callbacks

Instantiate task-specific components for the current task self.task_id from self.cfg.

def setup_task_specific(self) -> None:
305    def setup_task_specific(self) -> None:
306        r"""Setup task-specific components to get ready for the current task `self.task_id`."""
307
308        self.cl_dataset.setup_task_id(self.task_id)
309        self.backbone.setup_task_id(self.task_id)
310        self.model.setup_task_id(
311            self.task_id,
312            len(self.cl_dataset.cl_class_map(self.task_id)),
313            self.optimizer,
314        )
315
316        pylogger.debug(
317            "Datamodule, model and loggers are all set up ready for task %d!",
318            self.task_id,
319        )

Setup task-specific components to get ready for the current task self.task_id.

def run_task(self) -> None:
321    def run_task(self) -> None:
322        r"""Fit the model on the current task `self.task_id`. Also test the model if `self.test` is set to True."""
323
324        self.trainer.fit(
325            model=self.model,
326            datamodule=self.cl_dataset,
327        )
328
329        if self.test:
330            # test after training and validation
331            self.trainer.test(
332                model=self.model,
333                datamodule=self.cl_dataset,
334            )

Fit the model on the current task self.task_id. Also test the model if self.test is set to True.

def run(self) -> None:
336    def run(self) -> None:
337        r"""The main method to run the continual learning experiment."""
338
339        self.instantiate_global()
340        self.setup_global()
341
342        # task loop
343        for task_id in range(1, self.num_tasks + 1):  # task ID counts from 1
344
345            self.setup_task_id(task_id)
346            self.instantiate_task_specific()
347            self.setup_task_specific()
348
349            self.run_task()

The main method to run the continual learning experiment.