clarena

Welcome to CLArena

CLArena (Continual Learning Arena) is a open-source Python package for Continual Learning (CL) research, solely developed by myself. In this package, I provide a integrated environment and various APIs to conduct CL experiments for research purposes, as well as implemented CL algorithms and datasets that you can give it a spin immediately.

Please note that this is an API documantation providing detailed information about the available classes, functions, and modules in CLArena. Please refer to the main documentation and my beginners' guide to continual learning for more intuitive tutorials, examples, and guides on how to use CLArena.

We provide various components of continual learning system in the submodules:

  • clarena.cl_datasets: Continual learning datasets.
  • clarena.backbones: Neural network architectures used as backbones for CL algorithms.
  • clarena.cl_heads: Multi-head classifiers for continual learning outputs. Task-Incremental Learning (TIL) head and Class-Incremental Learning (CIL) head are included.
  • clarena.cl_algorithms: Implementations of various continual learning algorithms.
  • clarena.optimizers: Custom optimizers tailored for continual learning.
  • clarena.callbacks: Extra functionality or action added in the continual learning process.

As well as the base class in the outmost directory of the package:

  • Class CLExperiment: The base class for continual learning experiments.
 1"""
 2
 3# Welcome to CLArena
 4
 5**CLArena (Continual Learning Arena)** is a open-source Python package for Continual Learning (CL) research, solely developed by myself. In this package, I provide a integrated environment and various APIs to conduct CL experiments for research purposes, as well as implemented CL algorithms and datasets that you can give it a spin immediately.
 6
 7Please note that this is an API documantation providing detailed information about the available classes, functions, and modules in CLArena. Please refer to the main documentation and my beginners' guide to continual learning for more intuitive tutorials, examples, and guides on how to use CLArena.
 8
 9- **Main documentation:** [https://pengxiang-wang.com/projects/continual-learning-arena/](https://pengxiang-wang.com/projects/continual-learning-arena/)
10- **A beginners' guide to continual learning:** [https://pengxiang-wang.com/posts/continual-learning-beginners-guide](https://pengxiang-wang.com/posts/continual-learning-beginners-guide)
11
12We provide various components of continual learning system in the submodules:
13
14- `clarena.cl_datasets`: Continual learning datasets.
15- `clarena.backbones`: Neural network architectures used as backbones for CL algorithms.
16- `clarena.cl_heads`: Multi-head classifiers for continual learning outputs. Task-Incremental Learning (TIL) head and Class-Incremental Learning (CIL) head are included.
17- `clarena.cl_algorithms`: Implementations of various continual learning algorithms.
18- `clarena.optimizers`: Custom optimizers tailored for continual learning.
19- `clarena.callbacks`: Extra functionality or action added in the continual learning process.
20
21As well as the base class in the outmost directory of the package:
22
23- Class `CLExperiment`: The base class for continual learning experiments.
24
25"""
26
27from .base import CLExperiment
28
29__all__ = [
30    "CLExperiment",
31    "cl_datasets",
32    "backbones",
33    "cl_heads",
34    "cl_algorithms",
35    "optimizers",
36    "callbacks",
37]
class CLExperiment:
 24class CLExperiment:
 25    """The base class for continual learning experiments."""
 26
 27    def __init__(self, cfg: DictConfig) -> None:
 28        """Initializes the CL experiment object with a complete configuration.
 29
 30        **Args:**
 31        - **cfg** (`DictConfig`): the complete config dict for the CL experiment.
 32        """
 33        self.cfg: DictConfig = cfg
 34        """Store the complete config dict for any future reference."""
 35        self.sanity_check()
 36
 37        self.cl_paradigm: str = cfg.cl_paradigm
 38        """Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset."""
 39        self.num_tasks: int = cfg.num_tasks
 40        """Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop."""
 41        self.test: bool = cfg.test
 42        """Store whether to test the model after training and validation. Parsed from config and used in the tasks loop."""
 43        self.output_dir_name: str = cfg.output_dir_name
 44        """Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory."""
 45
 46        self.task_id: int
 47        """Task ID counter indicating which task is being processed. Self updated during the task loop."""
 48
 49        self.cl_dataset: CLDataset
 50        """CL dataset object. Instantiate in `instantiate_cl_dataset()`."""
 51        self.backbone: CLBackbone
 52        """Backbone network object. Instantiate in `instantiate_backbone()`."""
 53        self.heads: HeadsTIL | HeadsCIL
 54        """CL output heads object. Instantiate in `instantiate_heads()`."""
 55        self.model: CLAlgorithm
 56        """CL model object. Instantiate in `instantiate_cl_algorithm()`."""
 57
 58        self.optimizer: Optimizer
 59        """Optimizer object for current task `self.task_id`. Instantiate in `instantiate_optimizer()`."""
 60        self.trainer: Trainer
 61        """Trainer object for current task `self.task_id`. Instantiate in `instantiate_trainer()`."""
 62        self.lightning_loggers: list[Logger]
 63        """The list of initialised lightning loggers objects for current task `self.task_id`. Instantiate in `instantiate_lightning_loggers()`."""
 64        self.callbacks: list[Callback]
 65        """The list of initialised callbacks objects for current task `self.task_id`. Instantiate in `instantiate_callbacks()`."""
 66
 67    def sanity_check(self) -> None:
 68        """Check the sanity of the config dict `self.cfg`.
 69
 70        **Raises:**
 71        - **KeyError**: when required fields in experiment config are missing, including `cl_paradigm`, `num_tasks`, `test`, `output_dir_name`.
 72        - **ValueError**: when the value of `cl_paradigm` is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset.
 73        """
 74        if not self.cfg.get("cl_paradigm"):
 75            raise KeyError(
 76                "Field cl_paradigm should be specified in experiment config!"
 77            )
 78
 79        if self.cfg.cl_paradigm not in ["TIL", "CIL"]:
 80            raise ValueError(
 81                f"Field cl_paradigm should be either 'TIL' or 'CIL' but got {self.cfg.cl_paradigm}!"
 82            )
 83
 84        if not self.cfg.get("num_tasks"):
 85            raise KeyError("Field num_tasks should be specified in experiment config!")
 86
 87        if not self.cfg.cl_dataset.get("num_tasks"):
 88            raise KeyError("Field num_tasks should be specified in cl_dataset config!")
 89
 90        if not self.cfg.num_tasks <= self.cfg.cl_dataset.num_tasks:
 91            raise ValueError(
 92                f"The experiment is set to run {self.cfg.num_tasks} tasks whereas only {self.cfg.cl_dataset.num_tasks} exists in current cl_dataset setting!"
 93            )
 94
 95        if not self.cfg.get("test"):
 96            raise KeyError("Field test should be specified in experiment config!")
 97
 98        if not self.cfg.get("output_dir_name"):
 99            raise KeyError(
100                "Field output_dir_name should be specified in experiment config!"
101            )
102
103    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
104        """Instantiate the CL dataset object from cl_dataset config.
105
106        **Args:**
107        - **cl_dataset_cfg** (`DictConfig`): the cl_dataset config dict.
108        """
109        pylogger.debug(
110            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
111            cl_dataset_cfg.get("_target_"),
112        )
113        self.cl_dataset: LightningDataModule = hydra.utils.instantiate(cl_dataset_cfg)
114        pylogger.debug(
115            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
116            cl_dataset_cfg.get("_target_"),
117        )
118
119    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
120        """Instantiate the CL backbone network object from backbone config.
121
122        **Args:**
123        - **backbone_cfg** (`DictConfig`): the backbone config dict.
124        """
125        pylogger.debug(
126            "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...",
127            backbone_cfg.get("_target_"),
128        )
129        self.backbone: nn.Module = hydra.utils.instantiate(backbone_cfg)
130        pylogger.debug(
131            "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!",
132            backbone_cfg.get("_target_"),
133        )
134
135    def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
136        """Instantiate the CL output heads object according to field `cl_paradigm` and backbone `output_dim` in the config.
137
138        **Args:**
139        - **cl_paradigm** (`str`): the CL paradigm, either 'TIL' or 'CIL'.
140        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
141        """
142        pylogger.debug(
143            "CL paradigm is set as %s. Instantiating %s heads (torch.nn.Module)...",
144            cl_paradigm,
145            cl_paradigm,
146        )
147        self.heads: HeadsTIL | HeadsCIL = (
148            HeadsTIL(input_dim=input_dim)
149            if cl_paradigm == "TIL"
150            else HeadsCIL(input_dim=input_dim)
151        )
152        pylogger.debug("%s heads (torch.nn.Module) instantiated! ", cl_paradigm)
153
154    def instantiate_cl_algorithm(self, cl_algorithm_cfg: DictConfig) -> None:
155        """Instantiate the cl_algorithm object from cl_algorithm config.
156
157        **Args:**
158        - **cl_algorithm_cfg** (`DictConfig`): the cl_algorithm config dict.
159        """
160        pylogger.debug(
161            "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...",
162            cl_algorithm_cfg.get("_target_"),
163            cl_algorithm_cfg.get("_target_"),
164        )
165        self.model: LightningModule = hydra.utils.instantiate(
166            cl_algorithm_cfg,
167            backbone=self.backbone,
168            heads=self.heads,
169        )
170        pylogger.debug(
171            "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!",
172            cl_algorithm_cfg.get("_target_"),
173        )
174
175    def instantiate_optimizer(
176        self, optimizer_cfg: DictConfig | ListConfig, task_id: int
177    ) -> None:
178        """Instantiate the optimizer object for task `task_id` from optimizer config.
179
180        **Args:**
181        - **optimizer_cfg** (`DictConfig` or `ListConfig`): the optimizer config dict. If it's a `ListConfig`, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks.
182        - **task_id** (`int`): the target task ID.
183        """
184        if isinstance(optimizer_cfg, ListConfig):
185            pylogger.debug("Distinct optimizer config is applied to each task.")
186            optimizer_cfg = optimizer_cfg[task_id - 1]
187        elif isinstance(optimizer_cfg, DictConfig):
188            pylogger.debug("Uniform optimizer config is applied to all tasks.")
189
190            # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
191            pylogger.debug(
192                "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...",
193                optimizer_cfg.get("_target_"),
194                task_id,
195            )
196            self.optimizer: Optimizer = hydra.utils.instantiate(optimizer_cfg)
197            pylogger.debug(
198                "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!",
199                optimizer_cfg.get("_target_"),
200                task_id,
201            )
202
203    def instantiate_trainer(self, trainer_cfg: DictConfig, task_id: int) -> None:
204        """Instantiate the trainer object for task `task_id` from trainer config.
205
206        **Args:**
207        - **trainer_cfg** (`DictConfig`): the trainer config dict. All tasks share the same trainer config but different objects.
208        - **task_id** (`int`): the target task ID.
209        """
210        pylogger.debug(
211            "Instantiating trainer <%s> (lightning.Trainer) for task %d...",
212            trainer_cfg.get("_target_"),
213            task_id,
214        )
215        self.trainer: Trainer = hydra.utils.instantiate(
216            trainer_cfg, callbacks=self.callbacks, logger=self.lightning_loggers
217        )
218        pylogger.debug(
219            "Trainer <%s> (lightning.Trainer) for task %d instantiated!",
220            trainer_cfg.get("_target_"),
221            task_id,
222        )
223
224    def instantiate_lightning_loggers(
225        self, lightning_loggers_cfg: DictConfig, task_id: int
226    ) -> None:
227        """Instantiate the list of lightning loggers objects for task `task_id` from lightning_loggers config.
228
229        **Args:**
230        - **lightning_loggers_cfg** (`DictConfig`): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects.
231        - **task_id** (`int`): the target task ID.
232        """
233        pylogger.debug(
234            "Instantiating Lightning loggers (lightning.Logger) for task %d...", task_id
235        )
236        self.lightning_loggers: list[Logger] = [
237            hydra.utils.instantiate(
238                lightning_logger, version=f"task_{task_id}"
239            )  # change the directory name to "task_" prefix in lightning logs
240            for lightning_logger in lightning_loggers_cfg.values()
241        ]
242        pylogger.debug(
243            "Lightning loggers (lightning.Logger) for task %d instantiated!", task_id
244        )
245
246    def instantiate_callbacks(self, callbacks_cfg: DictConfig, task_id: int) -> None:
247        """Instantiate the list of callbacks objects for task `task_id` from callbacks config.
248
249        **Args:**
250        - **callbacks_cfg** (`DictConfig`): the callbacks config dict. All tasks share the same callbacks config but different objects.
251        - **task_id** (`int`): the target task ID.
252        """
253        pylogger.debug(
254            "Instantiating callbacks (lightning.Callback) for task %d...", task_id
255        )
256        self.callbacks: list[Callback] = [
257            hydra.utils.instantiate(callback) for callback in callbacks_cfg.values()
258        ]
259        pylogger.debug(
260            "Callbacks (lightning.Callback) for task %d instantiated!", task_id
261        )
262
263    def setup_task_id(self, task_id: int) -> None:
264        """Set up current task_id in the beginning of the continual learning process of a new task.
265
266        **Args:**
267        - **task_id** (`int`): current task_id.
268        """
269        self.task_id = task_id
270
271    def instantiate_global(self) -> None:
272        """Instantiate global components for the entire CL experiment from `self.cfg`."""
273
274        self.instantiate_cl_dataset(self.cfg.cl_dataset)
275        self.instantiate_backbone(self.cfg.backbone)
276        self.instantiate_heads(self.cfg.cl_paradigm, self.cfg.backbone.output_dim)
277        self.instantiate_cl_algorithm(
278            self.cfg.cl_algorithm
279        )  # cl_algorithm should be instantiated after backbone and heads
280
281    def setup_global(self) -> None:
282        """Let CL dataset know the CL paradigm to define its CL class map."""
283        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
284
285    def instantiate_task_specific(self) -> None:
286        """Instantiate task-specific components for the current task `self.task_id` from `self.cfg`."""
287
288        self.instantiate_optimizer(self.cfg.optimizer, self.task_id)
289        self.instantiate_callbacks(self.cfg.callbacks, self.task_id)
290        self.instantiate_lightning_loggers(self.cfg.lightning_loggers, self.task_id)
291        self.instantiate_trainer(
292            self.cfg.trainer, self.task_id
293        )  # trainer should be instantiated after loggers and callbacks
294
295    def setup_task_specific(self) -> None:
296        """Setup task-specific components to get ready for the current task `self.task_id`."""
297
298        self.cl_dataset.setup_task_id(self.task_id)
299        self.backbone.setup_task_id(self.task_id)
300        self.model.setup_task_id(
301            self.task_id,
302            len(self.cl_dataset.cl_class_map(self.task_id)),
303            self.optimizer,
304        )
305
306        pylogger.debug(
307            "Datamodule, model and loggers are all set up ready for task %d!",
308            self.task_id,
309        )
310
311    def fit_task(self) -> None:
312        """Fit the model on the current task `self.task_id`. Also test the model if `self.test` is set to True."""
313
314        self.trainer.fit(
315            model=self.model,
316            datamodule=self.cl_dataset,
317        )
318
319        if self.test:
320            # test after train and validate
321            self.trainer.test(
322                model=self.model,
323                datamodule=self.cl_dataset,
324            )
325
326    def fit(self) -> None:
327        """The main method to run the continual learning experiment."""
328
329        self.instantiate_global()
330        self.setup_global()
331
332        # task loop
333        for task_id in range(1, self.num_tasks + 1):  # task ID counts from 1
334
335            self.setup_task_id(task_id)
336            self.instantiate_task_specific()
337            self.setup_task_specific()
338
339            self.fit_task()

