clarena.backbones

Backbone Networks

This submodule provides the backbone neural network architectures for all paradigms in CLArena.

Here are the base classes for backbone networks, which inherit from PyTorch nn.Module:

Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement backbone networks:

 1r"""
 2
 3# Backbone Networks
 4
 5This submodule provides the **backbone neural network architectures** for all paradigms in CLArena.
 6
 7Here are the base classes for backbone networks, which inherit from PyTorch `nn.Module`:
 8
 9- `Backbone`: the base class for all backbone networks. Multi-task and single-task learning can use this class directly.
10-   `CLBackbone`: the base class for continual learning backbone networks, which incorporates mechanisms for managing continual learning tasks.
11    - `HATMaskBackbone`: the base class for backbones used in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) CL algorithm.
12    - `WSNMaskBackbone`: The base class for backbones used in [WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) CL algorithm.
13
14Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement backbone networks:
15
16- [**Configure Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/backbone-network)
17- [**Implement Custom Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/backbone-network)
18
19
20"""
21
22from .base import (
23    Backbone,
24    CLBackbone,
25    HATMaskBackbone,
26    WSNMaskBackbone,
27)
28from .mlp import MLP, CLMLP, HATMaskMLP, WSNMaskMLP
29from .resnet import (
30    ResNet18,
31    ResNet34,
32    ResNet50,
33    ResNet101,
34    ResNet152,
35    CLResNet18,
36    CLResNet34,
37    CLResNet50,
38    CLResNet101,
39    CLResNet152,
40    HATMaskResNet18,
41    HATMaskResNet34,
42    HATMaskResNet50,
43    HATMaskResNet101,
44    HATMaskResNet152,
45)
46
47
48__all__ = [
49    "Backbone",
50    "CLBackbone",
51    "HATMaskBackbone",
52    "WSNMaskBackbone",
53    "mlp",
54    "resnet",
55]
class Backbone(torch.nn.modules.module.Module):
 25class Backbone(nn.Module):
 26    r"""The base class for backbone networks."""
 27
 28    def __init__(self, output_dim: int | None, **kwargs) -> None:
 29        r"""
 30        **Args:**
 31        - **output_dim** (`int` | `None`): The output dimension that connects to output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
 32        - **kwargs**: Reserved for multiple inheritance.
 33        """
 34        super().__init__()
 35
 36        self.output_dim: int = output_dim
 37        r"""The output dimension of the backbone network."""
 38
 39        self.weighted_layer_names: list[str] = []
 40        r"""The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. **It must be provided in subclasses.**
 41        
 42        The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the `nn.Module` internal naming mechanism with `.` replaced with `/`. For example: 
 43        - If a layer is assigned to `self.conv1`, the name becomes `conv1`. 
 44        - If `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. 
 45        - If a hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should have been `block.conv1` according to `nn.Module`'s rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a `ModuleDict`.
 46        """
 47
 48    def get_layer_by_name(self, layer_name: str | None) -> nn.Module | None:
 49        r"""Get the layer by its name.
 50
 51        **Args:**
 52        - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`.
 53
 54        **Returns:**
 55        - **layer** (`nn.Module` | `None`): The layer. If `layer_name` is `None`, return `None`.
 56        """
 57        if layer_name is None:
 58            return None
 59
 60        for name, layer in self.named_modules():
 61            if name == layer_name.replace("/", "."):
 62                return layer
 63
 64    def preceding_layer_name(self, layer_name: str | None) -> str | None:
 65        r"""Get the name of the preceding layer of the given layer from the stored `weighted_layer_names`.
 66
 67        **Args:**
 68        - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`.
 69
 70        **Returns:**
 71        - **preceding_layer_name** (`str`): The name of the preceding layer. If the given layer is the first layer, return `None`.
 72        """
 73        if layer_name is None:
 74            return None
 75
 76        if layer_name not in self.weighted_layer_names:
 77            raise ValueError(
 78                f"The layer name {layer_name} doesn't exist in weighted layer names."
 79            )
 80
 81        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
 82        if weighted_layer_idx == 0:
 83            return None
 84        preceding_layer_name = self.weighted_layer_names[weighted_layer_idx - 1]
 85        return preceding_layer_name
 86
 87    def next_layer_name(self, layer_name: str) -> str:
 88        r"""Get the name of the next layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the last layer of the BACKBONE, return `None`.
 89
 90        **Args:**
 91        - **layer_name** (`str`): The name of the layer.
 92
 93        **Returns:**
 94        - **next_layer_name** (`str`): The name of the next layer.
 95
 96        **Raises:**
 97        - **ValueError**: If `layer_name` is not in the weighted layer order.
 98        """
 99
100        if layer_name not in self.weighted_layer_names:
101            raise ValueError(f"The layer name {layer_name} doesn't exist.")
102
103        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
104        if weighted_layer_idx == len(self.weighted_layer_names) - 1:
105            return None
106        next_layer_name = self.weighted_layer_names[weighted_layer_idx + 1]
107        return next_layer_name
108
109    @override  # since `nn.Module` uses it
110    def forward(
111        self,
112        input: Tensor,
113        stage: str,
114    ) -> tuple[Tensor, dict[str, Tensor]]:
115        r"""The forward pass. **It must be implemented by subclasses.**
116
117        **Args:**
118        - **input** (`Tensor`): The input tensor from data.
119        - **stage** (`str`): The stage of the forward pass; one of:
120            1. 'train': training stage.
121            2. 'validation': validation stage.
122            3. 'test': testing stage.
123
124        **Returns:**
125        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
126        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes.
127        """

The base class for backbone networks.

Backbone(output_dim: int | None, **kwargs)
28    def __init__(self, output_dim: int | None, **kwargs) -> None:
29        r"""
30        **Args:**
31        - **output_dim** (`int` | `None`): The output dimension that connects to output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
32        - **kwargs**: Reserved for multiple inheritance.
33        """
34        super().__init__()
35
36        self.output_dim: int = output_dim
37        r"""The output dimension of the backbone network."""
38
39        self.weighted_layer_names: list[str] = []
40        r"""The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. **It must be provided in subclasses.**
41        
42        The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the `nn.Module` internal naming mechanism with `.` replaced with `/`. For example: 
43        - If a layer is assigned to `self.conv1`, the name becomes `conv1`. 
44        - If `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. 
45        - If a hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should have been `block.conv1` according to `nn.Module`'s rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a `ModuleDict`.
46        """

