Implement Custom Backbone Network
This section guides you through implementing custom backbone networks for use in CLArena.
Backbones for multi-task learning and single-task learning are nothing different from neural networks like those defined by PyTorch nn.Module
. In continual learning, the backbone is typically shared among tasks; but in some CL approaches (particularly architecture-based approaches), the backbone is dynamic: it can expand, incorporate additional mechanisms like masks, or even assign different networks to each task.
Base Classes
In CLArena, backbone networks are implemented as subclasses of the base classes defined in clarena/backbones/base.py. The base classes are implemented inheriting PyTorch nn.Module
with additional features for different paradigms:
clarena.backbones.Backbone
: The base class for all backbone networks. Multi-task and single-task learning can use this class directly.clarena.backbones.CLBackbone
: The base class for continual learning backbone networks, which incorporates mechanisms for managing continual learning tasks.clarena.backbones.HATMaskBackbone
: The base class for backbones used in HAT (Hard Attention to the Task) CL algorithm.clarena.backbones.WSNMaskBackbone
: The base class for backbones used in WSN (Winning Subnetworks) CL algorithm.
Implementing MTL & STL Backbones
You can implement backbone just like the PyTorch nn.module
:
- Inherit
Backbone
; - Define model layers in
__init__()
; - Define
forward()
method. If the behaviour for different stages is different, usestage
argument to distinguish it.
Implement CL Backbones
For backbones for continual learning, we recommend implementing the Backbone
and inheriting CLBackbone
through multiple inheritance. For example, to implement CLMLP
, first implement MLP
from Backbone
, then define class CLMLP(CLBackbone, MLP)
.
For those CL algorithms requiring different backbones for different tasks, you can use self.task_id
as the task index to distinguish different tasks in CLBackbone
.
For more details, please refer to the API Reference and source code. You may take implemented backbones in CLArena as examples. Feel free to contribute by submitting pull requests in GitHub!