Implement Custom CL Algorithm
This section guides you through implementing custom continual learning algorithms for use in CLArena.
Continual learning algorithms define the process of learning a sequence of tasks.
Base Classes
In CLArena, continual learning algorithms are implemented as subclasses of the base classes in clarena/cl_algorithms/base.py. The base classes are implemented inheriting Lightning module with additional features for continual learning:
clarena.cl_algorithms.CLAlgorithm
: the base class for all continual learning algorithms.clarena.cl_algorithms.UnlearnableCLAlgorithm
: the base class for unlearnable continual learning algorithms.
Implement CL Algorithm
To implement CL algorithms:
- Inherit
CLAlgorithm
. - Put your algorithm’s hyperparameters in
save_hyperparameters()
in__init__()
. This enables Lightning loggers to manage them automatically. - Implement
training_step()
,validation_step()
,test_step()
to define the training, validation and test steps respectively. Note that these steps are for the current task indicated byself.task_id
, which is updated automatically at the beginning of each task. - Other hooks like
on_train_start()
,on_validation_start()
,on_test_start()
are also free to customize if needed.
CLAlgorithm
works the same as Lightning module. All the other hooks are free to customize. Please refer to the Lightning module documentation for details about the hooks.
Note that configure_optimizers()
is already implemented in CLAlgorithm
to manage the optimizer and learning rate scheduler for each task automatically. You don’t have to implement it unless you want to override the default behaviour.
Implement Your Regularizers
Some CL algorithms require regularizers to manage the interactions between tasks. You can implement your regularizers if the CL algorithms need them. A regularizer is a torch.nn.Module
whose forward()
method define the calculation of the loss. The method is called similar to the network in the training step to calculate the regularization loss, which is then added to the classification loss to form the total loss.
Best practices for the regularization factor, a hyperparameter multiplied with the regularization term, are to implement it as a property of the class:
from torch import nn
class YourReg(nn.Module):
def __init__(self, factor: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.factor = factor
def forward(self, *args, **kwargs):
= ...
reg return reg * self.factor
For more details, please refer to the API Reference and source code. You may take implemented CL algorithms in CLArena as examples. Feel free to contribute by submitting pull requests in GitHub!