clarena.backbones

Backbone Networks

This submodule provides the backbone neural network architectures for all paradigms in CLArena.

Here are the base classes for backbone networks, which inherit from PyTorch nn.Module:

Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement backbone networks:

 1r"""
 2
 3# Backbone Networks
 4
 5This submodule provides the **backbone neural network architectures** for all paradigms in CLArena.
 6
 7Here are the base classes for backbone networks, which inherit from PyTorch `nn.Module`:
 8
 9- `Backbone`: the base class for all backbone networks. Multi-task and single-task learning can use this class directly.
10-   `CLBackbone`: the base class for continual learning backbone networks, which incorporates mechanisms for managing continual learning tasks.
11    - `HATMaskBackbone`: the base class for backbones used in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) CL algorithm.
12        - `AmnesiacHATBackbone`: The base class for backbones used in AmnesiacHAT CL algorithm.
13    - `WSNMaskBackbone`: The base class for backbones used in [WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) CL algorithm.
14
15Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement backbone networks:
16
17- [**Configure Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/backbone-network)
18- [**Implement Custom Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/backbone-network)
19
20
21"""
22
23from .base import (
24    Backbone,
25    CLBackbone,
26    HATMaskBackbone,
27    AmnesiacHATBackbone,
28    WSNMaskBackbone,
29)
30from .mlp import MLP, CLMLP
31from .hat_mask_mlp import HATMaskMLP
32from .amnesiac_hat_mlp import AmnesiacHATMLP
33from .wsn_mask_mlp import WSNMaskMLP
34from .resnet import (
35    ResNet18,
36    ResNet34,
37    ResNet50,
38    ResNet101,
39    ResNet152,
40    CLResNet18,
41    CLResNet34,
42    CLResNet50,
43    CLResNet101,
44    CLResNet152,
45)
46from .hat_mask_resnet import (
47    HATMaskResNet18,
48    HATMaskResNet34,
49    HATMaskResNet50,
50    HATMaskResNet101,
51    HATMaskResNet152,
52)
53from .amnesiac_hat_resnet import (
54    AmnesiacHATResNet18,
55    AmnesiacHATResNet34,
56    AmnesiacHATResNet50,
57    AmnesiacHATResNet101,
58    AmnesiacHATResNet152,
59)
60
61
62__all__ = [
63    "Backbone",
64    "CLBackbone",
65    "HATMaskBackbone",
66    "AmnesiacHATBackbone",
67    "WSNMaskBackbone",
68    "mlp",
69    "hat_mask_mlp",
70    "amnesiac_hat_mlp",
71    "wsn_mask_mlp",
72    "resnet",
73    "hat_mask_resnet",
74    "amnesiac_hat_resnet",
75]
class Backbone(torch.nn.modules.module.Module):
 26class Backbone(nn.Module):
 27    r"""The base class for backbone networks."""
 28
 29    def __init__(self, output_dim: int | None, **kwargs) -> None:
 30        r"""
 31        **Args:**
 32        - **output_dim** (`int` | `None`): The output dimension that connects to output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
 33        - **kwargs**: Reserved for multiple inheritance.
 34        """
 35        super().__init__()
 36
 37        self.output_dim: int = output_dim
 38        r"""The output dimension of the backbone network."""
 39
 40        self.weighted_layer_names: list[str] = []
 41        r"""The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. **It must be provided in subclasses.**
 42        
 43        The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the `nn.Module` internal naming mechanism with `.` replaced with `/`. For example: 
 44        - If a layer is assigned to `self.conv1`, the name becomes `conv1`. 
 45        - If `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. 
 46        - If a hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should have been `block.conv1` according to `nn.Module`'s rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a `ModuleDict`.
 47        """
 48
 49    def get_layer_by_name(self, layer_name: str | None) -> nn.Module | None:
 50        r"""Get the layer by its name.
 51
 52        **Args:**
 53        - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`.
 54
 55        **Returns:**
 56        - **layer** (`nn.Module` | `None`): The layer. If `layer_name` is `None`, return `None`.
 57        """
 58        if layer_name is None:
 59            return None
 60
 61        for name, layer in self.named_modules():
 62            if name == layer_name.replace("/", "."):
 63                return layer
 64
 65    def preceding_layer_name(self, layer_name: str | None) -> str | None:
 66        r"""Get the name of the preceding layer of the given layer from the stored `weighted_layer_names`.
 67
 68        **Args:**
 69        - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`.
 70
 71        **Returns:**
 72        - **preceding_layer_name** (`str`): The name of the preceding layer. If the given layer is the first layer, return `None`.
 73        """
 74        if layer_name is None:
 75            return None
 76
 77        if layer_name not in self.weighted_layer_names:
 78            raise ValueError(
 79                f"The layer name {layer_name} doesn't exist in weighted layer names."
 80            )
 81
 82        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
 83        if weighted_layer_idx == 0:
 84            return None
 85        preceding_layer_name = self.weighted_layer_names[weighted_layer_idx - 1]
 86        return preceding_layer_name
 87
 88    def next_layer_name(self, layer_name: str) -> str:
 89        r"""Get the name of the next layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the last layer of the BACKBONE, return `None`.
 90
 91        **Args:**
 92        - **layer_name** (`str`): The name of the layer.
 93
 94        **Returns:**
 95        - **next_layer_name** (`str`): The name of the next layer.
 96
 97        **Raises:**
 98        - **ValueError**: If `layer_name` is not in the weighted layer order.
 99        """
100
101        if layer_name not in self.weighted_layer_names:
102            raise ValueError(f"The layer name {layer_name} doesn't exist.")
103
104        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
105        if weighted_layer_idx == len(self.weighted_layer_names) - 1:
106            return None
107        next_layer_name = self.weighted_layer_names[weighted_layer_idx + 1]
108        return next_layer_name
109
110    @override  # since `nn.Module` uses it
111    def forward(
112        self,
113        input: Tensor,
114        stage: str,
115    ) -> tuple[Tensor, dict[str, Tensor]]:
116        r"""The forward pass. **It must be implemented by subclasses.**
117
118        **Args:**
119        - **input** (`Tensor`): The input tensor from data.
120        - **stage** (`str`): The stage of the forward pass; one of:
121            1. 'train': training stage.
122            2. 'validation': validation stage.
123            3. 'test': testing stage.
124
125        **Returns:**
126        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
127        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes.
128        """

The base class for backbone networks.

Backbone(output_dim: int | None, **kwargs)
29    def __init__(self, output_dim: int | None, **kwargs) -> None:
30        r"""
31        **Args:**
32        - **output_dim** (`int` | `None`): The output dimension that connects to output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
33        - **kwargs**: Reserved for multiple inheritance.
34        """
35        super().__init__()
36
37        self.output_dim: int = output_dim
38        r"""The output dimension of the backbone network."""
39
40        self.weighted_layer_names: list[str] = []
41        r"""The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. **It must be provided in subclasses.**
42        
43        The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the `nn.Module` internal naming mechanism with `.` replaced with `/`. For example: 
44        - If a layer is assigned to `self.conv1`, the name becomes `conv1`. 
45        - If `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. 
46        - If a hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should have been `block.conv1` according to `nn.Module`'s rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a `ModuleDict`.
47        """

Args:

  • output_dim (int | None): The output dimension that connects to output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • kwargs: Reserved for multiple inheritance.
output_dim: int

The output dimension of the backbone network.

weighted_layer_names: list[str]

The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. It must be provided in subclasses.

