Implement Your CL Algorithm
Continual learning Datasets are implemented as classes based on the base classes provided by this package in clarena/cl_algorithms/base.py. The base classes are implemented based on Lightning module with added features for continual learning:
CLAlgorithm
: the general base class for all CL algorithms, where incorporates mechiasm for managing sequential tasks.
Required Arguments
CLAlgorithm
requires 2 arguments: backbone
and heads
, which are the backbone network and the output heads respectively, as they form the forward and backward pass in the CL algorithm. The 2 arguments are passed to CLAlgoithm
in our source codes so you donโt have to specify them in your config files.
Implementing Instructions
CLAlgorithm
works the same way as a Lightning module. Except for configure_optimizers()
which you donโt have to worry about, all the other hooks are free to customize, but do remember here the training, validation and test steps are for the current task indicated by self-updated self.task_id
.
We also provide a default forward()
method for convenience. You can override it if you want to implement your own forward pass or leave it. It is not necessary.
Implement Your Regularisers
Some CL algorithms require regularisers to manage the interactions between tasks. You can implement your regularisers if the CL algorithms need them. A regulariser is a torch.nn.Module
whose forward()
method define the calculation of the loss. The method is called similar to the network in the training step to calculate the regularisation loss, which is then added to the classification loss to form the total loss.
Best practices for the regularisation factor, a hyperparameter multiplied with the regularisation term, are to implement it as a property of the class:
from torch import nn
class YourReg(nn.Module):
def __init__(self, factor: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.factor = factor
def forward(self, *args, **kwargs):
= ...
reg return reg * self.factor
Please find the API documentation for detailed information about implementation, and remember you can always take implemented CL algorithms in package source codes clarena/cl_algorithms/ as examples, like Finetuning
!