Args:

  • output_dim (int | None): The output dimension that connects to output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • kwargs: Reserved for multiple inheritance.
output_dim: int

The output dimension of the backbone network.

weighted_layer_names: list[str]

The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. It must be provided in subclasses.

The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the nn.Module internal naming mechanism with . replaced with /. For example:

  • If a layer is assigned to self.conv1, the name becomes conv1.
  • If nn.Sequential is used, the name becomes the index of the layer in the sequence, such as 0, 1, etc.
  • If a hierarchical structure is used, for example, a nn.Module is assigned to self.block which has self.conv1, the name becomes block/conv1. Note that it should have been block.conv1 according to nn.Module's rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a ModuleDict.
def get_layer_by_name(self, layer_name: str | None) -> torch.nn.modules.module.Module | None:
48    def get_layer_by_name(self, layer_name: str | None) -> nn.Module | None:
49        r"""Get the layer by its name.
50
51        **Args:**
52        - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`.
53
54        **Returns:**
55        - **layer** (`nn.Module` | `None`): The layer. If `layer_name` is `None`, return `None`.
56        """
57        if layer_name is None:
58            return None
59
60        for name, layer in self.named_modules():
61            if name == layer_name.replace("/", "."):
62                return layer

Get the layer by its name.

Args:

  • layer_name (str | None): The layer name following the nn.Module internal naming mechanism with . replaced with /. If None, return None.

Returns:

  • layer (nn.Module | None): The layer. If layer_name is None, return None.
def preceding_layer_name(self, layer_name: str | None) -> str | None:
64    def preceding_layer_name(self, layer_name: str | None) -> str | None:
65        r"""Get the name of the preceding layer of the given layer from the stored `weighted_layer_names`.
66
67        **Args:**
68        - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`.
69
70        **Returns:**
71        - **preceding_layer_name** (`str`): The name of the preceding layer. If the given layer is the first layer, return `None`.
72        """
73        if layer_name is None:
74            return None
75
76        if layer_name not in self.weighted_layer_names:
77            raise ValueError(
78                f"The layer name {layer_name} doesn't exist in weighted layer names."
79            )
80
81        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
82        if weighted_layer_idx == 0:
83            return None
84        preceding_layer_name = self.weighted_layer_names[weighted_layer_idx - 1]
85        return preceding_layer_name

Get the name of the preceding layer of the given layer from the stored weighted_layer_names.

Args:

  • layer_name (str | None): The layer name following the nn.Module internal naming mechanism with . replaced with /. If None, return None.

Returns:

  • preceding_layer_name (str): The name of the preceding layer. If the given layer is the first layer, return None.
def next_layer_name(self, layer_name: str) -> str:
 87    def next_layer_name(self, layer_name: str) -> str:
 88        r"""Get the name of the next layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the last layer of the BACKBONE, return `None`.
 89
 90        **Args:**
 91        - **layer_name** (`str`): The name of the layer.
 92
 93        **Returns:**
 94        - **next_layer_name** (`str`): The name of the next layer.
 95
 96        **Raises:**
 97        - **ValueError**: If `layer_name` is not in the weighted layer order.
 98        """
 99
100        if layer_name not in self.weighted_layer_names:
101            raise ValueError(f"The layer name {layer_name} doesn't exist.")
102
103        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
104        if weighted_layer_idx == len(self.weighted_layer_names) - 1:
105            return None
106        next_layer_name = self.weighted_layer_names[weighted_layer_idx + 1]
107        return next_layer_name

Get the name of the next layer of the given layer from the stored self.masked_layer_order. If the given layer is the last layer of the BACKBONE, return None.

Args:

  • layer_name (str): The name of the layer.

Returns:

  • next_layer_name (str): The name of the next layer.

Raises:

  • ValueError: If layer_name is not in the weighted layer order.
@override
def forward( self, input: torch.Tensor, stage: str) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
109    @override  # since `nn.Module` uses it
110    def forward(
111        self,
112        input: Tensor,
113        stage: str,
114    ) -> tuple[Tensor, dict[str, Tensor]]:
115        r"""The forward pass. **It must be implemented by subclasses.**
116
117        **Args:**
118        - **input** (`Tensor`): The input tensor from data.
119        - **stage** (`str`): The stage of the forward pass; one of:
120            1. 'train': training stage.
121            2. 'validation': validation stage.
122            3. 'test': testing stage.
123
124        **Returns:**
125        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
126        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes.
127        """