The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the nn.Module internal naming mechanism with . replaced with /. For example:

  • If a layer is assigned to self.conv1, the name becomes conv1.
  • If nn.Sequential is used, the name becomes the index of the layer in the sequence, such as 0, 1, etc.
  • If a hierarchical structure is used, for example, a nn.Module is assigned to self.block which has self.conv1, the name becomes block/conv1. Note that it should have been block.conv1 according to nn.Module's rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a ModuleDict.
def get_layer_by_name(self, layer_name: str | None) -> torch.nn.modules.module.Module | None:
49    def get_layer_by_name(self, layer_name: str | None) -> nn.Module | None:
50        r"""Get the layer by its name.
51
52        **Args:**
53        - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`.
54
55        **Returns:**
56        - **layer** (`nn.Module` | `None`): The layer. If `layer_name` is `None`, return `None`.
57        """
58        if layer_name is None:
59            return None
60
61        for name, layer in self.named_modules():
62            if name == layer_name.replace("/", "."):
63                return layer

Get the layer by its name.

Args:

  • layer_name (str | None): The layer name following the nn.Module internal naming mechanism with . replaced with /. If None, return None.

Returns:

  • layer (nn.Module | None): The layer. If layer_name is None, return None.
def preceding_layer_name(self, layer_name: str | None) -> str | None:
65    def preceding_layer_name(self, layer_name: str | None) -> str | None:
66        r"""Get the name of the preceding layer of the given layer from the stored `weighted_layer_names`.
67
68        **Args:**
69        - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`.
70
71        **Returns:**
72        - **preceding_layer_name** (`str`): The name of the preceding layer. If the given layer is the first layer, return `None`.
73        """
74        if layer_name is None:
75            return None
76
77        if layer_name not in self.weighted_layer_names:
78            raise ValueError(
79                f"The layer name {layer_name} doesn't exist in weighted layer names."
80            )
81
82        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
83        if weighted_layer_idx == 0:
84            return None
85        preceding_layer_name = self.weighted_layer_names[weighted_layer_idx - 1]
86        return preceding_layer_name

Get the name of the preceding layer of the given layer from the stored weighted_layer_names.

Args:

  • layer_name (str | None): The layer name following the nn.Module internal naming mechanism with . replaced with /. If None, return None.

Returns:

  • preceding_layer_name (str): The name of the preceding layer. If the given layer is the first layer, return None.
def next_layer_name(self, layer_name: str) -> str:
 88    def next_layer_name(self, layer_name: str) -> str:
 89        r"""Get the name of the next layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the last layer of the BACKBONE, return `None`.
 90
 91        **Args:**
 92        - **layer_name** (`str`): The name of the layer.
 93
 94        **Returns:**
 95        - **next_layer_name** (`str`): The name of the next layer.
 96
 97        **Raises:**
 98        - **ValueError**: If `layer_name` is not in the weighted layer order.
 99        """
100
101        if layer_name not in self.weighted_layer_names:
102            raise ValueError(f"The layer name {layer_name} doesn't exist.")
103
104        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
105        if weighted_layer_idx == len(self.weighted_layer_names) - 1:
106            return None
107        next_layer_name = self.weighted_layer_names[weighted_layer_idx + 1]
108        return next_layer_name

Get the name of the next layer of the given layer from the stored self.masked_layer_order. If the given layer is the last layer of the BACKBONE, return None.

Args:

  • layer_name (str): The name of the layer.

Returns:

  • next_layer_name (str): The name of the next layer.

Raises:

  • ValueError: If layer_name is not in the weighted layer order.
@override
def forward( self, input: torch.Tensor, stage: str) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
110    @override  # since `nn.Module` uses it
111    def forward(
112        self,
113        input: Tensor,
114        stage: str,
115    ) -> tuple[Tensor, dict[str, Tensor]]:
116        r"""The forward pass. **It must be implemented by subclasses.**
117
118        **Args:**
119        - **input** (`Tensor`): The input tensor from data.
120        - **stage** (`str`): The stage of the forward pass; one of:
121            1. 'train': training stage.
122            2. 'validation': validation stage.
123            3. 'test': testing stage.
124
125        **Returns:**
126        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
127        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes.
128        """