The base class for continual learning experiments.

CLExperiment(cfg: omegaconf.dictconfig.DictConfig)
27    def __init__(self, cfg: DictConfig) -> None:
28        """Initializes the CL experiment object with a complete configuration.
29
30        **Args:**
31        - **cfg** (`DictConfig`): the complete config dict for the CL experiment.
32        """
33        self.cfg: DictConfig = cfg
34        """Store the complete config dict for any future reference."""
35        self.sanity_check()
36
37        self.cl_paradigm: str = cfg.cl_paradigm
38        """Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset."""
39        self.num_tasks: int = cfg.num_tasks
40        """Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop."""
41        self.test: bool = cfg.test
42        """Store whether to test the model after training and validation. Parsed from config and used in the tasks loop."""
43        self.output_dir_name: str = cfg.output_dir_name
44        """Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory."""
45
46        self.task_id: int
47        """Task ID counter indicating which task is being processed. Self updated during the task loop."""
48
49        self.cl_dataset: CLDataset
50        """CL dataset object. Instantiate in `instantiate_cl_dataset()`."""
51        self.backbone: CLBackbone
52        """Backbone network object. Instantiate in `instantiate_backbone()`."""
53        self.heads: HeadsTIL | HeadsCIL
54        """CL output heads object. Instantiate in `instantiate_heads()`."""
55        self.model: CLAlgorithm
56        """CL model object. Instantiate in `instantiate_cl_algorithm()`."""
57
58        self.optimizer: Optimizer
59        """Optimizer object for current task `self.task_id`. Instantiate in `instantiate_optimizer()`."""
60        self.trainer: Trainer
61        """Trainer object for current task `self.task_id`. Instantiate in `instantiate_trainer()`."""
62        self.lightning_loggers: list[Logger]
63        """The list of initialised lightning loggers objects for current task `self.task_id`. Instantiate in `instantiate_lightning_loggers()`."""
64        self.callbacks: list[Callback]
65        """The list of initialised callbacks objects for current task `self.task_id`. Instantiate in `instantiate_callbacks()`."""