The forward pass. It must be implemented by subclasses.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes.
class CLBackbone(clarena.backbones.Backbone):
130class CLBackbone(Backbone):
131    r"""The base class of continual learning backbone networks."""
132
133    def __init__(self, output_dim: int | None, **kwargs) -> None:
134        r"""
135        **Args:**
136        - **output_dim** (`int` | `None`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
137        - **kwargs**: Reserved for multiple inheritance.
138        """
139        super().__init__(output_dim=output_dim, **kwargs)
140
141        # task ID control
142        self.task_id: int
143        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
144        self.processed_task_ids: list[int] = []
145        r"""Task IDs that have been processed."""
146
147    def setup_task_id(self, task_id: int) -> None:
148        r"""Set up task `task_id`. This must be done before the `forward()` method is called."""
149        self.task_id = task_id
150        self.processed_task_ids.append(task_id)
151
152    @override  # since `nn.Module` uses it
153    def forward(
154        self,
155        input: Tensor,
156        stage: str,
157        task_id: int | None = None,
158    ) -> tuple[Tensor, dict[str, Tensor]]:
159        r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.**
160
161        **Args:**
162        - **input** (`Tensor`): The input tensor from data.
163        - **stage** (`str`): The stage of the forward pass; one of:
164            1. 'train': training stage.
165            2. 'validation': validation stage.
166            3. 'test': testing stage.
167        - **task_id** (`int` | `None`): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task `self.task_id`. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value.
168
169        **Returns:**
170        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
171        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
172        """

The base class of continual learning backbone networks.

CLBackbone(output_dim: int | None, **kwargs)
133    def __init__(self, output_dim: int | None, **kwargs) -> None:
134        r"""
135        **Args:**
136        - **output_dim** (`int` | `None`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
137        - **kwargs**: Reserved for multiple inheritance.
138        """
139        super().__init__(output_dim=output_dim, **kwargs)
140
141        # task ID control
142        self.task_id: int
143        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset."""
144        self.processed_task_ids: list[int] = []
145        r"""Task IDs that have been processed."""

Args:

  • output_dim (int | None): The output dimension that connects to CL output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • kwargs: Reserved for multiple inheritance.
task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.

processed_task_ids: list[int]

Task IDs that have been processed.

def setup_task_id(self, task_id: int) -> None:
147    def setup_task_id(self, task_id: int) -> None:
148        r"""Set up task `task_id`. This must be done before the `forward()` method is called."""
149        self.task_id = task_id
150        self.processed_task_ids.append(task_id)

Set up task task_id. This must be done before the forward() method is called.

@override
def forward( self, input: torch.Tensor, stage: str, task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
152    @override  # since `nn.Module` uses it
153    def forward(
154        self,
155        input: Tensor,
156        stage: str,
157        task_id: int | None = None,
158    ) -> tuple[Tensor, dict[str, Tensor]]:
159        r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.**
160
161        **Args:**
162        - **input** (`Tensor`): The input tensor from data.
163        - **stage** (`str`): The stage of the forward pass; one of:
164            1. 'train': training stage.
165            2. 'validation': validation stage.
166            3. 'test': testing stage.
167        - **task_id** (`int` | `None`): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task `self.task_id`. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value.
168
169        **Returns:**
170        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
171        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
172        """

The forward pass for data from task task_id. In some backbones, the forward pass might be different for different tasks. It must be implemented by subclasses.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • task_id (int | None): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task self.task_id. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
class HATMaskBackbone(clarena.backbones.CLBackbone):
175class HATMaskBackbone(CLBackbone):
176    r"""The backbone network for HAT-based algorithms with learnable hard attention masks.
177
178    HAT-based algorithms include:
179
180    - [**HAT (Hard Attention to the Task, 2018)**](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
181    - [**AdaHAT (Adaptive Hard Attention to the Task, 2024)**](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves HAT by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
182    - **FG-AdaHAT** is an architecture-based continual learning approach that improves HAT by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
183    """
184
185    def __init__(self, output_dim: int | None, gate: str, **kwargs) -> None:
186        r"""
187        **Args:**
188        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
189        - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of:
190            - `sigmoid`: the sigmoid function.
191            - `tanh`: the hyperbolic tangent function.
192        - **kwargs**: Reserved for multiple inheritance.
193        """
194        super().__init__(output_dim=output_dim, **kwargs)
195
196        self.gate: str = gate
197        r"""The type of gate function."""
198        self.gate_fn: Callable
199        r"""The gate function mapping the real value task embeddings into attention masks."""
200
201        if gate == "sigmoid":
202            self.gate_fn = nn.Sigmoid()
203
204        self.task_embedding_t: nn.ModuleDict = nn.ModuleDict()
205        r"""The task embedding for the current task `self.task_id`. Keys are layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units).
206        
207        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
208        
209        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
210        
211        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
212        """
213
214        self.masks: dict[int, dict[str, Tensor]] = {}
215        r"""The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ). """
216
217        HATMaskBackbone.sanity_check(self)
218
219    def initialize_task_embedding(self, mode: str) -> None:
220        r"""Initialize the task embedding for the current task `self.task_id`.
221
222        **Args:**
223        - **mode** (`str`): The initialization mode for task embeddings; one of:
224            1. 'N01' (default): standard normal distribution $N(0, 1)$.
225            2. 'U-11': uniform distribution $U(-1, 1)$.
226            3. 'U01': uniform distribution $U(0, 1)$.
227            4. 'U-10': uniform distribution $U(-1, 0)$.
228            5. 'last': inherit task embeddings from the last task.
229        """
230        for te in self.task_embedding_t.values():
231            if mode == "N01":
232                nn.init.normal_(te.weight, 0, 1)
233            elif mode == "U-11":
234                nn.init.uniform_(te.weight, -1, 1)
235            elif mode == "U01":
236                nn.init.uniform_(te.weight, 0, 1)
237            elif mode == "U-10":
238                nn.init.uniform_(te.weight, -1, 0)
239            elif mode == "last":
240                pass
241
242    def sanity_check(self) -> None:
243        r"""Sanity check."""
244
245        if self.gate not in ["sigmoid"]:
246            raise ValueError("The gate should be one of: 'sigmoid'.")
247
248    def get_mask(
249        self,
250        stage: str,
251        s_max: float | None = None,
252        batch_idx: int | None = None,
253        num_batches: int | None = None,
254        test_task_id: int | None = None,
255    ) -> dict[str, Tensor]:
256        r"""Get the hard attention mask used in the `forward()` method for different stages.
257
258        **Args:**
259        - **stage** (`str`): The stage when applying the conversion; one of:
260            1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
261            2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by `s_max`, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
262            3. 'test': testing stage. Apply the test mask directly from the stored masks using `test_task_id`.
263        - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
264        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
265        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
266        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
267
268        **Returns:**
269        - **mask** (`dict[str, Tensor]`): The hard attention (with values 0 or 1) mask. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
270        """
271
272        # sanity check
273        if stage == "train" and (
274            s_max is None or batch_idx is None or num_batches is None
275        ):
276            raise ValueError(
277                "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`."
278            )
279        if stage == "validation" and (s_max is None):
280            raise ValueError(
281                "The `s_max` should be provided at validation stage, instead of the default value `None`."
282            )
283        if stage == "test" and (test_task_id is None):
284            raise ValueError(
285                "The `task_mask` should be provided at testing stage, instead of the default value `None`."
286            )
287
288        mask = {}
289        if stage == "train":
290            for layer_name in self.weighted_layer_names:
291                anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / (
292                    num_batches - 1
293                )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
294                mask[layer_name] = self.gate_fn(
295                    self.task_embedding_t[layer_name].weight * anneal_scalar
296                ).squeeze()
297        elif stage == "validation":
298            for layer_name in self.weighted_layer_names:
299                mask[layer_name] = self.gate_fn(
300                    self.task_embedding_t[layer_name].weight * s_max
301                ).squeeze()
302        elif stage == "test":
303            mask = self.masks[test_task_id]
304
305        return mask
306
307    def te_to_binary_mask(self) -> dict[str, Tensor]:
308        r"""Convert the current task embedding into a binary mask.
309
310        This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task.
311
312        **Returns:**
313        - **mask_t** (`dict[str, Tensor]`): The binary mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
314        """
315        # get the mask for the current task
316        mask_t = {
317            layer_name: (self.task_embedding_t[layer_name].weight > 0)
318            .float()
319            .squeeze()
320            .detach()
321            for layer_name in self.weighted_layer_names
322        }
323
324        return mask_t
325
326    def store_mask(self) -> None:
327        r"""Store the mask for the current task `self.task_id`."""
328        mask_t = self.te_to_binary_mask()
329        self.masks[self.task_id] = mask_t
330
331        return mask_t
332
333    def get_layer_measure_parameter_wise(
334        self,
335        neuron_wise_measure: dict[str, Tensor],
336        layer_name: str,
337        aggregation_mode: str,
338    ) -> Tensor:
339        r"""Get the parameter-wise measure on the parameters right before the given layer.
340
341        It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
342
343        Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer.
344
345        This method is used to calculate parameter-wise measures in various HAT-based algorithms:
346
347        - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks `cumulative_mask_for_previous_tasks`, which is $\text{Agg} \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in Eq. (2) in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
348        - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks `summative_mask_for_previous_tasks`, which is $\text{Agg} \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
349        - **FG-AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks `summative_importance_for_previous_tasks`, which is $\text{Agg} \left(I_{l,i}^{<t}, I_{l-1,j}^{<t}\right)$ in Eq. (2) in the FG-AdaHAT paper.
350
351        **Args:**
352        - **neuron_wise_measure** (`dict[str, Tensor]`): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ).
353        - **layer_name** (`str`): The name of the given layer.
354        - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
355            - 'min': takes the minimum of the two connected unit measures.
356            - 'max': takes the maximum of the two connected unit measures.
357            - 'mean': takes the mean of the two connected unit measures.
358
359        **Returns:**
360        - **weight_measure** (`Tensor`): The weight measure matrix, the same size as the corresponding weights.
361        - **bias_measure** (`Tensor`): The bias measure vector, the same size as the corresponding bias.
362        """
363
364        # initialize the aggregation function
365        if aggregation_mode == "min":
366            aggregation_func = torch.min
367        elif aggregation_mode == "max":
368            aggregation_func = torch.max
369        elif aggregation_mode == "mean":
370            aggregation_func = torch.mean
371        else:
372            raise ValueError(
373                f"The aggregation method {aggregation_mode} is not supported."
374            )
375
376        # get the preceding layer
377        preceding_layer_name = self.preceding_layer_name(layer_name)
378
379        # get weight size for expanding the measures
380        layer = self.get_layer_by_name(layer_name)
381        weight_size = layer.weight.size()
382
383        # construct the weight-wise measure
384        layer_measure = neuron_wise_measure[layer_name]
385        layer_measure_broadcast_size = (-1, 1) + tuple(
386            1 for _ in range(len(weight_size) - 2)
387        )  # since the size of mask tensor is (number of units, ), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
388
389        layer_measure_broadcasted = layer_measure.view(
390            *layer_measure_broadcast_size
391        ).expand(
392            weight_size,
393        )  # expand the given layer mask to the weight size and broadcast
394
395        if (
396            preceding_layer_name
397        ):  # if the layer is not the first layer, where the preceding layer exists
398
399            preceding_layer_measure_broadcast_size = (1, -1) + tuple(
400                1 for _ in range(len(weight_size) - 2)
401            )  # since the size of mask tensor is (number of units, ), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
402            preceding_layer_measure = neuron_wise_measure[preceding_layer_name]
403            preceding_layer_measure_broadcasted = preceding_layer_measure.view(
404                *preceding_layer_measure_broadcast_size
405            ).expand(
406                weight_size
407            )  # expand the preceding layer mask to the weight size and broadcast
408            weight_measure = aggregation_func(
409                layer_measure_broadcasted, preceding_layer_measure_broadcasted
410            )  # get the minimum of the two mask vectors, from expanded
411        else:  # if the layer is the first layer
412            weight_measure = layer_measure_broadcasted
413
414        # construct the bias-wise measure
415        bias_measure = layer_measure
416
417        return weight_measure, bias_measure
418
419    @override
420    def forward(
421        self,
422        input: Tensor,
423        stage: str,
424        s_max: float | None = None,
425        batch_idx: int | None = None,
426        num_batches: int | None = None,
427        test_task_id: int | None = None,
428    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
429        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to the units in each layer.
430
431        **Args:**
432        - **input** (`Tensor`): The input tensor from data.
433        - **stage** (`str`): The stage of the forward pass; one of:
434            1. 'train': training stage.
435            2. 'validation': validation stage.
436            3. 'test': testing stage.
437        - **s_max** (`float`): The maximum scaling factor in the gate function. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
438        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
439        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
440        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
441
442        **Returns:**
443        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
444        - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
445        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class.
446        """

The backbone network for HAT-based algorithms with learnable hard attention masks.

HAT-based algorithms include:

  • HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
  • AdaHAT (Adaptive Hard Attention to the Task, 2024) is an architecture-based continual learning approach that improves HAT by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
  • FG-AdaHAT is an architecture-based continual learning approach that improves HAT by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
HATMaskBackbone(output_dim: int | None, gate: str, **kwargs)
185    def __init__(self, output_dim: int | None, gate: str, **kwargs) -> None:
186        r"""
187        **Args:**
188        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
189        - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of:
190            - `sigmoid`: the sigmoid function.
191            - `tanh`: the hyperbolic tangent function.
192        - **kwargs**: Reserved for multiple inheritance.
193        """
194        super().__init__(output_dim=output_dim, **kwargs)
195
196        self.gate: str = gate
197        r"""The type of gate function."""
198        self.gate_fn: Callable
199        r"""The gate function mapping the real value task embeddings into attention masks."""
200
201        if gate == "sigmoid":
202            self.gate_fn = nn.Sigmoid()
203
204        self.task_embedding_t: nn.ModuleDict = nn.ModuleDict()
205        r"""The task embedding for the current task `self.task_id`. Keys are layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units).
206        
207        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
208        
209        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
210        
211        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
212        """
213
214        self.masks: dict[int, dict[str, Tensor]] = {}
215        r"""The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ). """
216
217        HATMaskBackbone.sanity_check(self)

Args:

  • output_dim (int): The output dimension that connects to CL output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • gate (str): The type of gate function turning the real value task embeddings into attention masks; one of:
    • sigmoid: the sigmoid function.
    • tanh: the hyperbolic tangent function.
  • kwargs: Reserved for multiple inheritance.
gate: str

The type of gate function.

gate_fn: Callable

The gate function mapping the real value task embeddings into attention masks.

task_embedding_t: torch.nn.modules.container.ModuleDict

The task embedding for the current task self.task_id. Keys are layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has size (1, number of units).

We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.

We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)