The forward pass. It must be implemented by subclasses.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes.
class CLBackbone(clarena.backbones.Backbone):
131class CLBackbone(Backbone):
132    r"""The base class of continual learning backbone networks."""
133
134    def __init__(self, output_dim: int | None, **kwargs) -> None:
135        r"""
136        **Args:**
137        - **output_dim** (`int` | `None`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
138        - **kwargs**: Reserved for multiple inheritance.
139        """
140        super().__init__(output_dim=output_dim, **kwargs)
141
142        # task ID control
143        self.task_id: int
144        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
145        self.processed_task_ids: list[int] = []
146        r"""Task IDs that have been processed."""
147
148    def setup_task_id(self, task_id: int) -> None:
149        r"""Set up task `task_id`. This must be done before the `forward()` method is called."""
150        self.task_id = task_id
151        self.processed_task_ids.append(task_id)
152
153    @override  # since `nn.Module` uses it
154    def forward(
155        self,
156        input: Tensor,
157        stage: str,
158        task_id: int | None = None,
159    ) -> tuple[Tensor, dict[str, Tensor]]:
160        r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.**
161
162        **Args:**
163        - **input** (`Tensor`): The input tensor from data.
164        - **stage** (`str`): The stage of the forward pass; one of:
165            1. 'train': training stage.
166            2. 'validation': validation stage.
167            3. 'test': testing stage.
168        - **task_id** (`int` | `None`): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task `self.task_id`. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value.
169
170        **Returns:**
171        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
172        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
173        """

The base class of continual learning backbone networks.

CLBackbone(output_dim: int | None, **kwargs)
134    def __init__(self, output_dim: int | None, **kwargs) -> None:
135        r"""
136        **Args:**
137        - **output_dim** (`int` | `None`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
138        - **kwargs**: Reserved for multiple inheritance.
139        """
140        super().__init__(output_dim=output_dim, **kwargs)
141
142        # task ID control
143        self.task_id: int
144        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
145        self.processed_task_ids: list[int] = []
146        r"""Task IDs that have been processed."""

Args:

  • output_dim (int | None): The output dimension that connects to CL output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • kwargs: Reserved for multiple inheritance.
task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.

processed_task_ids: list[int]

Task IDs that have been processed.

def setup_task_id(self, task_id: int) -> None:
148    def setup_task_id(self, task_id: int) -> None:
149        r"""Set up task `task_id`. This must be done before the `forward()` method is called."""
150        self.task_id = task_id
151        self.processed_task_ids.append(task_id)

Set up task task_id. This must be done before the forward() method is called.

@override
def forward( self, input: torch.Tensor, stage: str, task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
153    @override  # since `nn.Module` uses it
154    def forward(
155        self,
156        input: Tensor,
157        stage: str,
158        task_id: int | None = None,
159    ) -> tuple[Tensor, dict[str, Tensor]]:
160        r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.**
161
162        **Args:**
163        - **input** (`Tensor`): The input tensor from data.
164        - **stage** (`str`): The stage of the forward pass; one of:
165            1. 'train': training stage.
166            2. 'validation': validation stage.
167            3. 'test': testing stage.
168        - **task_id** (`int` | `None`): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task `self.task_id`. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value.
169
170        **Returns:**
171        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
172        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
173        """

The forward pass for data from task task_id. In some backbones, the forward pass might be different for different tasks. It must be implemented by subclasses.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • task_id (int | None): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task self.task_id. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
class HATMaskBackbone(clarena.backbones.CLBackbone):
176class HATMaskBackbone(CLBackbone):
177    r"""The backbone network for HAT-based algorithms with learnable hard attention masks.
178
179    HAT-based algorithms include:
180
181    - [**HAT (Hard Attention to the Task, 2018)**](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
182    - [**AdaHAT (Adaptive Hard Attention to the Task, 2024)**](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves HAT by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
183    - **FG-AdaHAT** is an architecture-based continual learning approach that improves HAT by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
184    """
185
186    def __init__(self, output_dim: int | None, gate: str, **kwargs) -> None:
187        r"""
188        **Args:**
189        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
190        - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of:
191            - `sigmoid`: the sigmoid function.
192        - **kwargs**: Reserved for multiple inheritance.
193        """
194        super().__init__(output_dim=output_dim, **kwargs)
195
196        self.gate: str = gate
197        r"""The type of gate function."""
198        self.gate_fn: Callable
199        r"""The gate function mapping the real value task embeddings into attention masks."""
200
201        if gate == "sigmoid":
202            self.gate_fn = nn.Sigmoid()
203
204        self.task_embedding_t: nn.ModuleDict = nn.ModuleDict()
205        r"""The task embedding for the current task `self.task_id`. Keys are layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units).
206        
207        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
208        
209        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
210        
211        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
212        """
213
214        self.masks: dict[int, dict[str, Tensor]] = {}
215        r"""The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ). """
216
217        HATMaskBackbone.sanity_check(self)
218
219    def initialize_task_embedding(self, mode: str) -> None:
220        r"""Initialize the task embedding for the current task `self.task_id`.
221
222        **Args:**
223        - **mode** (`str`): The initialization mode for task embeddings; one of:
224            1. 'N01' (default): standard normal distribution $N(0, 1)$.
225            2. 'U-11': uniform distribution $U(-1, 1)$.
226            3. 'U01': uniform distribution $U(0, 1)$.
227            4. 'U-10': uniform distribution $U(-1, 0)$.
228            5. 'last': inherit task embeddings from the last task.
229        """
230        for te in self.task_embedding_t.values():
231            if mode == "N01":
232                nn.init.normal_(te.weight, 0, 1)
233            elif mode == "U-11":
234                nn.init.uniform_(te.weight, -1, 1)
235            elif mode == "U01":
236                nn.init.uniform_(te.weight, 0, 1)
237            elif mode == "U-10":
238                nn.init.uniform_(te.weight, -1, 0)
239            elif mode == "last":
240                pass
241
242    def sanity_check(self) -> None:
243        r"""Sanity check."""
244
245        if self.gate not in ["sigmoid"]:
246            raise ValueError("The gate should be one of: 'sigmoid'.")
247
248    def get_mask(
249        self,
250        stage: str,
251        s_max: float | None = None,
252        batch_idx: int | None = None,
253        num_batches: int | None = None,
254        test_task_id: int | None = None,
255    ) -> dict[str, Tensor]:
256        r"""Get the hard attention mask used in the `forward()` method for different stages.
257
258        **Args:**
259        - **stage** (`str`): The stage when applying the conversion; one of:
260            1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
261            2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by `s_max`, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
262            3. 'test': testing stage. Apply the test mask directly from the stored masks using `test_task_id`.
263            4. 'unlearning_test': unlearning testing stage. The mask is set to all 1s for unlearning testing.
264        - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
265        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
266        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
267        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
268
269        **Returns:**
270        - **mask** (`dict[str, Tensor]`): The hard attention (with values 0 or 1) mask. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
271        """
272
273        # sanity check
274        if stage == "train" and (
275            s_max is None or batch_idx is None or num_batches is None
276        ):
277            raise ValueError(
278                "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`."
279            )
280        if stage == "validation" and (s_max is None):
281            raise ValueError(
282                "The `s_max` should be provided at validation stage, instead of the default value `None`."
283            )
284        if stage == "test" and (test_task_id is None):
285            raise ValueError(
286                "The `task_mask` should be provided at testing stage, instead of the default value `None`."
287            )
288
289        mask = {}
290        if stage == "train":
291            for layer_name in self.weighted_layer_names:
292                anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / (
293                    num_batches - 1
294                )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
295                mask[layer_name] = self.gate_fn(
296                    self.task_embedding_t[layer_name].weight * anneal_scalar
297                ).squeeze()
298        elif stage == "validation":
299            for layer_name in self.weighted_layer_names:
300                mask[layer_name] = self.gate_fn(
301                    self.task_embedding_t[layer_name].weight * s_max
302                ).squeeze()
303        elif stage == "test":
304            mask = self.masks[test_task_id]
305            for layer_name, layer_mask in mask.items():
306                layer = self.get_layer_by_name(layer_name)
307                target_device = layer.weight.device
308                if layer_mask.device != target_device:
309                    mask[layer_name] = layer_mask.to(target_device)
310        elif stage == "unlearning_test":
311            for layer_name in self.weighted_layer_names:
312                layer = self.get_layer_by_name(layer_name)
313                mask[layer_name] = torch.ones(
314                    layer.weight.size(0), device=layer.weight.device
315                )
316
317        return mask
318
319    def te_to_binary_mask(self) -> dict[str, Tensor]:
320        r"""Convert the current task embedding into a binary mask.
321
322        This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task.
323
324        **Returns:**
325        - **mask_t** (`dict[str, Tensor]`): The binary mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
326        """
327        # get the mask for the current task
328        mask_t = {
329            layer_name: (self.task_embedding_t[layer_name].weight > 0)
330            .float()
331            .squeeze()
332            .detach()
333            for layer_name in self.weighted_layer_names
334        }
335
336        return mask_t
337
338    def combine_masks(
339        self, masks: list[dict[str, Tensor]], mode: str
340    ) -> dict[str, Tensor]:
341        r"""Combine multiple masks by taking their element-wise minimum (for intersection) / maximum (for union).
342
343        **Args:**
344        - **masks** (`list[dict[str, Tensor]]`): A list of masks. Each mask is a dict where keys are layer names and values are mask tensors.
345        - **mode** (`str`): The combination mode; one of:
346            - 'intersection': take the element-wise minimum of the masks (for intersection).
347            - 'union': take the element-wise maximum of the masks (for union).
348
349        **Returns:**
350        - **combined_mask** (`dict[str, Tensor]`): The combined mask.
351        """
352
353        combined_mask = {}
354        for layer_name in masks[0].keys():
355            layer_mask_tensors = torch.stack(
356                [mask[layer_name] for mask in masks], dim=0
357            )
358            if mode == "intersection":
359                combined_mask[layer_name] = torch.min(layer_mask_tensors, dim=0).values
360            elif mode == "union":
361                combined_mask[layer_name] = torch.max(layer_mask_tensors, dim=0).values
362            else:
363                raise ValueError(
364                    f"Unsupported mode: {mode}. Use 'intersection' or 'union'."
365                )
366
367        return combined_mask
368
369    def store_mask(self) -> None:
370        r"""Store the mask for the current task `self.task_id`."""
371        mask_t = self.te_to_binary_mask()
372
373        for subhatmodule in self.modules():
374            if isinstance(subhatmodule, HATMaskBackbone):  # for all sub HAT modules
375                subhatmodule.masks[self.task_id] = mask_t
376
377        return mask_t
378
379    def get_layer_measure_parameter_wise(
380        self,
381        neuron_wise_measure: dict[str, Tensor],
382        layer_name: str,
383        aggregation_mode: str,
384    ) -> Tensor:
385        r"""Get the parameter-wise measure on the parameters right before the given layer.
386
387        It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
388
389        Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer.
390
391        This method is used to calculate parameter-wise measures in various HAT-based algorithms:
392
393        - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks `cumulative_mask_for_previous_tasks`, which is $\text{Agg} \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in Eq. (2) in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
394        - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks `summative_mask_for_previous_tasks`, which is $\text{Agg} \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
395        - **FG-AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks `summative_importance_for_previous_tasks`, which is $\text{Agg} \left(I_{l,i}^{<t}, I_{l-1,j}^{<t}\right)$ in Eq. (2) in the FG-AdaHAT paper.
396
397        **Args:**
398        - **neuron_wise_measure** (`dict[str, Tensor]`): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ).
399        - **layer_name** (`str`): The name of the given layer.
400        - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
401            - 'min': takes the minimum of the two connected unit measures.
402            - 'max': takes the maximum of the two connected unit measures.
403            - 'mean': takes the mean of the two connected unit measures.
404
405        **Returns:**
406        - **weight_measure** (`Tensor`): The weight measure matrix, the same size as the corresponding weights.
407        - **bias_measure** (`Tensor`): The bias measure vector, the same size as the corresponding bias.
408        """
409
410        # initialize the aggregation function
411        if aggregation_mode == "min":
412            aggregation_func = torch.min
413        elif aggregation_mode == "max":
414            aggregation_func = torch.max
415        elif aggregation_mode == "mean":
416            aggregation_func = torch.mean
417        else:
418            raise ValueError(
419                f"The aggregation method {aggregation_mode} is not supported."
420            )
421
422        # get the preceding layer
423        preceding_layer_name = self.preceding_layer_name(layer_name)
424
425        # get weight size for expanding the measures
426        layer = self.get_layer_by_name(layer_name)
427        weight_size = layer.weight.size()
428
429        # construct the weight-wise measure
430        layer_measure = neuron_wise_measure[layer_name]
431        layer_measure_broadcast_size = (-1, 1) + tuple(
432            1 for _ in range(len(weight_size) - 2)
433        )  # since the size of mask tensor is (number of units, ), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
434
435        layer_measure_broadcasted = layer_measure.view(
436            *layer_measure_broadcast_size
437        ).expand(
438            weight_size,
439        )  # expand the given layer mask to the weight size and broadcast
440
441        if (
442            preceding_layer_name
443        ):  # if the layer is not the first layer, where the preceding layer exists
444
445            preceding_layer_measure_broadcast_size = (1, -1) + tuple(
446                1 for _ in range(len(weight_size) - 2)
447            )  # since the size of mask tensor is (number of units, ), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
448            preceding_layer_measure = neuron_wise_measure[preceding_layer_name]
449            preceding_layer_measure_broadcasted = preceding_layer_measure.view(
450                *preceding_layer_measure_broadcast_size
451            ).expand(
452                weight_size
453            )  # expand the preceding layer mask to the weight size and broadcast
454            weight_measure = aggregation_func(
455                layer_measure_broadcasted, preceding_layer_measure_broadcasted
456            )  # get the minimum of the two mask vectors, from expanded
457        else:  # if the layer is the first layer
458            weight_measure = layer_measure_broadcasted
459
460        # construct the bias-wise measure
461        bias_measure = layer_measure
462
463        return weight_measure, bias_measure
464
465    @override
466    def forward(
467        self,
468        input: Tensor,
469        stage: str,
470        s_max: float | None = None,
471        batch_idx: int | None = None,
472        num_batches: int | None = None,
473        test_task_id: int | None = None,
474    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
475        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to the units in each layer.
476
477        **Args:**
478        - **input** (`Tensor`): The input tensor from data.
479        - **stage** (`str`): The stage of the forward pass; one of:
480            1. 'train': training stage.
481            2. 'validation': validation stage.
482            3. 'test': testing stage.
483        - **s_max** (`float`): The maximum scaling factor in the gate function. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
484        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
485        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
486        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
487
488        **Returns:**
489        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
490        - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
491        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class.
492        """

The backbone network for HAT-based algorithms with learnable hard attention masks.

HAT-based algorithms include:

  • HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
  • AdaHAT (Adaptive Hard Attention to the Task, 2024) is an architecture-based continual learning approach that improves HAT by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
  • FG-AdaHAT is an architecture-based continual learning approach that improves HAT by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
HATMaskBackbone(output_dim: int | None, gate: str, **kwargs)
186    def __init__(self, output_dim: int | None, gate: str, **kwargs) -> None:
187        r"""
188        **Args:**
189        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
190        - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of:
191            - `sigmoid`: the sigmoid function.
192        - **kwargs**: Reserved for multiple inheritance.
193        """
194        super().__init__(output_dim=output_dim, **kwargs)
195
196        self.gate: str = gate
197        r"""The type of gate function."""
198        self.gate_fn: Callable
199        r"""The gate function mapping the real value task embeddings into attention masks."""
200
201        if gate == "sigmoid":
202            self.gate_fn = nn.Sigmoid()
203
204        self.task_embedding_t: nn.ModuleDict = nn.ModuleDict()
205        r"""The task embedding for the current task `self.task_id`. Keys are layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units).
206        
207        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
208        
209        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
210        
211        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
212        """
213
214        self.masks: dict[int, dict[str, Tensor]] = {}
215        r"""The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ). """
216
217        HATMaskBackbone.sanity_check(self)

Args:

  • output_dim (int): The output dimension that connects to CL output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • gate (str): The type of gate function turning the real value task embeddings into attention masks; one of:
    • sigmoid: the sigmoid function.
  • kwargs: Reserved for multiple inheritance.
gate: str

The type of gate function.

gate_fn: Callable

The gate function mapping the real value task embeddings into attention masks.

task_embedding_t: torch.nn.modules.container.ModuleDict

The task embedding for the current task self.task_id. Keys are layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has size (1, number of units).

We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.

We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)

This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.

masks: dict[int, dict[str, torch.Tensor]]

The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ).

def initialize_task_embedding(self, mode: str) -> None:
219    def initialize_task_embedding(self, mode: str) -> None:
220        r"""Initialize the task embedding for the current task `self.task_id`.
221
222        **Args:**
223        - **mode** (`str`): The initialization mode for task embeddings; one of:
224            1. 'N01' (default): standard normal distribution $N(0, 1)$.
225            2. 'U-11': uniform distribution $U(-1, 1)$.
226            3. 'U01': uniform distribution $U(0, 1)$.
227            4. 'U-10': uniform distribution $U(-1, 0)$.
228            5. 'last': inherit task embeddings from the last task.
229        """
230        for te in self.task_embedding_t.values():
231            if mode == "N01":
232                nn.init.normal_(te.weight, 0, 1)
233            elif mode == "U-11":
234                nn.init.uniform_(te.weight, -1, 1)
235            elif mode == "U01":
236                nn.init.uniform_(te.weight, 0, 1)
237            elif mode == "U-10":
238                nn.init.uniform_(te.weight, -1, 0)
239            elif mode == "last":
240                pass

Initialize the task embedding for the current task self.task_id.

Args:

  • mode (str): The initialization mode for task embeddings; one of:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit task embeddings from the last task.
def sanity_check(self) -> None:
242    def sanity_check(self) -> None:
243        r"""Sanity check."""
244
245        if self.gate not in ["sigmoid"]:
246            raise ValueError("The gate should be one of: 'sigmoid'.")

Sanity check.

def get_mask( self, stage: str, s_max: float | None = None, batch_idx: int | None = None, num_batches: int | None = None, test_task_id: int | None = None) -> dict[str, torch.Tensor]:
248    def get_mask(
249        self,
250        stage: str,
251        s_max: float | None = None,
252        batch_idx: int | None = None,
253        num_batches: int | None = None,
254        test_task_id: int | None = None,
255    ) -> dict[str, Tensor]:
256        r"""Get the hard attention mask used in the `forward()` method for different stages.
257
258        **Args:**
259        - **stage** (`str`): The stage when applying the conversion; one of:
260            1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
261            2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by `s_max`, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
262            3. 'test': testing stage. Apply the test mask directly from the stored masks using `test_task_id`.
263            4. 'unlearning_test': unlearning testing stage. The mask is set to all 1s for unlearning testing.
264        - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
265        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
266        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
267        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
268
269        **Returns:**
270        - **mask** (`dict[str, Tensor]`): The hard attention (with values 0 or 1) mask. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
271        """
272
273        # sanity check
274        if stage == "train" and (
275            s_max is None or batch_idx is None or num_batches is None
276        ):
277            raise ValueError(
278                "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`."
279            )
280        if stage == "validation" and (s_max is None):
281            raise ValueError(
282                "The `s_max` should be provided at validation stage, instead of the default value `None`."
283            )
284        if stage == "test" and (test_task_id is None):
285            raise ValueError(
286                "The `task_mask` should be provided at testing stage, instead of the default value `None`."
287            )
288
289        mask = {}
290        if stage == "train":
291            for layer_name in self.weighted_layer_names:
292                anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / (
293                    num_batches - 1
294                )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
295                mask[layer_name] = self.gate_fn(
296                    self.task_embedding_t[layer_name].weight * anneal_scalar
297                ).squeeze()
298        elif stage == "validation":
299            for layer_name in self.weighted_layer_names:
300                mask[layer_name] = self.gate_fn(
301                    self.task_embedding_t[layer_name].weight * s_max
302                ).squeeze()
303        elif stage == "test":
304            mask = self.masks[test_task_id]
305            for layer_name, layer_mask in mask.items():
306                layer = self.get_layer_by_name(layer_name)
307                target_device = layer.weight.device
308                if layer_mask.device != target_device:
309                    mask[layer_name] = layer_mask.to(target_device)
310        elif stage == "unlearning_test":
311            for layer_name in self.weighted_layer_names:
312                layer = self.get_layer_by_name(layer_name)
313                mask[layer_name] = torch.ones(
314                    layer.weight.size(0), device=layer.weight.device
315                )
316
317        return mask

Get the hard attention mask used in the forward() method for different stages.

Args:

  • stage (str): The stage when applying the conversion; one of:
    1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the HAT paper.
    2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by s_max, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
    3. 'test': testing stage. Apply the test mask directly from the stored masks using test_task_id.
    4. 'unlearning_test': unlearning testing stage. The mask is set to all 1s for unlearning testing.
  • s_max (float): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the HAT paper.
  • batch_idx (int | None): The current batch index. Applies only to the training stage. For other stages, it is None.
  • num_batches (int | None): The total number of batches. Applies only to the training stage. For other stages, it is None.
  • test_task_id (int | None): The test task ID. Applies only to the testing stage. For other stages, it is None.

Returns:

  • mask (dict[str, Tensor]): The hard attention (with values 0 or 1) mask. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
def te_to_binary_mask(self) -> dict[str, torch.Tensor]:
319    def te_to_binary_mask(self) -> dict[str, Tensor]:
320        r"""Convert the current task embedding into a binary mask.
321
322        This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task.
323
324        **Returns:**
325        - **mask_t** (`dict[str, Tensor]`): The binary mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
326        """
327        # get the mask for the current task
328        mask_t = {
329            layer_name: (self.task_embedding_t[layer_name].weight > 0)
330            .float()
331            .squeeze()
332            .detach()
333            for layer_name in self.weighted_layer_names
334        }
335
336        return mask_t

Convert the current task embedding into a binary mask.

This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task.

Returns:

  • mask_t (dict[str, Tensor]): The binary mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
def combine_masks( self, masks: list[dict[str, torch.Tensor]], mode: str) -> dict[str, torch.Tensor]:
338    def combine_masks(
339        self, masks: list[dict[str, Tensor]], mode: str
340    ) -> dict[str, Tensor]:
341        r"""Combine multiple masks by taking their element-wise minimum (for intersection) / maximum (for union).
342
343        **Args:**
344        - **masks** (`list[dict[str, Tensor]]`): A list of masks. Each mask is a dict where keys are layer names and values are mask tensors.
345        - **mode** (`str`): The combination mode; one of:
346            - 'intersection': take the element-wise minimum of the masks (for intersection).
347            - 'union': take the element-wise maximum of the masks (for union).
348
349        **Returns:**
350        - **combined_mask** (`dict[str, Tensor]`): The combined mask.
351        """
352
353        combined_mask = {}
354        for layer_name in masks[0].keys():
355            layer_mask_tensors = torch.stack(
356                [mask[layer_name] for mask in masks], dim=0
357            )
358            if mode == "intersection":
359                combined_mask[layer_name] = torch.min(layer_mask_tensors, dim=0).values
360            elif mode == "union":
361                combined_mask[layer_name] = torch.max(layer_mask_tensors, dim=0).values
362            else:
363                raise ValueError(
364                    f"Unsupported mode: {mode}. Use 'intersection' or 'union'."
365                )
366
367        return combined_mask

Combine multiple masks by taking their element-wise minimum (for intersection) / maximum (for union).

Args:

  • masks (list[dict[str, Tensor]]): A list of masks. Each mask is a dict where keys are layer names and values are mask tensors.
  • mode (str): The combination mode; one of:
    • 'intersection': take the element-wise minimum of the masks (for intersection).
    • 'union': take the element-wise maximum of the masks (for union).

Returns:

  • combined_mask (dict[str, Tensor]): The combined mask.
def store_mask(self) -> None:
369    def store_mask(self) -> None:
370        r"""Store the mask for the current task `self.task_id`."""
371        mask_t = self.te_to_binary_mask()
372
373        for subhatmodule in self.modules():
374            if isinstance(subhatmodule, HATMaskBackbone):  # for all sub HAT modules
375                subhatmodule.masks[self.task_id] = mask_t
376
377        return mask_t

Store the mask for the current task self.task_id.

def get_layer_measure_parameter_wise( self, neuron_wise_measure: dict[str, torch.Tensor], layer_name: str, aggregation_mode: str) -> torch.Tensor:
379    def get_layer_measure_parameter_wise(
380        self,
381        neuron_wise_measure: dict[str, Tensor],
382        layer_name: str,
383        aggregation_mode: str,
384    ) -> Tensor:
385        r"""Get the parameter-wise measure on the parameters right before the given layer.
386
387        It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
388
389        Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer.
390
391        This method is used to calculate parameter-wise measures in various HAT-based algorithms:
392
393        - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks `cumulative_mask_for_previous_tasks`, which is $\text{Agg} \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in Eq. (2) in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
394        - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks `summative_mask_for_previous_tasks`, which is $\text{Agg} \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
395        - **FG-AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks `summative_importance_for_previous_tasks`, which is $\text{Agg} \left(I_{l,i}^{<t}, I_{l-1,j}^{<t}\right)$ in Eq. (2) in the FG-AdaHAT paper.
396
397        **Args:**
398        - **neuron_wise_measure** (`dict[str, Tensor]`): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ).
399        - **layer_name** (`str`): The name of the given layer.
400        - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
401            - 'min': takes the minimum of the two connected unit measures.
402            - 'max': takes the maximum of the two connected unit measures.
403            - 'mean': takes the mean of the two connected unit measures.
404
405        **Returns:**
406        - **weight_measure** (`Tensor`): The weight measure matrix, the same size as the corresponding weights.
407        - **bias_measure** (`Tensor`): The bias measure vector, the same size as the corresponding bias.
408        """
409
410        # initialize the aggregation function
411        if aggregation_mode == "min":
412            aggregation_func = torch.min
413        elif aggregation_mode == "max":
414            aggregation_func = torch.max
415        elif aggregation_mode == "mean":
416            aggregation_func = torch.mean
417        else:
418            raise ValueError(
419                f"The aggregation method {aggregation_mode} is not supported."
420            )
421
422        # get the preceding layer
423        preceding_layer_name = self.preceding_layer_name(layer_name)
424
425        # get weight size for expanding the measures
426        layer = self.get_layer_by_name(layer_name)
427        weight_size = layer.weight.size()
428
429        # construct the weight-wise measure
430        layer_measure = neuron_wise_measure[layer_name]
431        layer_measure_broadcast_size = (-1, 1) + tuple(
432            1 for _ in range(len(weight_size) - 2)
433        )  # since the size of mask tensor is (number of units, ), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
434
435        layer_measure_broadcasted = layer_measure.view(
436            *layer_measure_broadcast_size
437        ).expand(
438            weight_size,
439        )  # expand the given layer mask to the weight size and broadcast
440
441        if (
442            preceding_layer_name
443        ):  # if the layer is not the first layer, where the preceding layer exists
444
445            preceding_layer_measure_broadcast_size = (1, -1) + tuple(
446                1 for _ in range(len(weight_size) - 2)
447            )  # since the size of mask tensor is (number of units, ), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
448            preceding_layer_measure = neuron_wise_measure[preceding_layer_name]
449            preceding_layer_measure_broadcasted = preceding_layer_measure.view(
450                *preceding_layer_measure_broadcast_size
451            ).expand(
452                weight_size
453            )  # expand the preceding layer mask to the weight size and broadcast
454            weight_measure = aggregation_func(
455                layer_measure_broadcasted, preceding_layer_measure_broadcasted
456            )  # get the minimum of the two mask vectors, from expanded
457        else:  # if the layer is the first layer
458            weight_measure = layer_measure_broadcasted
459
460        # construct the bias-wise measure
461        bias_measure = layer_measure
462
463        return weight_measure, bias_measure

Get the parameter-wise measure on the parameters right before the given layer.

It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.

Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer.

This method is used to calculate parameter-wise measures in various HAT-based algorithms:

  • HAT: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks cumulative_mask_for_previous_tasks, which is $\text{Agg} \left(a_{l,i}^{HAT paper.
  • AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks summative_mask_for_previous_tasks, which is $\text{Agg} \left(m_{l,i}^{AdaHAT paper.
  • FG-AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks summative_importance_for_previous_tasks, which is $\text{Agg} \left(I_{l,i}^{

Args:

  • neuron_wise_measure (dict[str, Tensor]): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ).
  • layer_name (str): The name of the given layer.
  • aggregation_mode (str): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
    • 'min': takes the minimum of the two connected unit measures.
    • 'max': takes the maximum of the two connected unit measures.
    • 'mean': takes the mean of the two connected unit measures.

Returns:

  • weight_measure (Tensor): The weight measure matrix, the same size as the corresponding weights.
  • bias_measure (Tensor): The bias measure vector, the same size as the corresponding bias.
@override
def forward( self, input: torch.Tensor, stage: str, s_max: float | None = None, batch_idx: int | None = None, num_batches: int | None = None, test_task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor]]:
465    @override
466    def forward(
467        self,
468        input: Tensor,
469        stage: str,
470        s_max: float | None = None,
471        batch_idx: int | None = None,
472        num_batches: int | None = None,
473        test_task_id: int | None = None,
474    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
475        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to the units in each layer.
476
477        **Args:**
478        - **input** (`Tensor`): The input tensor from data.
479        - **stage** (`str`): The stage of the forward pass; one of:
480            1. 'train': training stage.
481            2. 'validation': validation stage.
482            3. 'test': testing stage.
483        - **s_max** (`float`): The maximum scaling factor in the gate function. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
484        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
485        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
486        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
487
488        **Returns:**
489        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
490        - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
491        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class.
492        """