Initializes the CL experiment object with a complete configuration.

Args:

  • cfg (DictConfig): the complete config dict for the CL experiment.
cfg: omegaconf.dictconfig.DictConfig

Store the complete config dict for any future reference.

cl_paradigm: str

Store the continual learning paradigm, either 'TIL' (Task-Incremental Learning) or 'CIL' (Class-Incremental Learning). Parsed from config and used to instantiate the correct heads object and set up CL dataset.

num_tasks: int

Store the number of tasks to be conducted in this experiment. Parsed from config and used in the tasks loop.

test: bool

Store whether to test the model after training and validation. Parsed from config and used in the tasks loop.

output_dir_name: str

Store the name of the output directory to store the logs and checkpoints. Parsed from config and help any output operation to locate the correct directory.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

CL dataset object. Instantiate in instantiate_cl_dataset().

Backbone network object. Instantiate in instantiate_backbone().

CL output heads object. Instantiate in instantiate_heads().

CL model object. Instantiate in instantiate_cl_algorithm().

optimizer: torch.optim.optimizer.Optimizer

Optimizer object for current task self.task_id. Instantiate in instantiate_optimizer().

trainer: lightning.pytorch.trainer.trainer.Trainer

Trainer object for current task self.task_id. Instantiate in instantiate_trainer().

