clarena.cl_algorithms
Continual Learning Algorithms
This submodule provides the continual learning algorithms in CLArena.
Please note that this is an API documantation. Please refer to the main documentation pages for more information about the backbone networks and how to configure and implement them:
- Configure CL Algorithm
- Implement Your CL Algorithm Class
- A Beginners' Guide to Continual Learning (Methodology Overview)
The algorithms are implemented as subclasses of CLAlgorithm
.
1r""" 2 3# Continual Learning Algorithms 4 5This submodule provides the **continual learning algorithms** in CLArena. 6 7Please note that this is an API documantation. Please refer to the main documentation pages for more information about the backbone networks and how to configure and implement them: 8 9- [**Configure CL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiment/cl-algorithm) 10- [**Implement Your CL Algorithm Class**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-algorithm) 11- [**A Beginners' Guide to Continual Learning (Methodology Overview)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-methodology) 12 13 14The algorithms are implemented as subclasses of `CLAlgorithm`. 15 16""" 17 18from .base import CLAlgorithm 19 20# finetuning first 21from .finetuning import Finetuning 22from .fix import Fix 23 24from .lwf import LwF 25from .ewc import EWC 26from .cbp import CBP 27 28from .hat import HAT 29from .adahat import AdaHAT 30from .cbphat import CBPHAT 31 32 33__all__ = [ 34 "CLAlgorithm", 35 "regularisers", 36 "finetuning", 37 "fix", 38 "lwf", 39 "ewc", 40 "hat", 41 "cbp", 42 "adahat", 43 "cbphat", 44]
21class CLAlgorithm(LightningModule): 22 r"""The base class of continual learning algorithms, inherited from `LightningModule`.""" 23 24 def __init__( 25 self, 26 backbone: CLBackbone, 27 heads: HeadsTIL | HeadsCIL, 28 ) -> None: 29 r"""Initialise the CL algorithm with the network. 30 31 **Args:** 32 - **backbone** (`CLBackbone`): backbone network. 33 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 34 """ 35 LightningModule.__init__(self) 36 37 self.backbone: CLBackbone = backbone 38 r"""Store the backbone network.""" 39 self.heads: HeadsTIL | HeadsCIL = heads 40 r"""Store the output heads.""" 41 self.optimizer: Optimizer 42 r"""Store the optimizer object (partially initialised) for the backpropagation of task `self.task_id`. Will be equipped with parameters in `configure_optimizers()`.""" 43 self.criterion = nn.CrossEntropyLoss() 44 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 45 46 self.task_id: int 47 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 48 49 CLAlgorithm.sanity_check(self) 50 51 def sanity_check(self) -> None: 52 r"""Check the sanity of the arguments. 53 54 **Raises:** 55 - **ValueError**: if the `output_dim` of backbone network is not equal to the `input_dim` of CL heads. 56 """ 57 if self.backbone.output_dim != self.heads.input_dim: 58 raise ValueError( 59 "The output_dim of backbone network should be equal to the input_dim of CL heads!" 60 ) 61 62 def setup_task_id( 63 self, task_id: int, num_classes_t: int, optimizer: Optimizer 64 ) -> None: 65 r"""Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called. 66 67 **Args:** 68 - **task_id** (`int`): the target task ID. 69 - **num_classes_t** (`int`): the number of classes in the task. 70 - **optimizer** (`Optimizer`): the optimizer object (partially initialised) for the task `self.task_id`. 71 """ 72 self.task_id = task_id 73 self.heads.setup_task_id(task_id, num_classes_t) 74 self.optimizer = optimizer 75 76 def configure_optimizers(self) -> Optimizer: 77 r""" 78 Configure optimizer hooks by Lightning. 79 See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details. 80 """ 81 # finish partially initialised optimizer by specifying model parameters. The `parameters()` method of this `CLAlrogithm` (inherited from `LightningModule`) returns both backbone and heads parameters 82 return self.optimizer(params=self.parameters())
The base class of continual learning algorithms, inherited from LightningModule
.
24 def __init__( 25 self, 26 backbone: CLBackbone, 27 heads: HeadsTIL | HeadsCIL, 28 ) -> None: 29 r"""Initialise the CL algorithm with the network. 30 31 **Args:** 32 - **backbone** (`CLBackbone`): backbone network. 33 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 34 """ 35 LightningModule.__init__(self) 36 37 self.backbone: CLBackbone = backbone 38 r"""Store the backbone network.""" 39 self.heads: HeadsTIL | HeadsCIL = heads 40 r"""Store the output heads.""" 41 self.optimizer: Optimizer 42 r"""Store the optimizer object (partially initialised) for the backpropagation of task `self.task_id`. Will be equipped with parameters in `configure_optimizers()`.""" 43 self.criterion = nn.CrossEntropyLoss() 44 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 45 46 self.task_id: int 47 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 48 49 CLAlgorithm.sanity_check(self)
Initialise the CL algorithm with the network.
Args:
- backbone (
CLBackbone
): backbone network. - heads (
HeadsTIL
|HeadsCIL
): output heads.
Store the optimizer object (partially initialised) for the backpropagation of task self.task_id
. Will be equipped with parameters in configure_optimizers()
.
The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.
Task ID counter indicating which task is being processed. Self updated during the task loop.
51 def sanity_check(self) -> None: 52 r"""Check the sanity of the arguments. 53 54 **Raises:** 55 - **ValueError**: if the `output_dim` of backbone network is not equal to the `input_dim` of CL heads. 56 """ 57 if self.backbone.output_dim != self.heads.input_dim: 58 raise ValueError( 59 "The output_dim of backbone network should be equal to the input_dim of CL heads!" 60 )
Check the sanity of the arguments.
Raises:
- ValueError: if the
output_dim
of backbone network is not equal to theinput_dim
of CL heads.
62 def setup_task_id( 63 self, task_id: int, num_classes_t: int, optimizer: Optimizer 64 ) -> None: 65 r"""Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called. 66 67 **Args:** 68 - **task_id** (`int`): the target task ID. 69 - **num_classes_t** (`int`): the number of classes in the task. 70 - **optimizer** (`Optimizer`): the optimizer object (partially initialised) for the task `self.task_id`. 71 """ 72 self.task_id = task_id 73 self.heads.setup_task_id(task_id, num_classes_t) 74 self.optimizer = optimizer
Set up which task's dataset the CL experiment is on. This must be done before forward()
method is called.
Args:
- task_id (
int
): the target task ID. - num_classes_t (
int
): the number of classes in the task. - optimizer (
Optimizer
): the optimizer object (partially initialised) for the taskself.task_id
.
76 def configure_optimizers(self) -> Optimizer: 77 r""" 78 Configure optimizer hooks by Lightning. 79 See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details. 80 """ 81 # finish partially initialised optimizer by specifying model parameters. The `parameters()` method of this `CLAlrogithm` (inherited from `LightningModule`) returns both backbone and heads parameters 82 return self.optimizer(params=self.parameters())
Configure optimizer hooks by Lightning. See Lightning docs for more details.