This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.

masks: dict[int, dict[str, torch.Tensor]]

The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ).

def initialize_task_embedding(self, mode: str) -> None:
219    def initialize_task_embedding(self, mode: str) -> None:
220        r"""Initialize the task embedding for the current task `self.task_id`.
221
222        **Args:**
223        - **mode** (`str`): The initialization mode for task embeddings; one of:
224            1. 'N01' (default): standard normal distribution $N(0, 1)$.
225            2. 'U-11': uniform distribution $U(-1, 1)$.
226            3. 'U01': uniform distribution $U(0, 1)$.
227            4. 'U-10': uniform distribution $U(-1, 0)$.
228            5. 'last': inherit task embeddings from the last task.
229        """
230        for te in self.task_embedding_t.values():
231            if mode == "N01":
232                nn.init.normal_(te.weight, 0, 1)
233            elif mode == "U-11":
234                nn.init.uniform_(te.weight, -1, 1)
235            elif mode == "U01":
236                nn.init.uniform_(te.weight, 0, 1)
237            elif mode == "U-10":
238                nn.init.uniform_(te.weight, -1, 0)
239            elif mode == "last":
240                pass

Initialize the task embedding for the current task self.task_id.

Args:

  • mode (str): The initialization mode for task embeddings; one of:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit task embeddings from the last task.
