Shawn’s Blog
  • 🗂️ Collections
    • 🖥️ Slides Gallery
    • 📚 Library (Reading Notes)
    • 🧑‍🍳️ Cooking Ideas
    • 🍱 Cookbook
    • 💬 Language Learning
    • 🎼 Songbook
  • ⚙️ Projects
    • ⚛ Continual Learning Arena
  • 📄 Papers
    • AdaHAT
  • 🎓 CV
    • CV (English)
    • CV (Mandarin)
  • About

On this page

  • Required Arguments
  • Implementing Instructions
    • Implement Your Regularisers

Implement Your CL Algorithm

Continual learning Datasets are implemented as classes based on the base classes provided by this package in clarena/cl_algorithms/base.py. The base classes are implemented based on Lightning module with added features for continual learning:

  • CLAlgorithm: the general base class for all CL algorithms, where incorporates mechiasm for managing sequential tasks.

Required Arguments

CLAlgorithm requires 2 arguments: backbone and heads, which are the backbone network and the output heads respectively, as they form the forward and backward pass in the CL algorithm. The 2 arguments are passed to CLAlgoithm in our source codes so you don’t have to specify them in your config files.

Implementing Instructions

CLAlgorithm works the same way as a Lightning module. Except for configure_optimizers() which you don’t have to worry about, all the other hooks are free to customize, but do remember here the training, validation and test steps are for the current task indicated by self-updated self.task_id.

We also provide a default forward() method for convenience. You can override it if you want to implement your own forward pass or leave it. It is not necessary.

Implement Your Regularisers

Some CL algorithms require regularisers to manage the interactions between tasks. You can implement your regularisers if the CL algorithms need them. A regulariser is a torch.nn.Module whose forward() method define the calculation of the loss. The method is called similar to the network in the training step to calculate the regularisation loss, which is then added to the classification loss to form the total loss.

Tip

Best practices for the regularisation factor, a hyperparameter multiplied with the regularisation term, are to implement it as a property of the class:

from torch import nn

class YourReg(nn.Module):

    def __init__(self, factor: float, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.factor = factor

    def forward(self, *args, **kwargs):
        reg = ...
        return reg * self.factor

Please find the API reference for detailed information about implementation, and remember you can always take implemented CL algorithms in package source codes clarena/cl_algorithms/ as examples, like Finetuning!

API Reference (CL Algorithms) GitHub Pull Request

Back to top
 
 

©️ 2025 Pengxiang Wang. All rights reserved.