lightning_loggers: list[lightning.pytorch.loggers.logger.Logger]

The list of initialised lightning loggers objects for current task self.task_id. Instantiate in instantiate_lightning_loggers().

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

The list of initialised callbacks objects for current task self.task_id. Instantiate in instantiate_callbacks().

def sanity_check(self) -> None:
 67    def sanity_check(self) -> None:
 68        """Check the sanity of the config dict `self.cfg`.
 69
 70        **Raises:**
 71        - **KeyError**: when required fields in experiment config are missing, including `cl_paradigm`, `num_tasks`, `test`, `output_dir_name`.
 72        - **ValueError**: when the value of `cl_paradigm` is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset.
 73        """
 74        if not self.cfg.get("cl_paradigm"):
 75            raise KeyError(
 76                "Field cl_paradigm should be specified in experiment config!"
 77            )
 78
 79        if self.cfg.cl_paradigm not in ["TIL", "CIL"]:
 80            raise ValueError(
 81                f"Field cl_paradigm should be either 'TIL' or 'CIL' but got {self.cfg.cl_paradigm}!"
 82            )
 83
 84        if not self.cfg.get("num_tasks"):
 85            raise KeyError("Field num_tasks should be specified in experiment config!")
 86
 87        if not self.cfg.cl_dataset.get("num_tasks"):
 88            raise KeyError("Field num_tasks should be specified in cl_dataset config!")
 89
 90        if not self.cfg.num_tasks <= self.cfg.cl_dataset.num_tasks:
 91            raise ValueError(
 92                f"The experiment is set to run {self.cfg.num_tasks} tasks whereas only {self.cfg.cl_dataset.num_tasks} exists in current cl_dataset setting!"
 93            )
 94
 95        if not self.cfg.get("test"):
 96            raise KeyError("Field test should be specified in experiment config!")
 97
 98        if not self.cfg.get("output_dir_name"):
 99            raise KeyError(
100                "Field output_dir_name should be specified in experiment config!"
101            )