def sanity_check(self) -> None:
242    def sanity_check(self) -> None:
243        r"""Sanity check."""
244
245        if self.gate not in ["sigmoid"]:
246            raise ValueError("The gate should be one of: 'sigmoid'.")

Sanity check.

def get_mask( self, stage: str, s_max: float | None = None, batch_idx: int | None = None, num_batches: int | None = None, test_task_id: int | None = None) -> dict[str, torch.Tensor]:
248    def get_mask(
249        self,
250        stage: str,
251        s_max: float | None = None,
252        batch_idx: int | None = None,
253        num_batches: int | None = None,
254        test_task_id: int | None = None,
255    ) -> dict[str, Tensor]:
256        r"""Get the hard attention mask used in the `forward()` method for different stages.
257
258        **Args:**
259        - **stage** (`str`): The stage when applying the conversion; one of:
260            1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
261            2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by `s_max`, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
262            3. 'test': testing stage. Apply the test mask directly from the stored masks using `test_task_id`.
263        - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
264        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
265        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
266        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
267
268        **Returns:**
269        - **mask** (`dict[str, Tensor]`): The hard attention (with values 0 or 1) mask. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
270        """
271
272        # sanity check
273        if stage == "train" and (
274            s_max is None or batch_idx is None or num_batches is None
275        ):
276            raise ValueError(
277                "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`."
278            )
279        if stage == "validation" and (s_max is None):
280            raise ValueError(
281                "The `s_max` should be provided at validation stage, instead of the default value `None`."
282            )
283        if stage == "test" and (test_task_id is None):
284            raise ValueError(
285                "The `task_mask` should be provided at testing stage, instead of the default value `None`."
286            )
287
288        mask = {}
289        if stage == "train":
290            for layer_name in self.weighted_layer_names:
291                anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / (
292                    num_batches - 1
293                )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
294                mask[layer_name] = self.gate_fn(
295                    self.task_embedding_t[layer_name].weight * anneal_scalar
296                ).squeeze()
297        elif stage == "validation":
298            for layer_name in self.weighted_layer_names:
299                mask[layer_name] = self.gate_fn(
300                    self.task_embedding_t[layer_name].weight * s_max
301                ).squeeze()
302        elif stage == "test":
303            mask = self.masks[test_task_id]
304
305        return mask

Get the hard attention mask used in the forward() method for different stages.

Args:

  • stage (str): The stage when applying the conversion; one of:
    1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the HAT paper.
    2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by s_max, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
    3. 'test': testing stage. Apply the test mask directly from the stored masks using test_task_id.
  • s_max (float): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the HAT paper.
  • batch_idx (int | None): The current batch index. Applies only to the training stage. For other stages, it is None.
  • num_batches (int | None): The total number of batches. Applies only to the training stage. For other stages, it is None.
  • test_task_id (int | None): The test task ID. Applies only to the testing stage. For other stages, it is None.

Returns:

  • mask (dict[str, Tensor]): The hard attention (with values 0 or 1) mask. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
def te_to_binary_mask(self) -> dict[str, torch.Tensor]:
307    def te_to_binary_mask(self) -> dict[str, Tensor]:
308        r"""Convert the current task embedding into a binary mask.
309
310        This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task.
311
312        **Returns:**
313        - **mask_t** (`dict[str, Tensor]`): The binary mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
314        """
315        # get the mask for the current task
316        mask_t = {
317            layer_name: (self.task_embedding_t[layer_name].weight > 0)
318            .float()
319            .squeeze()
320            .detach()
321            for layer_name in self.weighted_layer_names
322        }
323
324        return mask_t

