clarena.cl_algorithms

Continual Learning Algorithms

This submodule provides the continual learning algorithms in CLArena.

Please note that this is an API documantation. Please refer to the main documentation pages for more information about the backbone networks and how to configure and implement them:

The algorithms are implemented as subclasses of CLAlgorithm.

 1r"""
 2
 3# Continual Learning Algorithms
 4
 5This submodule provides the **continual learning algorithms** in CLArena. 
 6
 7Please note that this is an API documantation. Please refer to the main documentation pages for more information about the backbone networks and how to configure and implement them:
 8
 9- [**Configure CL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiment/cl-algorithm)
10- [**Implement Your CL Algorithm Class**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-algorithm)
11- [**A Beginners' Guide to Continual Learning (Methodology Overview)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-methodology)
12
13
14The algorithms are implemented as subclasses of `CLAlgorithm`.
15
16"""
17
18from .base import CLAlgorithm
19
20# finetuning first
21from .finetuning import Finetuning
22from .fix import Fix
23
24from .lwf import LwF
25from .ewc import EWC
26from .cbp import CBP
27
28from .hat import HAT
29from .adahat import AdaHAT
30from .cbphat import CBPHAT
31
32
33__all__ = [
34    "CLAlgorithm",
35    "regularisers",
36    "finetuning",
37    "fix",
38    "lwf",
39    "ewc",
40    "hat",
41    "cbp",
42    "adahat",
43    "cbphat",
44]
class CLAlgorithm(lightning.pytorch.core.module.LightningModule):
21class CLAlgorithm(LightningModule):
22    r"""The base class of continual learning algorithms, inherited from `LightningModule`."""
23
24    def __init__(
25        self,
26        backbone: CLBackbone,
27        heads: HeadsTIL | HeadsCIL,
28    ) -> None:
29        r"""Initialise the CL algorithm with the network.
30
31        **Args:**
32        - **backbone** (`CLBackbone`): backbone network.
33        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
34        """
35        LightningModule.__init__(self)
36
37        self.backbone: CLBackbone = backbone
38        r"""Store the backbone network."""
39        self.heads: HeadsTIL | HeadsCIL = heads
40        r"""Store the output heads."""
41        self.optimizer: Optimizer
42        r"""Store the optimizer object (partially initialised) for the backpropagation of task `self.task_id`. Will be equipped with parameters in `configure_optimizers()`."""
43        self.criterion = nn.CrossEntropyLoss()
44        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
45
46        self.task_id: int
47        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
48
49        CLAlgorithm.sanity_check(self)
50
51    def sanity_check(self) -> None:
52        r"""Check the sanity of the arguments.
53
54        **Raises:**
55        - **ValueError**: if the `output_dim` of backbone network is not equal to the `input_dim` of CL heads.
56        """
57        if self.backbone.output_dim != self.heads.input_dim:
58            raise ValueError(
59                "The output_dim of backbone network should be equal to the input_dim of CL heads!"
60            )
61
62    def setup_task_id(
63        self, task_id: int, num_classes_t: int, optimizer: Optimizer
64    ) -> None:
65        r"""Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called.
66
67        **Args:**
68        - **task_id** (`int`): the target task ID.
69        - **num_classes_t** (`int`): the number of classes in the task.
70        - **optimizer** (`Optimizer`): the optimizer object (partially initialised) for the task `self.task_id`.
71        """
72        self.task_id = task_id
73        self.heads.setup_task_id(task_id, num_classes_t)
74        self.optimizer = optimizer
75
76    def configure_optimizers(self) -> Optimizer:
77        r"""
78        Configure optimizer hooks by Lightning.
79        See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details.
80        """
81        # finish partially initialised optimizer by specifying model parameters. The `parameters()` method of this `CLAlrogithm` (inherited from `LightningModule`) returns both backbone and heads parameters
82        return self.optimizer(params=self.parameters())

The base class of continual learning algorithms, inherited from LightningModule.

24    def __init__(
25        self,
26        backbone: CLBackbone,
27        heads: HeadsTIL | HeadsCIL,
28    ) -> None:
29        r"""Initialise the CL algorithm with the network.
30
31        **Args:**
32        - **backbone** (`CLBackbone`): backbone network.
33        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
34        """
35        LightningModule.__init__(self)
36
37        self.backbone: CLBackbone = backbone
38        r"""Store the backbone network."""
39        self.heads: HeadsTIL | HeadsCIL = heads
40        r"""Store the output heads."""
41        self.optimizer: Optimizer
42        r"""Store the optimizer object (partially initialised) for the backpropagation of task `self.task_id`. Will be equipped with parameters in `configure_optimizers()`."""
43        self.criterion = nn.CrossEntropyLoss()
44        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
45
46        self.task_id: int
47        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
48
49        CLAlgorithm.sanity_check(self)

Initialise the CL algorithm with the network.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL): output heads.

Store the backbone network.

Store the output heads.

optimizer: torch.optim.optimizer.Optimizer

Store the optimizer object (partially initialised) for the backpropagation of task self.task_id. Will be equipped with parameters in configure_optimizers().

criterion

The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

def sanity_check(self) -> None:
51    def sanity_check(self) -> None:
52        r"""Check the sanity of the arguments.
53
54        **Raises:**
55        - **ValueError**: if the `output_dim` of backbone network is not equal to the `input_dim` of CL heads.
56        """
57        if self.backbone.output_dim != self.heads.input_dim:
58            raise ValueError(
59                "The output_dim of backbone network should be equal to the input_dim of CL heads!"
60            )

Check the sanity of the arguments.

Raises:

  • ValueError: if the output_dim of backbone network is not equal to the input_dim of CL heads.
def setup_task_id( self, task_id: int, num_classes_t: int, optimizer: torch.optim.optimizer.Optimizer) -> None:
62    def setup_task_id(
63        self, task_id: int, num_classes_t: int, optimizer: Optimizer
64    ) -> None:
65        r"""Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called.
66
67        **Args:**
68        - **task_id** (`int`): the target task ID.
69        - **num_classes_t** (`int`): the number of classes in the task.
70        - **optimizer** (`Optimizer`): the optimizer object (partially initialised) for the task `self.task_id`.
71        """
72        self.task_id = task_id
73        self.heads.setup_task_id(task_id, num_classes_t)
74        self.optimizer = optimizer

Set up which task's dataset the CL experiment is on. This must be done before forward() method is called.

Args:

  • task_id (int): the target task ID.
  • num_classes_t (int): the number of classes in the task.
  • optimizer (Optimizer): the optimizer object (partially initialised) for the task self.task_id.
def configure_optimizers(self) -> torch.optim.optimizer.Optimizer:
76    def configure_optimizers(self) -> Optimizer:
77        r"""
78        Configure optimizer hooks by Lightning.
79        See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details.
80        """
81        # finish partially initialised optimizer by specifying model parameters. The `parameters()` method of this `CLAlrogithm` (inherited from `LightningModule`) returns both backbone and heads parameters
82        return self.optimizer(params=self.parameters())

Configure optimizer hooks by Lightning. See Lightning docs for more details.