Check the sanity of the config dict self.cfg.

Raises:

  • KeyError: when required fields in experiment config are missing, including cl_paradigm, num_tasks, test, output_dir_name.
  • ValueError: when the value of cl_paradigm is not 'TIL' or 'CIL', or when the number of tasks is larger than the number of tasks in the CL dataset.
def instantiate_cl_dataset(self, cl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
103    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
104        """Instantiate the CL dataset object from cl_dataset config.
105
106        **Args:**
107        - **cl_dataset_cfg** (`DictConfig`): the cl_dataset config dict.
108        """
109        pylogger.debug(
110            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
111            cl_dataset_cfg.get("_target_"),
112        )
113        self.cl_dataset: LightningDataModule = hydra.utils.instantiate(cl_dataset_cfg)
114        pylogger.debug(
115            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
116            cl_dataset_cfg.get("_target_"),
117        )

Instantiate the CL dataset object from cl_dataset config.

Args:

  • cl_dataset_cfg (DictConfig): the cl_dataset config dict.
def instantiate_backbone(self, backbone_cfg: omegaconf.dictconfig.DictConfig) -> None:
119    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
120        """Instantiate the CL backbone network object from backbone config.
121
122        **Args:**
123        - **backbone_cfg** (`DictConfig`): the backbone config dict.
124        """
125        pylogger.debug(
126            "Instantiating backbone network <%s> (clarena.backbones.CLBackbone)...",
127            backbone_cfg.get("_target_"),
128        )
129        self.backbone: nn.Module = hydra.utils.instantiate(backbone_cfg)
130        pylogger.debug(
131            "Backbone network <%s> (clarena.backbones.CLBackbone) instantiated!",
132            backbone_cfg.get("_target_"),
133        )

Instantiate the CL backbone network object from backbone config.

Args:

  • backbone_cfg (DictConfig): the backbone config dict.
def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
135    def instantiate_heads(self, cl_paradigm: str, input_dim: int) -> None:
136        """Instantiate the CL output heads object according to field `cl_paradigm` and backbone `output_dim` in the config.
137
138        **Args:**
139        - **cl_paradigm** (`str`): the CL paradigm, either 'TIL' or 'CIL'.
140        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
141        """
142        pylogger.debug(
143            "CL paradigm is set as %s. Instantiating %s heads (torch.nn.Module)...",
144            cl_paradigm,
145            cl_paradigm,
146        )
147        self.heads: HeadsTIL | HeadsCIL = (
148            HeadsTIL(input_dim=input_dim)
149            if cl_paradigm == "TIL"
150            else HeadsCIL(input_dim=input_dim)
151        )
152        pylogger.debug("%s heads (torch.nn.Module) instantiated! ", cl_paradigm)

Instantiate the CL output heads object according to field cl_paradigm and backbone output_dim in the config.

Args:

  • cl_paradigm (str): the CL paradigm, either 'TIL' or 'CIL'.
  • input_dim (int): the input dimension of the heads. Must be equal to the output_dim of the connected backbone.
def instantiate_cl_algorithm(self, cl_algorithm_cfg: omegaconf.dictconfig.DictConfig) -> None:
154    def instantiate_cl_algorithm(self, cl_algorithm_cfg: DictConfig) -> None:
155        """Instantiate the cl_algorithm object from cl_algorithm config.
156
157        **Args:**
158        - **cl_algorithm_cfg** (`DictConfig`): the cl_algorithm config dict.
159        """
160        pylogger.debug(
161            "CL algorithm is set as <%s>. Instantiating <%s> (clarena.cl_algorithms.CLAlgorithm)...",
162            cl_algorithm_cfg.get("_target_"),
163            cl_algorithm_cfg.get("_target_"),
164        )
165        self.model: LightningModule = hydra.utils.instantiate(
166            cl_algorithm_cfg,
167            backbone=self.backbone,
168            heads=self.heads,
169        )
170        pylogger.debug(
171            "<%s> (clarena.cl_algorithms.CLAlgorithm) instantiated!",
172            cl_algorithm_cfg.get("_target_"),
173        )

Instantiate the cl_algorithm object from cl_algorithm config.

Args:

  • cl_algorithm_cfg (DictConfig): the cl_algorithm config dict.
def instantiate_optimizer( self, optimizer_cfg: omegaconf.dictconfig.DictConfig | omegaconf.listconfig.ListConfig, task_id: int) -> None:
175    def instantiate_optimizer(
176        self, optimizer_cfg: DictConfig | ListConfig, task_id: int
177    ) -> None:
178        """Instantiate the optimizer object for task `task_id` from optimizer config.
179
180        **Args:**
181        - **optimizer_cfg** (`DictConfig` or `ListConfig`): the optimizer config dict. If it's a `ListConfig`, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks.
182        - **task_id** (`int`): the target task ID.
183        """
184        if isinstance(optimizer_cfg, ListConfig):
185            pylogger.debug("Distinct optimizer config is applied to each task.")
186            optimizer_cfg = optimizer_cfg[task_id - 1]
187        elif isinstance(optimizer_cfg, DictConfig):
188            pylogger.debug("Uniform optimizer config is applied to all tasks.")
189
190            # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
191            pylogger.debug(
192                "Partially instantiating optimizer <%s> (torch.optim.Optimizer) for task %d...",
193                optimizer_cfg.get("_target_"),
194                task_id,
195            )
196            self.optimizer: Optimizer = hydra.utils.instantiate(optimizer_cfg)
197            pylogger.debug(
198                "Optimizer <%s> (torch.optim.Optimizer) partially for task %d instantiated!",
199                optimizer_cfg.get("_target_"),
200                task_id,
201            )

Instantiate the optimizer object for task task_id from optimizer config.

Args:

  • optimizer_cfg (DictConfig or ListConfig): the optimizer config dict. If it's a ListConfig, it should contain optimizer config for each task; otherwise, it's an uniform optimizer config for all tasks.
  • task_id (int): the target task ID.
def instantiate_trainer(self, trainer_cfg: omegaconf.dictconfig.DictConfig, task_id: int) -> None:
203    def instantiate_trainer(self, trainer_cfg: DictConfig, task_id: int) -> None:
204        """Instantiate the trainer object for task `task_id` from trainer config.
205
206        **Args:**
207        - **trainer_cfg** (`DictConfig`): the trainer config dict. All tasks share the same trainer config but different objects.
208        - **task_id** (`int`): the target task ID.
209        """
210        pylogger.debug(
211            "Instantiating trainer <%s> (lightning.Trainer) for task %d...",
212            trainer_cfg.get("_target_"),
213            task_id,
214        )
215        self.trainer: Trainer = hydra.utils.instantiate(
216            trainer_cfg, callbacks=self.callbacks, logger=self.lightning_loggers
217        )
218        pylogger.debug(
219            "Trainer <%s> (lightning.Trainer) for task %d instantiated!",
220            trainer_cfg.get("_target_"),
221            task_id,
222        )

Instantiate the trainer object for task task_id from trainer config.

Args:

  • trainer_cfg (DictConfig): the trainer config dict. All tasks share the same trainer config but different objects.
  • task_id (int): the target task ID.
def instantiate_lightning_loggers( self, lightning_loggers_cfg: omegaconf.dictconfig.DictConfig, task_id: int) -> None:
224    def instantiate_lightning_loggers(
225        self, lightning_loggers_cfg: DictConfig, task_id: int
226    ) -> None:
227        """Instantiate the list of lightning loggers objects for task `task_id` from lightning_loggers config.
228
229        **Args:**
230        - **lightning_loggers_cfg** (`DictConfig`): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects.
231        - **task_id** (`int`): the target task ID.
232        """
233        pylogger.debug(
234            "Instantiating Lightning loggers (lightning.Logger) for task %d...", task_id
235        )
236        self.lightning_loggers: list[Logger] = [
237            hydra.utils.instantiate(
238                lightning_logger, version=f"task_{task_id}"
239            )  # change the directory name to "task_" prefix in lightning logs
240            for lightning_logger in lightning_loggers_cfg.values()
241        ]
242        pylogger.debug(
243            "Lightning loggers (lightning.Logger) for task %d instantiated!", task_id
244        )

Instantiate the list of lightning loggers objects for task task_id from lightning_loggers config.

Args:

  • lightning_loggers_cfg (DictConfig): the lightning_loggers config dict. All tasks share the same lightning_loggers config but different objects.
  • task_id (int): the target task ID.
def instantiate_callbacks( self, callbacks_cfg: omegaconf.dictconfig.DictConfig, task_id: int) -> None:
246    def instantiate_callbacks(self, callbacks_cfg: DictConfig, task_id: int) -> None:
247        """Instantiate the list of callbacks objects for task `task_id` from callbacks config.
248
249        **Args:**
250        - **callbacks_cfg** (`DictConfig`): the callbacks config dict. All tasks share the same callbacks config but different objects.
251        - **task_id** (`int`): the target task ID.
252        """
253        pylogger.debug(
254            "Instantiating callbacks (lightning.Callback) for task %d...", task_id
255        )
256        self.callbacks: list[Callback] = [
257            hydra.utils.instantiate(callback) for callback in callbacks_cfg.values()
258        ]
259        pylogger.debug(
260            "Callbacks (lightning.Callback) for task %d instantiated!", task_id
261        )

Instantiate the list of callbacks objects for task task_id from callbacks config.

Args:

  • callbacks_cfg (DictConfig): the callbacks config dict. All tasks share the same callbacks config but different objects.
  • task_id (int): the target task ID.
def setup_task_id(self, task_id: int) -> None:
263    def setup_task_id(self, task_id: int) -> None:
264        """Set up current task_id in the beginning of the continual learning process of a new task.
265
266        **Args:**
267        - **task_id** (`int`): current task_id.
268        """
269        self.task_id = task_id

Set up current task_id in the beginning of the continual learning process of a new task.

Args:

  • task_id (int): current task_id.
def instantiate_global(self) -> None:
271    def instantiate_global(self) -> None:
272        """Instantiate global components for the entire CL experiment from `self.cfg`."""
273
274        self.instantiate_cl_dataset(self.cfg.cl_dataset)
275        self.instantiate_backbone(self.cfg.backbone)
276        self.instantiate_heads(self.cfg.cl_paradigm, self.cfg.backbone.output_dim)
277        self.instantiate_cl_algorithm(
278            self.cfg.cl_algorithm
279        )  # cl_algorithm should be instantiated after backbone and heads

Instantiate global components for the entire CL experiment from self.cfg.

def setup_global(self) -> None:
281    def setup_global(self) -> None:
282        """Let CL dataset know the CL paradigm to define its CL class map."""
283        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)