Convert the current task embedding into a binary mask.

This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task.

Returns:

  • mask_t (dict[str, Tensor]): The binary mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
def store_mask(self) -> None:
326    def store_mask(self) -> None:
327        r"""Store the mask for the current task `self.task_id`."""
328        mask_t = self.te_to_binary_mask()
329        self.masks[self.task_id] = mask_t
330
331        return mask_t

Store the mask for the current task self.task_id.

def get_layer_measure_parameter_wise( self, neuron_wise_measure: dict[str, torch.Tensor], layer_name: str, aggregation_mode: str) -> torch.Tensor:
333    def get_layer_measure_parameter_wise(
334        self,
335        neuron_wise_measure: dict[str, Tensor],
336        layer_name: str,
337        aggregation_mode: str,
338    ) -> Tensor:
339        r"""Get the parameter-wise measure on the parameters right before the given layer.
340
341        It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
342
343        Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer.
344
345        This method is used to calculate parameter-wise measures in various HAT-based algorithms:
346
347        - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks `cumulative_mask_for_previous_tasks`, which is $\text{Agg} \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in Eq. (2) in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
348        - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks `summative_mask_for_previous_tasks`, which is $\text{Agg} \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
349        - **FG-AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks `summative_importance_for_previous_tasks`, which is $\text{Agg} \left(I_{l,i}^{<t}, I_{l-1,j}^{<t}\right)$ in Eq. (2) in the FG-AdaHAT paper.
350
351        **Args:**
352        - **neuron_wise_measure** (`dict[str, Tensor]`): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ).
353        - **layer_name** (`str`): The name of the given layer.
354        - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
355            - 'min': takes the minimum of the two connected unit measures.
356            - 'max': takes the maximum of the two connected unit measures.
357            - 'mean': takes the mean of the two connected unit measures.
358
359        **Returns:**
360        - **weight_measure** (`Tensor`): The weight measure matrix, the same size as the corresponding weights.
361        - **bias_measure** (`Tensor`): The bias measure vector, the same size as the corresponding bias.
362        """
363
364        # initialize the aggregation function
365        if aggregation_mode == "min":
366            aggregation_func = torch.min
367        elif aggregation_mode == "max":
368            aggregation_func = torch.max
369        elif aggregation_mode == "mean":
370            aggregation_func = torch.mean
371        else:
372            raise ValueError(
373                f"The aggregation method {aggregation_mode} is not supported."
374            )
375
376        # get the preceding layer
377        preceding_layer_name = self.preceding_layer_name(layer_name)
378
379        # get weight size for expanding the measures
380        layer = self.get_layer_by_name(layer_name)
381        weight_size = layer.weight.size()
382
383        # construct the weight-wise measure
384        layer_measure = neuron_wise_measure[layer_name]
385        layer_measure_broadcast_size = (-1, 1) + tuple(
386            1 for _ in range(len(weight_size) - 2)
387        )  # since the size of mask tensor is (number of units, ), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
388
389        layer_measure_broadcasted = layer_measure.view(
390            *layer_measure_broadcast_size
391        ).expand(
392            weight_size,
393        )  # expand the given layer mask to the weight size and broadcast
394
395        if (
396            preceding_layer_name
397        ):  # if the layer is not the first layer, where the preceding layer exists
398
399            preceding_layer_measure_broadcast_size = (1, -1) + tuple(
400                1 for _ in range(len(weight_size) - 2)
401            )  # since the size of mask tensor is (number of units, ), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
402            preceding_layer_measure = neuron_wise_measure[preceding_layer_name]
403            preceding_layer_measure_broadcasted = preceding_layer_measure.view(
404                *preceding_layer_measure_broadcast_size
405            ).expand(
406                weight_size
407            )  # expand the preceding layer mask to the weight size and broadcast
408            weight_measure = aggregation_func(
409                layer_measure_broadcasted, preceding_layer_measure_broadcasted
410            )  # get the minimum of the two mask vectors, from expanded
411        else:  # if the layer is the first layer
412            weight_measure = layer_measure_broadcasted
413
414        # construct the bias-wise measure
415        bias_measure = layer_measure
416
417        return weight_measure, bias_measure

Get the parameter-wise measure on the parameters right before the given layer.

It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.

Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer.

This method is used to calculate parameter-wise measures in various HAT-based algorithms:

  • HAT: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks cumulative_mask_for_previous_tasks, which is $\text{Agg} \left(a_{l,i}^{HAT paper.
  • AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks summative_mask_for_previous_tasks, which is $\text{Agg} \left(m_{l,i}^{AdaHAT paper.
  • FG-AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks summative_importance_for_previous_tasks, which is $\text{Agg} \left(I_{l,i}^{

Args:

  • neuron_wise_measure (dict[str, Tensor]): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ).
  • layer_name (str): The name of the given layer.
  • aggregation_mode (str): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
    • 'min': takes the minimum of the two connected unit measures.
    • 'max': takes the maximum of the two connected unit measures.
    • 'mean': takes the mean of the two connected unit measures.

Returns:

  • weight_measure (Tensor): The weight measure matrix, the same size as the corresponding weights.
  • bias_measure (Tensor): The bias measure vector, the same size as the corresponding bias.
@override
def forward( self, input: torch.Tensor, stage: str, s_max: float | None = None, batch_idx: int | None = None, num_batches: int | None = None, test_task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor]]:
419    @override
420    def forward(
421        self,
422        input: Tensor,
423        stage: str,
424        s_max: float | None = None,
425        batch_idx: int | None = None,
426        num_batches: int | None = None,
427        test_task_id: int | None = None,
428    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
429        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to the units in each layer.
430
431        **Args:**
432        - **input** (`Tensor`): The input tensor from data.
433        - **stage** (`str`): The stage of the forward pass; one of:
434            1. 'train': training stage.
435            2. 'validation': validation stage.
436            3. 'test': testing stage.
437        - **s_max** (`float`): The maximum scaling factor in the gate function. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
438        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
439        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
440        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
441
442        **Returns:**
443        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
444        - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
445        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class.
446        """

The forward pass for data from task self.task_id. Task-specific masks for self.task_id are applied to the units in each layer.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • s_max (float): The maximum scaling factor in the gate function. See Sec. 2.4 in the HAT paper.
  • batch_idx (int | None): The current batch index. Applies only to the training stage. For other stages, it is None.
  • num_batches (int | None): The total number of batches. Applies only to the training stage. For other stages, it is None.
  • test_task_id (int | None): The test task ID. Applies only to the testing stage. For other stages, it is None.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • mask (dict[str, Tensor]): The mask for the current task. Keys (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this forward() method of the HAT class.
class WSNMaskBackbone(clarena.backbones.CLBackbone):
449class WSNMaskBackbone(CLBackbone):
450    r"""The backbone network for the WSN algorithm with learnable parameter masks.
451
452    [WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) is an architecture-based continual learning algorithm. It trains learnable parameter-wise scores and selects the most scored $c\%$ of the network parameters to be used for each task.
453    """
454
455    def __init__(self, output_dim: int | None, **kwargs) -> None:
456        r"""
457        **Args:**
458        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
459        - **kwargs**: Reserved for multiple inheritance.
460        """
461        super().__init__(output_dim=output_dim, **kwargs)
462
463        self.gate_fn: torch.autograd.Function = PercentileLayerParameterMaskingByScore
464        r"""The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score."""
465
466        self.weight_score_t: nn.ModuleDict = nn.ModuleDict()
467        r"""The weight score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (output features, input features) as the weight.
468        
469        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
470        
471        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
472        
473        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
474        """
475
476        self.bias_score_t: nn.ModuleDict = nn.ModuleDict()
477        r"""The bias score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is `None`.
478        
479        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
480        
481        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
482        
483        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
484        """
485
486        WSNMaskBackbone.sanity_check(self)
487
488    def sanity_check(self) -> None:
489        r"""Sanity check."""
490
491    def initialize_parameter_score(self, mode: str) -> None:
492        r"""Initialize the parameter score for the current task.
493
494        **Args:**
495        - **mode** (`str`): The initialization mode for parameter scores; one of:
496            1. 'default': the default initialization mode in the original WSN code.
497            2. 'N01': standard normal distribution $N(0, 1)$.
498            3. 'U01': uniform distribution $U(0, 1)$.
499        """
500
501        for layer_name, weight_score in self.weight_score_t.items():
502            if mode == "default":
503                # Kaiming Uniform Initialization for weight score
504                nn.init.kaiming_uniform_(weight_score.weight, a=math.sqrt(5))
505
506                for layer_name, bias_score in self.bias_score_t.items():
507                    if bias_score is not None:
508                        # For bias, follow the standard bias initialization using fan_in
509                        weight_score = self.weight_score_t[layer_name]
510                        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
511                            weight_score.weight
512                        )
513                        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
514                        nn.init.uniform_(bias_score.weight, -bound, bound)
515            elif mode == "N01":
516                nn.init.normal_(weight_score.weight, 0, 1)
517                for layer_name, bias_score in self.bias_score_t.items():
518                    if bias_score is not None:
519                        nn.init.normal_(bias_score.weight, 0, 1)
520            elif mode == "U01":
521                nn.init.uniform_(weight_score.weight, 0, 1)
522                for layer_name, bias_score in self.bias_score_t.items():
523                    if bias_score is not None:
524                        nn.init.uniform_(bias_score.weight, 0, 1)
525
526    def get_mask(
527        self,
528        stage: str,
529        mask_percentage: float,
530        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
531    ) -> dict[str, Tensor]:
532        r"""Get the binary parameter mask used in the `forward()` method for different stages.
533
534        **Args:**
535        - **stage** (`str`): The stage when applying the conversion; one of:
536            1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
537            2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
538            3. 'test': testing stage. Apply the test mask directly from the argument `test_mask`.
539        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`.
540
541        **Returns:**
542        - **weight_mask** (`dict[str, Tensor]`): The binary mask on weights. Key (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
543        - **bias_mask** (`dict[str, Tensor]`): The binary mask on biases. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
544        """
545        weight_mask = {}
546        bias_mask = {}
547        if stage == "train" or stage == "validation":
548            for layer_name in self.weighted_layer_names:
549                weight_mask[layer_name] = self.gate_fn.apply(
550                    self.weight_score_t[layer_name].weight, mask_percentage
551                )
552                if self.bias_score_t[layer_name] is not None:
553                    bias_mask[layer_name] = self.gate_fn.apply(
554                        self.bias_score_t[layer_name].weight.squeeze(
555                            0
556                        ),  # from (1, output_dim) to (output_dim, )
557                        mask_percentage,
558                    )
559                else:
560                    bias_mask[layer_name] = None
561        elif stage == "test":
562            weight_mask, bias_mask = test_mask
563
564        return weight_mask, bias_mask
565
566    @override
567    def forward(
568        self,
569        input: Tensor,
570        stage: str,
571        mask_percentage: float,
572        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
573    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]:
574        r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units in each layer.
575
576        **Args:**
577        - **input** (`Tensor`): The input tensor from data.
578        - **stage** (`str`): The stage of the forward pass; one of:
579            1. 'train': training stage.
580            2. 'validation': validation stage.
581            3. 'test': testing stage.
582        - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1.
583        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is `None`.
584
585        **Returns:**
586        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
587        - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Key (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, input features) as the weight.
588        - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
589        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
590        """

The backbone network for the WSN algorithm with learnable parameter masks.

WSN (Winning Subnetworks, 2022) is an architecture-based continual learning algorithm. It trains learnable parameter-wise scores and selects the most scored $c\%$ of the network parameters to be used for each task.

WSNMaskBackbone(output_dim: int | None, **kwargs)
455    def __init__(self, output_dim: int | None, **kwargs) -> None:
456        r"""
457        **Args:**
458        - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`.
459        - **kwargs**: Reserved for multiple inheritance.
460        """
461        super().__init__(output_dim=output_dim, **kwargs)
462
463        self.gate_fn: torch.autograd.Function = PercentileLayerParameterMaskingByScore
464        r"""The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score."""
465
466        self.weight_score_t: nn.ModuleDict = nn.ModuleDict()
467        r"""The weight score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (output features, input features) as the weight.
468        
469        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
470        
471        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
472        
473        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
474        """
475
476        self.bias_score_t: nn.ModuleDict = nn.ModuleDict()
477        r"""The bias score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is `None`.
478        
479        We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries.
480        
481        We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
482        
483        **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
484        """
485
486        WSNMaskBackbone.sanity_check(self)

Args:

  • output_dim (int): The output dimension that connects to CL output heads. The input_dim of output heads is expected to be the same as this output_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be None.
  • kwargs: Reserved for multiple inheritance.
gate_fn: torch.autograd.function.Function

The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score.

weight_score_t: torch.nn.modules.container.ModuleDict

The weight score for the current task self.task_id. Keys are the layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has the same size (output features, input features) as the weight.

We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.

We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)

This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.

bias_score_t: torch.nn.modules.container.ModuleDict

The bias score for the current task self.task_id. Keys are the layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is None.

We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.

We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)

This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.

def sanity_check(self) -> None:
488    def sanity_check(self) -> None:
489        r"""Sanity check."""

Sanity check.

def initialize_parameter_score(self, mode: str) -> None:
491    def initialize_parameter_score(self, mode: str) -> None:
492        r"""Initialize the parameter score for the current task.
493
494        **Args:**
495        - **mode** (`str`): The initialization mode for parameter scores; one of:
496            1. 'default': the default initialization mode in the original WSN code.
497            2. 'N01': standard normal distribution $N(0, 1)$.
498            3. 'U01': uniform distribution $U(0, 1)$.
499        """
500
501        for layer_name, weight_score in self.weight_score_t.items():
502            if mode == "default":
503                # Kaiming Uniform Initialization for weight score
504                nn.init.kaiming_uniform_(weight_score.weight, a=math.sqrt(5))
505
506                for layer_name, bias_score in self.bias_score_t.items():
507                    if bias_score is not None:
508                        # For bias, follow the standard bias initialization using fan_in
509                        weight_score = self.weight_score_t[layer_name]
510                        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
511                            weight_score.weight
512                        )
513                        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
514                        nn.init.uniform_(bias_score.weight, -bound, bound)
515            elif mode == "N01":
516                nn.init.normal_(weight_score.weight, 0, 1)
517                for layer_name, bias_score in self.bias_score_t.items():
518                    if bias_score is not None:
519                        nn.init.normal_(bias_score.weight, 0, 1)
520            elif mode == "U01":
521                nn.init.uniform_(weight_score.weight, 0, 1)
522                for layer_name, bias_score in self.bias_score_t.items():
523                    if bias_score is not None:
524                        nn.init.uniform_(bias_score.weight, 0, 1)

Initialize the parameter score for the current task.

Args:

  • mode (str): The initialization mode for parameter scores; one of:
    1. 'default': the default initialization mode in the original WSN code.
    2. 'N01': standard normal distribution $N(0, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
def get_mask( self, stage: str, mask_percentage: float, test_mask: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] | None = None) -> dict[str, torch.Tensor]:
526    def get_mask(
527        self,
528        stage: str,
529        mask_percentage: float,
530        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
531    ) -> dict[str, Tensor]:
532        r"""Get the binary parameter mask used in the `forward()` method for different stages.
533
534        **Args:**
535        - **stage** (`str`): The stage when applying the conversion; one of:
536            1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
537            2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
538            3. 'test': testing stage. Apply the test mask directly from the argument `test_mask`.
539        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`.
540
541        **Returns:**
542        - **weight_mask** (`dict[str, Tensor]`): The binary mask on weights. Key (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
543        - **bias_mask** (`dict[str, Tensor]`): The binary mask on biases. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
544        """
545        weight_mask = {}
546        bias_mask = {}
547        if stage == "train" or stage == "validation":
548            for layer_name in self.weighted_layer_names:
549                weight_mask[layer_name] = self.gate_fn.apply(
550                    self.weight_score_t[layer_name].weight, mask_percentage
551                )
552                if self.bias_score_t[layer_name] is not None:
553                    bias_mask[layer_name] = self.gate_fn.apply(
554                        self.bias_score_t[layer_name].weight.squeeze(
555                            0
556                        ),  # from (1, output_dim) to (output_dim, )
557                        mask_percentage,
558                    )
559                else:
560                    bias_mask[layer_name] = None
561        elif stage == "test":
562            weight_mask, bias_mask = test_mask
563
564        return weight_mask, bias_mask

Get the binary parameter mask used in the forward() method for different stages.

Args:

  • stage (str): The stage when applying the conversion; one of:
    1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the WSN paper.
    2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
    3. 'test': testing stage. Apply the test mask directly from the argument test_mask.
  • test_mask (tuple[dict[str, Tensor], dict[str, Tensor]] | None): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is None.

Returns:

  • weight_mask (dict[str, Tensor]): The binary mask on weights. Key (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
  • bias_mask (dict[str, Tensor]): The binary mask on biases. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is None.
@override
def forward( self, input: torch.Tensor, stage: str, mask_percentage: float, test_mask: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, torch.Tensor]]:
566    @override
567    def forward(
568        self,
569        input: Tensor,
570        stage: str,
571        mask_percentage: float,
572        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
573    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]:
574        r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units in each layer.
575
576        **Args:**
577        - **input** (`Tensor`): The input tensor from data.
578        - **stage** (`str`): The stage of the forward pass; one of:
579            1. 'train': training stage.
580            2. 'validation': validation stage.
581            3. 'test': testing stage.
582        - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1.
583        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is `None`.
584
585        **Returns:**
586        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
587        - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Key (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, input features) as the weight.
588        - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
589        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
590        """

The forward pass for data from task self.task_id. Task-specific mask for self.task_id are applied to the units in each layer.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • mask_percentage (float): The percentage of parameters to be masked. The value should be between 0 and 1.
  • test_mask (tuple[dict[str, Tensor], dict[str, Tensor]] | None): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is None.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • weight_mask (dict[str, Tensor]): The weight mask for the current task. Key (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has same (output features, input features) as the weight.
  • bias_mask (dict[str, Tensor]): The bias mask for the current task. Keys (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is None.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes.