clarena.backbones

Backbone Networks for Continual Learning

This submodule provides the neural network architectures for continual learning** that can be used in CLArena.

Please note that this is an API documantation. Please refer to the main documentation page for more information about the backbone networks and how to use and customize them:

 1"""
 2
 3# Backbone Networks for Continual Learning
 4
 5This submodule provides the **neural network architectures** for continual learning** that can be used in CLArena. 
 6
 7Please note that this is an API documantation. Please refer to the main documentation page for more information about the backbone networks and how to use and customize them:
 8
 9- **Configure your backbone network:** [https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/backbone-network](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiments/backbone-network)
10- **Implement your backbone network:** [https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/backbone-network](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-cl-modules/backbone-network)
11
12"""
13
14from .base import CLBackbone
15from .mlp import MLP
16
17__all__ = ["CLBackbone", "mlp"]
class CLBackbone(torch.nn.modules.module.Module):
17class CLBackbone(nn.Module):
18    """The base class of continual learning backbone networks, inherited from `torch.nn.Module`."""
19
20    def __init__(self, output_dim: int) -> None:
21        """Initialise the CL backbone network.
22
23        **Args:**
24        - **output_dim** (`int`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`.
25        """
26        super().__init__()
27
28        self.output_dim = output_dim
29        """Store the output dimension of the backbone network."""
30
31        self.task_id: int
32        """Task ID counter indicating which task is being processed. Self updated during the task loop."""
33
34    def setup_task_id(self, task_id: int) -> None:
35        """Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called.
36
37        **Args:**
38        - **task_id** (`int`): the target task ID.
39        """
40        self.task_id = task_id
41
42    @override
43    def forward(self, input: Tensor, task_id: int | None = None) -> Tensor:
44        """The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks.
45
46        **Args:**
47        - **input** (`Tensor`): The input tensor from data.
48        - **task_id** (`int`): the task ID where the data are from. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value.
49
50        **Returns:**
51        - The output feature tensor to be passed into heads.
52        """

The base class of continual learning backbone networks, inherited from torch.nn.Module.

CLBackbone(output_dim: int)
20    def __init__(self, output_dim: int) -> None:
21        """Initialise the CL backbone network.
22
23        **Args:**
24        - **output_dim** (`int`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`.
25        """
26        super().__init__()
27
28        self.output_dim = output_dim
29        """Store the output dimension of the backbone network."""
30
31        self.task_id: int
32        """Task ID counter indicating which task is being processed. Self updated during the task loop."""

Initialise the CL backbone network.

Args:

  • output_dim (int): The output dimension which connects to CL output heads. The input_dim of output heads are expected to be the same as this output_dim.
output_dim

Store the output dimension of the backbone network.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

def setup_task_id(self, task_id: int) -> None:
34    def setup_task_id(self, task_id: int) -> None:
35        """Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called.
36
37        **Args:**
38        - **task_id** (`int`): the target task ID.
39        """
40        self.task_id = task_id

Set up which task's dataset the CL experiment is on. This must be done before forward() method is called.

Args:

  • task_id (int): the target task ID.
@override
def forward(self, input: torch.Tensor, task_id: int | None = None) -> torch.Tensor:
42    @override
43    def forward(self, input: Tensor, task_id: int | None = None) -> Tensor:
44        """The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks.
45
46        **Args:**
47        - **input** (`Tensor`): The input tensor from data.
48        - **task_id** (`int`): the task ID where the data are from. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value.
49
50        **Returns:**
51        - The output feature tensor to be passed into heads.
52        """

The forward pass for data from task task_id. In some backbones, the forward pass might be different for different tasks.

Args:

  • input (Tensor): The input tensor from data.
  • task_id (int): the task ID where the data are from. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value.

Returns:

  • The output feature tensor to be passed into heads.