Let CL dataset know the CL paradigm to define its CL class map.

def instantiate_task_specific(self) -> None:
285    def instantiate_task_specific(self) -> None:
286        """Instantiate task-specific components for the current task `self.task_id` from `self.cfg`."""
287
288        self.instantiate_optimizer(self.cfg.optimizer, self.task_id)
289        self.instantiate_callbacks(self.cfg.callbacks, self.task_id)
290        self.instantiate_lightning_loggers(self.cfg.lightning_loggers, self.task_id)
291        self.instantiate_trainer(
292            self.cfg.trainer, self.task_id
293        )  # trainer should be instantiated after loggers and callbacks

Instantiate task-specific components for the current task self.task_id from self.cfg.

def setup_task_specific(self) -> None:
295    def setup_task_specific(self) -> None:
296        """Setup task-specific components to get ready for the current task `self.task_id`."""
297
298        self.cl_dataset.setup_task_id(self.task_id)
299        self.backbone.setup_task_id(self.task_id)
300        self.model.setup_task_id(
301            self.task_id,
302            len(self.cl_dataset.cl_class_map(self.task_id)),
303            self.optimizer,
304        )
305
306        pylogger.debug(
307            "Datamodule, model and loggers are all set up ready for task %d!",
308            self.task_id,
309        )

Setup task-specific components to get ready for the current task self.task_id.

def fit_task(self) -> None:
311    def fit_task(self) -> None:
312        """Fit the model on the current task `self.task_id`. Also test the model if `self.test` is set to True."""
313
314        self.trainer.fit(
315            model=self.model,
316            datamodule=self.cl_dataset,
317        )
318
319        if self.test:
320            # test after train and validate
321            self.trainer.test(
322                model=self.model,
323                datamodule=self.cl_dataset,
324            )

Fit the model on the current task self.task_id. Also test the model if self.test is set to True.

def fit(self) -> None:
326    def fit(self) -> None:
327        """The main method to run the continual learning experiment."""
328
329        self.instantiate_global()
330        self.setup_global()
331
332        # task loop
333        for task_id in range(1, self.num_tasks + 1):  # task ID counts from 1
334
335            self.setup_task_id(task_id)
336            self.instantiate_task_specific()
337            self.setup_task_specific()
338
339            self.fit_task()

The main method to run the continual learning experiment.