clarena.cl_algorithms
Continual Learning Algorithms
This submodule provides the continual learning algorithms in CLArena.
Please note that this is an API documantation. Please refer to the main documentation page for more information about the backbone networks and how to use and customize them:
- Configure your CL algorithm: https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/cl-algorithm
- Implement your CL algorithm: https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-algorithm
- A beginners' guide to continual learning (CL algorithm): https://pengxiang-wang.com/posts/continual-learning-beginners-guide#methodology
1""" 2 3# Continual Learning Algorithms 4 5This submodule provides the **continual learning algorithms** in CLArena. 6 7Please note that this is an API documantation. Please refer to the main documentation page for more information about the backbone networks and how to use and customize them: 8 9- **Configure your CL algorithm:** [https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/cl-algorithm](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/cl-algorithm) 10- **Implement your CL algorithm:** [https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-algorithm](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/cl-algorithm) 11- **A beginners' guide to continual learning (CL algorithm):** [https://pengxiang-wang.com/posts/continual-learning-beginners-guide#methodology](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#methodology) 12 13""" 14 15from .base import CLAlgorithm 16from .finetuning import Finetuning 17 18__all__ = ["CLAlgorithm", "finetuning"]
22class CLAlgorithm(LightningModule): 23 """ 24 The base class of continual learning algorithms, inherited from `LightningModule`. 25 """ 26 27 def __init__( 28 self, 29 backbone: CLBackbone, 30 heads: HeadsTIL | HeadsCIL, 31 ) -> None: 32 """Initialise the CL algorithm with the network. 33 34 Args: 35 - **backbone** (`CLBackbone`): backbone network. 36 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 37 """ 38 super().__init__() 39 40 self.backbone: CLBackbone = backbone 41 """Store the backbone network.""" 42 self.heads: HeadsTIL | HeadsCIL = heads 43 """Store the output heads.""" 44 self.optimizer: Optimizer 45 """Store the optimizer object (partially initialised) for the backpropagation of task `self.task_id`. Will be equipped with parameters in `configure_optimizers()`.""" 46 self.criterion = nn.CrossEntropyLoss() 47 """The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 48 49 self.task_id: int 50 """Task ID counter indicating which task is being processed. Self updated during the task loop.""" 51 52 self.sanity_check() 53 54 def sanity_check(self) -> None: 55 """Check the sanity of the arguments. 56 57 **Raises:** 58 - **ValueError**: if the `output_dim` of backbone network is not equal to the `input_dim` of CL heads. 59 """ 60 if self.backbone.output_dim != self.heads.input_dim: 61 raise ValueError( 62 "The output_dim of backbone network should be equal to the input_dim of CL heads!" 63 ) 64 65 def setup_task_id( 66 self, task_id: int, num_classes_t: int, optimizer: Optimizer 67 ) -> None: 68 """Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called. 69 70 **Args:** 71 - **task_id** (`int`): the target task ID. 72 - **num_classes_t** (`int`): the number of classes in the task. 73 - **optimizer** (`Optimizer`): the optimizer object (partially initialised) for the task `self.task_id`. 74 """ 75 self.task_id = task_id 76 self.heads.setup_task_id(task_id, num_classes_t) 77 self.optimizer = optimizer 78 79 @override 80 def forward(self, input: Tensor, task_id: int): 81 """The default forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 82 83 **Args:** 84 - **input** (`Tensor`): The input tensor from data. 85 - **task_id** (`int`): the task ID where the data are from. 86 87 Returns: 88 - The output logits tensor. 89 """ 90 feature = self.backbone(input, task_id) 91 logits = self.heads(feature, task_id) 92 return logits 93 94 def configure_optimizers(self) -> Optimizer: 95 """ 96 Configure optimizer hooks by Lightning. 97 See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details. 98 """ 99 # finish partially initialised optimizer by specifying model parameters. The `parameters()` method of this `CLAlrogithm` (inherited from `LightningModule`) returns both backbone and heads parameters 100 return self.optimizer(params=self.parameters())
The base class of continual learning algorithms, inherited from LightningModule
.
27 def __init__( 28 self, 29 backbone: CLBackbone, 30 heads: HeadsTIL | HeadsCIL, 31 ) -> None: 32 """Initialise the CL algorithm with the network. 33 34 Args: 35 - **backbone** (`CLBackbone`): backbone network. 36 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 37 """ 38 super().__init__() 39 40 self.backbone: CLBackbone = backbone 41 """Store the backbone network.""" 42 self.heads: HeadsTIL | HeadsCIL = heads 43 """Store the output heads.""" 44 self.optimizer: Optimizer 45 """Store the optimizer object (partially initialised) for the backpropagation of task `self.task_id`. Will be equipped with parameters in `configure_optimizers()`.""" 46 self.criterion = nn.CrossEntropyLoss() 47 """The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 48 49 self.task_id: int 50 """Task ID counter indicating which task is being processed. Self updated during the task loop.""" 51 52 self.sanity_check()
Initialise the CL algorithm with the network.
Args:
- backbone (
CLBackbone
): backbone network. - heads (
HeadsTIL
|HeadsCIL
): output heads.
Store the optimizer object (partially initialised) for the backpropagation of task self.task_id
. Will be equipped with parameters in configure_optimizers()
.
The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.
Task ID counter indicating which task is being processed. Self updated during the task loop.
54 def sanity_check(self) -> None: 55 """Check the sanity of the arguments. 56 57 **Raises:** 58 - **ValueError**: if the `output_dim` of backbone network is not equal to the `input_dim` of CL heads. 59 """ 60 if self.backbone.output_dim != self.heads.input_dim: 61 raise ValueError( 62 "The output_dim of backbone network should be equal to the input_dim of CL heads!" 63 )
Check the sanity of the arguments.
Raises:
- ValueError: if the
output_dim
of backbone network is not equal to theinput_dim
of CL heads.
65 def setup_task_id( 66 self, task_id: int, num_classes_t: int, optimizer: Optimizer 67 ) -> None: 68 """Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called. 69 70 **Args:** 71 - **task_id** (`int`): the target task ID. 72 - **num_classes_t** (`int`): the number of classes in the task. 73 - **optimizer** (`Optimizer`): the optimizer object (partially initialised) for the task `self.task_id`. 74 """ 75 self.task_id = task_id 76 self.heads.setup_task_id(task_id, num_classes_t) 77 self.optimizer = optimizer
Set up which task's dataset the CL experiment is on. This must be done before forward()
method is called.
Args:
- task_id (
int
): the target task ID. - num_classes_t (
int
): the number of classes in the task. - optimizer (
Optimizer
): the optimizer object (partially initialised) for the taskself.task_id
.
79 @override 80 def forward(self, input: Tensor, task_id: int): 81 """The default forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 82 83 **Args:** 84 - **input** (`Tensor`): The input tensor from data. 85 - **task_id** (`int`): the task ID where the data are from. 86 87 Returns: 88 - The output logits tensor. 89 """ 90 feature = self.backbone(input, task_id) 91 logits = self.heads(feature, task_id) 92 return logits
94 def configure_optimizers(self) -> Optimizer: 95 """ 96 Configure optimizer hooks by Lightning. 97 See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details. 98 """ 99 # finish partially initialised optimizer by specifying model parameters. The `parameters()` method of this `CLAlrogithm` (inherited from `LightningModule`) returns both backbone and heads parameters 100 return self.optimizer(params=self.parameters())
Configure optimizer hooks by Lightning. See Lightning docs for more details.