The forward pass for data from task self.task_id. Task-specific masks for self.task_id are applied to the units in each layer.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • s_max (float): The maximum scaling factor in the gate function. See Sec. 2.4 in the HAT paper.
  • batch_idx (int | None): The current batch index. Applies only to the training stage. For other stages, it is None.
  • num_batches (int | None): The total number of batches. Applies only to the training stage. For other stages, it is None.
  • test_task_id (int | None): The test task ID. Applies only to the testing stage. For other stages, it is None.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • mask (dict[str, Tensor]): The mask for the current task. Keys (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this forward() method of the HAT class.
class AmnesiacHATBackbone(clarena.backbones.HATMaskBackbone):
495class AmnesiacHATBackbone(HATMaskBackbone):
496    r"""The backbone network for AmnesiacHAT on top of HAT. AmnesiacHAT introduces a parallel backup backbone in case of effects caused by unlearning."""
497
498    original_backbone_class: type[Backbone]
499    r"""The original backbone class used to instantiate backup backbones. Must be defined in subclasses."""
500
501    def __init__(
502        self,
503        output_dim: int | None,
504        gate: str,
505        disable_unlearning: bool = False,
506        **kwargs,
507    ) -> None:
508        r"""
509        **Args:**
510        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
511        - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of:
512            - `sigmoid`: the sigmoid function.
513        - **disable_unlearning** (`bool`): whether to disable unlearning. This is used in reference experiments following continual learning pipeline. Default is `False`.
514        - **kwargs**: Reserved for multiple inheritance.
515        """
516        super().__init__(output_dim=output_dim, gate=gate, **kwargs)
517
518        self.disable_unlearning: bool = disable_unlearning
519        r"""Whether to disable unlearning. This is used in reference experiments following continual learning pipeline."""
520
521        if not disable_unlearning:
522            self.backup_backbones: nn.ModuleDict
523            r"""The backup backbone networks. Keys are task IDs (in string format because ModuleDict keys have to be strings) that the backbone is backed up in case of which is unlearned, and values are the corresponding backbone networks that the backup is trained on. They all have the same architecture as the main backbone network.
524            
525            Please note that we use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for training. DO NOT change this to `dict`."""
526
527            self.backup_task_ids: list[int]
528            r"""The task IDs that need to have backup backbones at current task `self.task_id`."""
529
530            self.backup_state_dicts: dict[tuple[int, int], dict[str, Tensor]] = {}
531            r"""The backup state dict for each task. Keys are tuples (backup task IDs, the task ID that the backup is for) and values are the corresponding state dicts."""
532
533    def instantiate_backup_backbones(
534        self,
535        backup_task_ids: list[int],
536    ) -> None:
537        r"""Instantiate the backup backbone network for the current task. This is called when a new task is created.
538
539        **Args:**
540        - **backup_task_ids** (`list[int]`): The list of task IDs to backup at current task `self.task_id`.
541        """
542
543        self.backup_task_ids = backup_task_ids
544
545        self.backup_backbones = nn.ModuleDict(
546            {
547                f"{task_id_to_backup}": self.original_backbone_class(
548                    **self.backup_backbone_kwargs,
549                )
550                for task_id_to_backup in backup_task_ids
551            }
552        )
553
554        pylogger.debug(
555            "Backup backbones (backing up task IDs %s) for current task ID %d have been instantiated.",
556            backup_task_ids,
557            self.task_id,
558        )

The backbone network for AmnesiacHAT on top of HAT. AmnesiacHAT introduces a parallel backup backbone in case of effects caused by unlearning.

AmnesiacHATBackbone( output_dim: int | None, gate: str, disable_unlearning: bool = False, **kwargs)
501    def __init__(
502        self,
503        output_dim: int | None,
504        gate: str,
505        disable_unlearning: bool = False,
506        **kwargs,
507    ) -> None:
508        r"""
509        **Args:**
510        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
511        - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of:
512            - `sigmoid`: the sigmoid function.
513        - **disable_unlearning** (`bool`): whether to disable unlearning. This is used in reference experiments following continual learning pipeline. Default is `False`.
514        - **kwargs**: Reserved for multiple inheritance.
515        """
516        super().__init__(output_dim=output_dim, gate=gate, **kwargs)
517
518        self.disable_unlearning: bool = disable_unlearning
519        r"""Whether to disable unlearning. This is used in reference experiments following continual learning pipeline."""
520
521        if not disable_unlearning:
522            self.backup_backbones: nn.ModuleDict
523            r"""The backup backbone networks. Keys are task IDs (in string format because ModuleDict keys have to be strings) that the backbone is backed up in case of which is unlearned, and values are the corresponding backbone networks that the backup is trained on. They all have the same architecture as the main backbone network.
524            
525            Please note that we use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for training. DO NOT change this to `dict`."""
526
527            self.backup_task_ids: list[int]
528            r"""The task IDs that need to have backup backbones at current task `self.task_id`."""
529
530            self.backup_state_dicts: dict[tuple[int, int], dict[str, Tensor]] = {}
531            r"""The backup state dict for each task. Keys are tuples (backup task IDs, the task ID that the backup is for) and values are the corresponding state dicts."""

Args:

  • output_dim (int): The output dimension that connects to CL output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • gate (str): The type of gate function turning the real value task embeddings into attention masks; one of:
    • sigmoid: the sigmoid function.
  • disable_unlearning (bool): whether to disable unlearning. This is used in reference experiments following continual learning pipeline. Default is False.
  • kwargs: Reserved for multiple inheritance.
original_backbone_class: type[Backbone]

The original backbone class used to instantiate backup backbones. Must be defined in subclasses.

disable_unlearning: bool

Whether to disable unlearning. This is used in reference experiments following continual learning pipeline.

def instantiate_backup_backbones(self, backup_task_ids: list[int]) -> None:
533    def instantiate_backup_backbones(
534        self,
535        backup_task_ids: list[int],
536    ) -> None:
537        r"""Instantiate the backup backbone network for the current task. This is called when a new task is created.
538
539        **Args:**
540        - **backup_task_ids** (`list[int]`): The list of task IDs to backup at current task `self.task_id`.
541        """
542
543        self.backup_task_ids = backup_task_ids
544
545        self.backup_backbones = nn.ModuleDict(
546            {
547                f"{task_id_to_backup}": self.original_backbone_class(
548                    **self.backup_backbone_kwargs,
549                )
550                for task_id_to_backup in backup_task_ids
551            }
552        )
553
554        pylogger.debug(
555            "Backup backbones (backing up task IDs %s) for current task ID %d have been instantiated.",
556            backup_task_ids,
557            self.task_id,
558        )

Instantiate the backup backbone network for the current task. This is called when a new task is created.

Args:

  • backup_task_ids (list[int]): The list of task IDs to backup at current task self.task_id.
class WSNMaskBackbone(clarena.backbones.CLBackbone):
561class WSNMaskBackbone(CLBackbone):
562    r"""The backbone network for the WSN algorithm with learnable parameter masks.
563
564    [WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) is an architecture-based continual learning algorithm. It trains learnable parameter-wise scores and selects the most scored $c\%$ of the network parameters to be used for each task.
565    """
566
567    def __init__(self, output_dim: int | None, **kwargs) -> None:
568        r"""
569        **Args:**
570        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
571        - **kwargs**: Reserved for multiple inheritance.
572        """
573        super().__init__(output_dim=output_dim, **kwargs)
574
575        self.gate_fn: torch.autograd.Function = PercentileLayerParameterMaskingByScore
576        r"""The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score."""
577
578        self.weight_score_t: nn.ModuleDict = nn.ModuleDict()
579        r"""The weight score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (output features, input features) as the weight.
580        
581        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
582        
583        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
584        
585        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
586        """
587
588        self.bias_score_t: nn.ModuleDict = nn.ModuleDict()
589        r"""The bias score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is `None`.
590        
591        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
592        
593        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
594        
595        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
596        """
597
598        WSNMaskBackbone.sanity_check(self)
599
600    def sanity_check(self) -> None:
601        r"""Sanity check."""
602
603    def initialize_parameter_score(self, mode: str) -> None:
604        r"""Initialize the parameter score for the current task.
605
606        **Args:**
607        - **mode** (`str`): The initialization mode for parameter scores; one of:
608            1. 'default': the default initialization mode in the original WSN code.
609            2. 'N01': standard normal distribution $N(0, 1)$.
610            3. 'U01': uniform distribution $U(0, 1)$.
611        """
612
613        for layer_name, weight_score in self.weight_score_t.items():
614            if mode == "default":
615                # Kaiming Uniform Initialization for weight score
616                nn.init.kaiming_uniform_(weight_score.weight, a=math.sqrt(5))
617
618                for layer_name, bias_score in self.bias_score_t.items():
619                    if bias_score is not None:
620                        # For bias, follow the standard bias initialization using fan_in
621                        weight_score = self.weight_score_t[layer_name]
622                        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
623                            weight_score.weight
624                        )
625                        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
626                        nn.init.uniform_(bias_score.weight, -bound, bound)
627            elif mode == "N01":
628                nn.init.normal_(weight_score.weight, 0, 1)
629                for layer_name, bias_score in self.bias_score_t.items():
630                    if bias_score is not None:
631                        nn.init.normal_(bias_score.weight, 0, 1)
632            elif mode == "U01":
633                nn.init.uniform_(weight_score.weight, 0, 1)
634                for layer_name, bias_score in self.bias_score_t.items():
635                    if bias_score is not None:
636                        nn.init.uniform_(bias_score.weight, 0, 1)
637
638    def get_mask(
639        self,
640        stage: str,
641        mask_percentage: float,
642        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
643    ) -> dict[str, Tensor]:
644        r"""Get the binary parameter mask used in the `forward()` method for different stages.
645
646        **Args:**
647        - **stage** (`str`): The stage when applying the conversion; one of:
648            1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
649            2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
650            3. 'test': testing stage. Apply the test mask directly from the argument `test_mask`.
651        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`.
652
653        **Returns:**
654        - **weight_mask** (`dict[str, Tensor]`): The binary mask on weights. Key (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
655        - **bias_mask** (`dict[str, Tensor]`): The binary mask on biases. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
656        """
657        weight_mask = {}
658        bias_mask = {}
659        if stage == "train" or stage == "validation":
660            for layer_name in self.weighted_layer_names:
661                weight_mask[layer_name] = self.gate_fn.apply(
662                    self.weight_score_t[layer_name].weight, mask_percentage
663                )
664                if self.bias_score_t[layer_name] is not None:
665                    bias_mask[layer_name] = self.gate_fn.apply(
666                        self.bias_score_t[layer_name].weight.squeeze(
667                            0
668                        ),  # from (1, output_dim) to (output_dim, )
669                        mask_percentage,
670                    )
671                else:
672                    bias_mask[layer_name] = None
673        elif stage == "test":
674            weight_mask, bias_mask = test_mask
675
676        return weight_mask, bias_mask
677
678    @override
679    def forward(
680        self,
681        input: Tensor,
682        stage: str,
683        mask_percentage: float,
684        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
685    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]:
686        r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units in each layer.
687
688        **Args:**
689        - **input** (`Tensor`): The input tensor from data.
690        - **stage** (`str`): The stage of the forward pass; one of:
691            1. 'train': training stage.
692            2. 'validation': validation stage.
693            3. 'test': testing stage.
694        - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1.
695        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is `None`.
696
697        **Returns:**
698        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
699        - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Key (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, input features) as the weight.
700        - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
701        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
702        """

The backbone network for the WSN algorithm with learnable parameter masks.

WSN (Winning Subnetworks, 2022) is an architecture-based continual learning algorithm. It trains learnable parameter-wise scores and selects the most scored $c\%$ of the network parameters to be used for each task.

WSNMaskBackbone(output_dim: int | None, **kwargs)
567    def __init__(self, output_dim: int | None, **kwargs) -> None:
568        r"""
569        **Args:**
570        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
571        - **kwargs**: Reserved for multiple inheritance.
572        """
573        super().__init__(output_dim=output_dim, **kwargs)
574
575        self.gate_fn: torch.autograd.Function = PercentileLayerParameterMaskingByScore
576        r"""The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score."""
577
578        self.weight_score_t: nn.ModuleDict = nn.ModuleDict()
579        r"""The weight score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (output features, input features) as the weight.
580        
581        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
582        
583        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
584        
585        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
586        """
587
588        self.bias_score_t: nn.ModuleDict = nn.ModuleDict()
589        r"""The bias score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is `None`.
590        
591        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
592        
593        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
594        
595        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
596        """
597
598        WSNMaskBackbone.sanity_check(self)

Args:

  • output_dim (int): The output dimension that connects to CL output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • kwargs: Reserved for multiple inheritance.
gate_fn: torch.autograd.function.Function

The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score.

weight_score_t: torch.nn.modules.container.ModuleDict

The weight score for the current task self.task_id. Keys are the layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has the same size (output features, input features) as the weight.

We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.

We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)

This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.

bias_score_t: torch.nn.modules.container.ModuleDict

The bias score for the current task self.task_id. Keys are the layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is None.

We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.

We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)

This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.

def sanity_check(self) -> None:
600    def sanity_check(self) -> None:
601        r"""Sanity check."""

Sanity check.

def initialize_parameter_score(self, mode: str) -> None:
603    def initialize_parameter_score(self, mode: str) -> None:
604        r"""Initialize the parameter score for the current task.
605
606        **Args:**
607        - **mode** (`str`): The initialization mode for parameter scores; one of:
608            1. 'default': the default initialization mode in the original WSN code.
609            2. 'N01': standard normal distribution $N(0, 1)$.
610            3. 'U01': uniform distribution $U(0, 1)$.
611        """
612
613        for layer_name, weight_score in self.weight_score_t.items():
614            if mode == "default":
615                # Kaiming Uniform Initialization for weight score
616                nn.init.kaiming_uniform_(weight_score.weight, a=math.sqrt(5))
617
618                for layer_name, bias_score in self.bias_score_t.items():
619                    if bias_score is not None:
620                        # For bias, follow the standard bias initialization using fan_in
621                        weight_score = self.weight_score_t[layer_name]
622                        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
623                            weight_score.weight
624                        )
625                        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
626                        nn.init.uniform_(bias_score.weight, -bound, bound)
627            elif mode == "N01":
628                nn.init.normal_(weight_score.weight, 0, 1)
629                for layer_name, bias_score in self.bias_score_t.items():
630                    if bias_score is not None:
631                        nn.init.normal_(bias_score.weight, 0, 1)
632            elif mode == "U01":
633                nn.init.uniform_(weight_score.weight, 0, 1)
634                for layer_name, bias_score in self.bias_score_t.items():
635                    if bias_score is not None:
636                        nn.init.uniform_(bias_score.weight, 0, 1)

Initialize the parameter score for the current task.

Args:

  • mode (str): The initialization mode for parameter scores; one of:
    1. 'default': the default initialization mode in the original WSN code.
    2. 'N01': standard normal distribution $N(0, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
def get_mask( self, stage: str, mask_percentage: float, test_mask: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] | None = None) -> dict[str, torch.Tensor]:
638    def get_mask(
639        self,
640        stage: str,
641        mask_percentage: float,
642        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
643    ) -> dict[str, Tensor]:
644        r"""Get the binary parameter mask used in the `forward()` method for different stages.
645
646        **Args:**
647        - **stage** (`str`): The stage when applying the conversion; one of:
648            1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
649            2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
650            3. 'test': testing stage. Apply the test mask directly from the argument `test_mask`.
651        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`.
652
653        **Returns:**
654        - **weight_mask** (`dict[str, Tensor]`): The binary mask on weights. Key (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
655        - **bias_mask** (`dict[str, Tensor]`): The binary mask on biases. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
656        """
657        weight_mask = {}
658        bias_mask = {}
659        if stage == "train" or stage == "validation":
660            for layer_name in self.weighted_layer_names:
661                weight_mask[layer_name] = self.gate_fn.apply(
662                    self.weight_score_t[layer_name].weight, mask_percentage
663                )
664                if self.bias_score_t[layer_name] is not None:
665                    bias_mask[layer_name] = self.gate_fn.apply(
666                        self.bias_score_t[layer_name].weight.squeeze(
667                            0
668                        ),  # from (1, output_dim) to (output_dim, )
669                        mask_percentage,
670                    )
671                else:
672                    bias_mask[layer_name] = None
673        elif stage == "test":
674            weight_mask, bias_mask = test_mask
675
676        return weight_mask, bias_mask

Get the binary parameter mask used in the forward() method for different stages.

Args:

  • stage (str): The stage when applying the conversion; one of:
    1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the WSN paper.
    2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
    3. 'test': testing stage. Apply the test mask directly from the argument test_mask.
  • test_mask (tuple[dict[str, Tensor], dict[str, Tensor]] | None): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is None.

Returns:

  • weight_mask (dict[str, Tensor]): The binary mask on weights. Key (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
  • bias_mask (dict[str, Tensor]): The binary mask on biases. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is None.
@override
def forward( self, input: torch.Tensor, stage: str, mask_percentage: float, test_mask: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, torch.Tensor]]:
678    @override
679    def forward(
680        self,
681        input: Tensor,
682        stage: str,
683        mask_percentage: float,
684        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
685    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]:
686        r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units in each layer.
687
688        **Args:**
689        - **input** (`Tensor`): The input tensor from data.
690        - **stage** (`str`): The stage of the forward pass; one of:
691            1. 'train': training stage.
692            2. 'validation': validation stage.
693            3. 'test': testing stage.
694        - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1.
695        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is `None`.
696
697        **Returns:**
698        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
699        - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Key (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, input features) as the weight.
700        - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
701        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
702        """

The forward pass for data from task self.task_id. Task-specific mask for self.task_id are applied to the units in each layer.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • mask_percentage (float): The percentage of parameters to be masked. The value should be between 0 and 1.
  • test_mask (tuple[dict[str, Tensor], dict[str, Tensor]] | None): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is None.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • weight_mask (dict[str, Tensor]): The weight mask for the current task. Key (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has same (output features, input features) as the weight.
  • bias_mask (dict[str, Tensor]): The bias mask for the current task. Keys (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is None.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes.