clarena.backbones

Backbone Networks for Continual Learning

This submodule provides the backbone neural network architectures for continual learning.

Please note that this is an API documentation. Please refer to the main documentation pages for more information about the backbone networks and how to configure and implement them:

The backbones are implemented as subclasses of CLBackbone classes, which are the base class for all continual learning backbones in CLArena.

 1r"""
 2
 3# Backbone Networks for Continual Learning
 4
 5This submodule provides the **backbone neural network architectures for continual learning**. 
 6
 7Please note that this is an API documentation. Please refer to the main documentation pages for more information about the backbone networks and how to 
 8configure and implement them:
 9
10- [**Configure Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiment/backbone-network)
11- [**Implement Your CL Backbone Class**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-CL-modules/backbone-network)
12
13
14
15The backbones are implemented as subclasses of `CLBackbone` classes, which are the base class for all continual learning backbones in CLArena.
16
17- `CLBackbone`: The base class for continual learning backbones.
18- `HATMaskBackbone`: The base class for backbones used in [HAT (Hard Attention to the Task) algorithm](http://proceedings.mlr.press/v80/serra18a). A child class of `CLBackbone`.
19
20
21"""
22
23from .base import CLBackbone, HATMaskBackbone
24from .mlp import MLP, HATMaskMLP
25from .resnet import (
26    HATMaskResNet18,
27    HATMaskResNet34,
28    HATMaskResNet50,
29    HATMaskResNet101,
30    HATMaskResNet152,
31    ResNet18,
32    ResNet34,
33    ResNet50,
34    ResNet101,
35    ResNet152,
36)
37
38__all__ = ["CLBackbone", "HATMaskBackbone", "mlp", "resnet"]
class CLBackbone(torch.nn.modules.module.Module):
 18class CLBackbone(nn.Module):
 19    r"""The base class of continual learning backbone networks, inherited from `nn.Module`."""
 20
 21    def __init__(self, output_dim: int | None) -> None:
 22        r"""Initialise the CL backbone network.
 23
 24        **Args:**
 25        - **output_dim** (`int` | `None`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be `None`.
 26        """
 27        nn.Module.__init__(self)
 28
 29        self.output_dim: int = output_dim
 30        r"""Store the output dimension of the backbone network."""
 31
 32        self.weighted_layer_names: list[str] = []
 33        r"""Maintain a list of the weighted layer names. Weighted layer has weights connecting to other weighted layer. They are the main part of neural networks. **It must be provided in subclasses.**
 34        
 35        The names are following the `nn.Module` internal naming mechanism. For example, if the a layer is assigned to `self.conv1`, the name becomes `conv1`. If the `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. If hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should be `block.conv1` according to `nn.Module` internal mechanism, but we use '/' instead of '.' to avoid the error of using '.' in the key of `ModuleDict`.
 36        
 37        In HAT architecture, it's also the layer names with task embedding masking in the order of forward pass. HAT gives task embedding to every possible weighted layer. 
 38        """
 39
 40        self.task_id: int
 41        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
 42
 43    def setup_task_id(self, task_id: int) -> None:
 44        r"""Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called.
 45
 46        **Args:**
 47        - **task_id** (`int`): the target task ID.
 48        """
 49        self.task_id = task_id
 50
 51    def get_layer_by_name(self, layer_name: str) -> nn.Module:
 52        r"""Get the layer by its name.
 53
 54        **Args:**
 55        - **layer_name** (`str`): the name of the layer. Note that the name is the name substituting the '.' with '/', like `block/conv1`, rather than `block.conv1`.
 56
 57        **Returns:**
 58        - **layer** (`nn.Module`): the layer.
 59        """
 60        for name, layer in self.named_modules():
 61            if name == layer_name.replace("/", "."):
 62                return layer
 63
 64    def preceding_layer_name(self, layer_name: str) -> str:
 65        r"""Get the name of the preceding layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the first layer, return `None`.
 66
 67        **Args:**
 68        - **layer_name** (`str`): the name of the layer.
 69
 70        **Returns:**
 71        - **preceding_layer_name** (`str`): the name of the preceding layer.
 72
 73        **Raises:**
 74        - **ValueError**: if `layer_name` is not in the weighted layer order.
 75        """
 76
 77        if layer_name not in self.weighted_layer_names:
 78            raise ValueError(f"The layer name {layer_name} doesn't exist.")
 79
 80        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
 81        if weighted_layer_idx == 0:
 82            return None
 83        return self.weighted_layer_names[weighted_layer_idx - 1]
 84
 85    @override  # since `nn.Module` uses it
 86    def forward(
 87        self,
 88        input: Tensor,
 89        stage: str,
 90        task_id: int | None = None,
 91    ) -> tuple[Tensor, dict[str, Tensor]]:
 92        r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.**
 93
 94        **Args:**
 95        - **input** (`Tensor`): The input tensor from data.
 96        - **stage** (`str`): the stage of the forward pass, should be one of the following:
 97            1. 'train': training stage.
 98            2. 'validation': validation stage.
 99            3. 'test': testing stage.
100        - **task_id** (`int` | `None`): the task ID where the data are from. If stage is 'train' or 'validation', it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value.
101
102        **Returns:**
103        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
104        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
105        """

The base class of continual learning backbone networks, inherited from nn.Module.

CLBackbone(output_dim: int | None)
21    def __init__(self, output_dim: int | None) -> None:
22        r"""Initialise the CL backbone network.
23
24        **Args:**
25        - **output_dim** (`int` | `None`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be `None`.
26        """
27        nn.Module.__init__(self)
28
29        self.output_dim: int = output_dim
30        r"""Store the output dimension of the backbone network."""
31
32        self.weighted_layer_names: list[str] = []
33        r"""Maintain a list of the weighted layer names. Weighted layer has weights connecting to other weighted layer. They are the main part of neural networks. **It must be provided in subclasses.**
34        
35        The names are following the `nn.Module` internal naming mechanism. For example, if the a layer is assigned to `self.conv1`, the name becomes `conv1`. If the `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. If hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should be `block.conv1` according to `nn.Module` internal mechanism, but we use '/' instead of '.' to avoid the error of using '.' in the key of `ModuleDict`.
36        
37        In HAT architecture, it's also the layer names with task embedding masking in the order of forward pass. HAT gives task embedding to every possible weighted layer. 
38        """
39
40        self.task_id: int
41        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""

Initialise the CL backbone network.

Args:

  • output_dim (int | None): The output dimension which connects to CL output heads. The input_dim of output heads are expected to be the same as this output_dim. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be None.
output_dim: int

Store the output dimension of the backbone network.

weighted_layer_names: list[str]

Maintain a list of the weighted layer names. Weighted layer has weights connecting to other weighted layer. They are the main part of neural networks. It must be provided in subclasses.

The names are following the nn.Module internal naming mechanism. For example, if the a layer is assigned to self.conv1, the name becomes conv1. If the nn.Sequential is used, the name becomes the index of the layer in the sequence, such as 0, 1, etc. If hierarchical structure is used, for example, a nn.Module is assigned to self.block which has self.conv1, the name becomes block/conv1. Note that it should be block.conv1 according to nn.Module internal mechanism, but we use '/' instead of '.' to avoid the error of using '.' in the key of ModuleDict.

In HAT architecture, it's also the layer names with task embedding masking in the order of forward pass. HAT gives task embedding to every possible weighted layer.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

def setup_task_id(self, task_id: int) -> None:
43    def setup_task_id(self, task_id: int) -> None:
44        r"""Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called.
45
46        **Args:**
47        - **task_id** (`int`): the target task ID.
48        """
49        self.task_id = task_id

Set up which task's dataset the CL experiment is on. This must be done before forward() method is called.

Args:

  • task_id (int): the target task ID.
def get_layer_by_name(self, layer_name: str) -> torch.nn.modules.module.Module:
51    def get_layer_by_name(self, layer_name: str) -> nn.Module:
52        r"""Get the layer by its name.
53
54        **Args:**
55        - **layer_name** (`str`): the name of the layer. Note that the name is the name substituting the '.' with '/', like `block/conv1`, rather than `block.conv1`.
56
57        **Returns:**
58        - **layer** (`nn.Module`): the layer.
59        """
60        for name, layer in self.named_modules():
61            if name == layer_name.replace("/", "."):
62                return layer

Get the layer by its name.

Args:

  • layer_name (str): the name of the layer. Note that the name is the name substituting the '.' with '/', like block/conv1, rather than block.conv1.

Returns:

  • layer (nn.Module): the layer.
def preceding_layer_name(self, layer_name: str) -> str:
64    def preceding_layer_name(self, layer_name: str) -> str:
65        r"""Get the name of the preceding layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the first layer, return `None`.
66
67        **Args:**
68        - **layer_name** (`str`): the name of the layer.
69
70        **Returns:**
71        - **preceding_layer_name** (`str`): the name of the preceding layer.
72
73        **Raises:**
74        - **ValueError**: if `layer_name` is not in the weighted layer order.
75        """
76
77        if layer_name not in self.weighted_layer_names:
78            raise ValueError(f"The layer name {layer_name} doesn't exist.")
79
80        weighted_layer_idx = self.weighted_layer_names.index(layer_name)
81        if weighted_layer_idx == 0:
82            return None
83        return self.weighted_layer_names[weighted_layer_idx - 1]

Get the name of the preceding layer of the given layer from the stored self.masked_layer_order. If the given layer is the first layer, return None.

Args:

  • layer_name (str): the name of the layer.

Returns:

  • preceding_layer_name (str): the name of the preceding layer.

Raises:

  • ValueError: if layer_name is not in the weighted layer order.
@override
def forward( self, input: torch.Tensor, stage: str, task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
 85    @override  # since `nn.Module` uses it
 86    def forward(
 87        self,
 88        input: Tensor,
 89        stage: str,
 90        task_id: int | None = None,
 91    ) -> tuple[Tensor, dict[str, Tensor]]:
 92        r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.**
 93
 94        **Args:**
 95        - **input** (`Tensor`): The input tensor from data.
 96        - **stage** (`str`): the stage of the forward pass, should be one of the following:
 97            1. 'train': training stage.
 98            2. 'validation': validation stage.
 99            3. 'test': testing stage.
100        - **task_id** (`int` | `None`): the task ID where the data are from. If stage is 'train' or 'validation', it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value.
101
102        **Returns:**
103        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
104        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
105        """

The forward pass for data from task task_id. In some backbones, the forward pass might be different for different tasks. It must be implemented by subclasses.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): the stage of the forward pass, should be one of the following:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • task_id (int | None): the task ID where the data are from. If stage is 'train' or 'validation', it is usually from the current task self.task_id. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value.

Returns:

  • output_feature (Tensor): the output feature tensor to be passed into heads. This is the main target of backpropagation.
  • hidden_features (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
class HATMaskBackbone(clarena.backbones.CLBackbone):
108class HATMaskBackbone(CLBackbone):
109    r"""The backbone network for HAT-based algorithms with learnable hard attention masks.
110
111    HAT-based algorithms:
112
113    - [**HAT (Hard Attention to the Task, 2018)**](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.
114    - [**Adaptive HAT (Adaptive Hard Attention to the Task, 2024)**](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity.
115    - **CBPHAT** is what I am working on, trying combining HAT (Hard Attention to the Task) algorithm with Continual Backpropagation (CBP) by leveraging the contribution utility as the parameter importance like in AdaHAT (Adaptive Hard Attention to the Task) algorithm.
116    """
117
118    def __init__(self, output_dim: int | None, gate: str) -> None:
119        r"""Initialise the HAT mask backbone network with task embeddings and masks.
120
121        **Args:**
122        - **output_dim** (`int`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be `None`.
123        - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
124            - `sigmoid`: the sigmoid function.
125        """
126        CLBackbone.__init__(self, output_dim=output_dim)
127
128        self.register_hat_mask_module_explicitly(
129            gate=gate
130        )  # we moved the registration of the modules to a separate method to solve a problem of multiple inheritance in terms of `nn.Module`
131
132        HATMaskBackbone.sanity_check(self)
133
134    def register_hat_mask_module_explicitly(self, gate: str) -> None:
135        r"""Register all `nn.Module`s explicitly in this method. For `HATMaskBackbone`, they are task embedding for the current task and the masks.
136
137        **Args:**
138        - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
139            - `sigmoid`: the sigmoid function.
140        """
141        self.gate: str = gate
142        r"""Store the type of gate function."""
143        if gate == "sigmoid":
144            self.gate_fn: nn.Module = nn.Sigmoid()
145            r"""The gate function turning the real value task embeddings into attention masks."""
146
147        self.task_embedding_t: nn.ModuleDict = nn.ModuleDict()
148        r"""Store the task embedding for the current task. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units).
149        
150        We use `ModuleDict` rather than `dict` to make sure `LightningModule` can properly register these model parameters for the purpose of, like automatically transfering to device, being recorded in model summaries.
151        
152        we use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
153        
154        **This must be defined to cover each weighted layer (just as `self.weighted_layer_names` listed) in the backbone network.** Otherwise, the uncovered parts will keep updating for all tasks and become a source of catastrophic forgetting. """
155
156    def initialise_task_embedding(self, mode: str) -> None:
157        r"""Initialise the task embedding for the current task.
158
159        **Args:**
160        - **mode** (`str`): the initialisation mode for task embeddings, should be one of the following:
161            1. 'N01' (default): standard normal distribution $N(0, 1)$.
162            2. 'U-11': uniform distribution $U(-1, 1)$.
163            3. 'U01': uniform distribution $U(0, 1)$.
164            4. 'U-10': uniform distribution $U(-1, 0)$.
165            5. 'last': inherit task embedding from last task.
166        """
167        for te in self.task_embedding_t.values():
168            if mode == "N01":
169                nn.init.normal_(te.weight, 0, 1)
170            elif mode == "U-11":
171                nn.init.uniform_(te.weight, -1, 1)
172            elif mode == "U01":
173                nn.init.uniform_(te.weight, 0, 1)
174            elif mode == "U-10":
175                nn.init.uniform_(te.weight, -1, 0)
176            elif mode == "last":
177                pass
178
179    def sanity_check(self) -> None:
180        r"""Check the sanity of the arguments.
181
182        **Raises:**
183        - **ValueError**: when the `gate` is not one of the valid options.
184        """
185
186        if self.gate not in ["sigmoid"]:
187            raise ValueError("The gate should be one of 'sigmoid'.")
188
189    def get_mask(
190        self,
191        stage: str,
192        s_max: float | None = None,
193        batch_idx: int | None = None,
194        num_batches: int | None = None,
195        test_mask: dict[str, Tensor] | None = None,
196    ) -> dict[str, Tensor]:
197        r"""Get the hard attention mask used in `forward()` method for different stages.
198
199        **Args:**
200        - **stage** (`str`): the stage when applying the conversion, should be one of the following:
201            1. 'train': training stage. If stage is 'train', get the mask from task embedding of current task through the gate function, which is scaled by an annealed scalar. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
202            2. ‘validation': validation stage. If stage is 'validation', get the mask from task embedding of current task through the gate function, which is scaled by `s_max`. (Note that in this stage, the binary mask hasn't been stored yet as the training is not over.)
203            3. 'test': testing stage. If stage is 'test', apply the mask gate function is scaled by `s_max`, the large scaling making masks nearly binary.
204        - **s_max** (`float`): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
205        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
206        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
207        - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`.
208
209        **Returns:**
210        - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
211
212        **Raises:**
213        - **ValueError**: if the `batch_idx` and `batch_num` are not provided in 'train' stage; if the `s_max` is not provided in 'validation' stage; if the `task_id` is not provided in 'test' stage.
214        """
215
216        # sanity check
217        if stage == "train" and (
218            s_max is None or batch_idx is None or num_batches is None
219        ):
220            raise ValueError(
221                "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`."
222            )
223        if stage == "validation" and (s_max is None):
224            raise ValueError(
225                "The `s_max` should be provided at validation stage, instead of the default value `None`."
226            )
227        if stage == "test" and (test_mask is None):
228            raise ValueError(
229                "The `task_mask` should be provided at testing stage, instead of the default value `None`."
230            )
231
232        mask = {}
233        if stage == "train":
234            for layer_name in self.weighted_layer_names:
235                anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / (
236                    num_batches - 1
237                )  # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
238                mask[layer_name] = self.gate_fn(
239                    self.task_embedding_t[layer_name].weight * anneal_scalar
240                ).squeeze()
241        elif stage == "validation":
242            for layer_name in self.weighted_layer_names:
243                mask[layer_name] = self.gate_fn(
244                    self.task_embedding_t[layer_name].weight * s_max
245                ).squeeze()
246        elif stage == "test":
247            mask = test_mask
248
249        return mask
250
251    def get_cumulative_mask(self) -> dict[str, Tensor]:
252        r"""Get the cumulative mask till current task.
253
254        **Returns:**
255        - **cumulative_mask** (`dict[str, Tensor]`): the cumulative mask. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
256        """
257        return self.cumulative_mask_for_previous_tasks
258
259    def get_summative_mask(self) -> dict[str, Tensor]:
260        r"""Get the summative mask till current task.
261
262        **Returns:**
263        - **summative_mask** (`dict[str, Tensor]`): the summative mask tensor. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
264        """
265        return self.summative_mask_for_previous_tasks
266
267    def get_layer_measure_parameter_wise(
268        self,
269        unit_wise_measure: dict[str, Tensor],
270        layer_name: str,
271        aggregation: str,
272    ) -> Tensor:
273        r"""Get the parameter-wise measure on the parameters right before the given layer.
274
275        It is calculated from the given unit-wise measure. It aggregates two feature-sized vectors (corresponding the given layer and preceding layer) into a weight-wise matrix (corresponding the weights in between) and bias-wise vector (corresponding the bias of the given layer), using the given aggregation method. For example, given two feature-sized measure $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is then $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
276
277        Note that if the given layer is the first layer with no preceding layer, we will get parameter-wise measure directly broadcasted from the unit-wise measure of given layer.
278
279        This method is used in the calculation of parameter-wise measure in various HAT-based algorithms:
280
281        - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the unit-wise cumulative mask of previous tasks `self.cumulative_mask_for_previous_tasks`, which is $\min \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in equation (2) in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
282        - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise summative mask of previous tasks `self.summative_mask_for_previous_tasks`, which is $\min \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
283        - **CBPHAT**: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise importance of previous tasks `self.unit_importance_for_previous_tasks` based on contribution utility, which is $\min \left(I_{l,i}^{(t-1)}, I_{l-1,j}^{(t-1)}\right)$ in the adjustment rate formula in the paper draft.
284
285        **Args:**
286        - **unit_wise_measure** (`dict[str, Tensor]`): the unit-wise measure. Key is layer name, value is the unit-wise measure tensor. The measure tensor has size (number of units).
287        - **layer_name** (`str`): the name of given layer.
288        - **aggregation** (`str`): the aggregation method turning two feature-wise measures into weight-wise matrix, should be one of the following:
289            - 'min': takes minimum of the two connected unit measures.
290            - 'max': takes maximum of the two connected unit measures.
291
292        **Returns:**
293        - **weight_measure** (`Tensor`): the weight measure matrix, same size as the corresponding weights.
294        - **bias_measure** (`Tensor`): the bias measure vector, same size as the corresponding bias.
295
296
297        """
298
299        # initialise the aggregation function
300        if aggregation == "min":
301            aggregation_func = torch.min
302        elif aggregation == "max":
303            aggregation_func = torch.max
304        else:
305            raise ValueError(f"The aggregation method {aggregation} is not supported.")
306
307        # get the preceding layer name
308        preceding_layer_name = self.preceding_layer_name(layer_name)
309
310        # get weight size for expanding the measures
311        layer = self.get_layer_by_name(layer_name)
312        weight_size = layer.weight.size()
313
314        # construct the weight-wise measure
315        layer_measure = unit_wise_measure[layer_name]
316        layer_measure_broadcast_size = (-1, 1) + tuple(
317            1 for _ in range(len(weight_size) - 2)
318        )  # since the size of mask tensor is (number of units), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
319
320        layer_measure_broadcasted = layer_measure.view(
321            *layer_measure_broadcast_size
322        ).expand(
323            weight_size,
324        )  # expand the given layer mask to the weight size and broadcast
325
326        if (
327            preceding_layer_name
328        ):  # if the layer is not the first layer, where the preceding layer exists
329
330            preceding_layer_measure_broadcast_size = (1, -1) + tuple(
331                1 for _ in range(len(weight_size) - 2)
332            )  # since the size of mask tensor is (number of units), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
333            preceding_layer_measure = unit_wise_measure[preceding_layer_name]
334            preceding_layer_measure_broadcasted = preceding_layer_measure.view(
335                *preceding_layer_measure_broadcast_size
336            ).expand(
337                weight_size
338            )  # expand the preceding layer mask to the weight size and broadcast
339            weight_measure = aggregation_func(
340                layer_measure_broadcasted, preceding_layer_measure_broadcasted
341            )  # get the minimum of the two mask vectors, from expanded
342        else:  # if the layer is the first layer
343            weight_measure = layer_measure_broadcasted
344
345        # construct the bias-wise measure
346        bias_measure = layer_measure
347
348        return weight_measure, bias_measure
349
350    @override
351    def forward(
352        self,
353        input: Tensor,
354        stage: str,
355        s_max: float | None = None,
356        batch_idx: int | None = None,
357        num_batches: int | None = None,
358        test_mask: dict[str, Tensor] | None = None,
359    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
360        r"""The forward pass for data from task `task_id`. Task-specific mask for `task_id` are applied to the units in each layer.
361
362        **Args:**
363        - **input** (`Tensor`): The input tensor from data.
364        - **stage** (`str`): the stage of the forward pass, should be one of the following:
365            1. 'train': training stage.
366            2. 'validation': validation stage.
367            3. 'test': testing stage.
368        - **s_max** (`float`): the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
369        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
370        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
371        - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`.
372
373        **Returns:**
374        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
375        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
376        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
377
378        """
379        # this should be copied to all subclasses. Make sure it is called to get the mask for the current task from the task embedding in this stage
380        mask = self.get_mask(
381            stage,
382            s_max=s_max,
383            batch_idx=batch_idx,
384            num_batches=num_batches,
385            test_mask=test_mask,
386        )

The backbone network for HAT-based algorithms with learnable hard attention masks.

HAT-based algorithms:

  • HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.
  • Adaptive HAT (Adaptive Hard Attention to the Task, 2024) is an architecture-based continual learning approach that improves HAT (Hard Attention to the Task, 2018) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity.
  • CBPHAT is what I am working on, trying combining HAT (Hard Attention to the Task) algorithm with Continual Backpropagation (CBP) by leveraging the contribution utility as the parameter importance like in AdaHAT (Adaptive Hard Attention to the Task) algorithm.
HATMaskBackbone(output_dim: int | None, gate: str)
118    def __init__(self, output_dim: int | None, gate: str) -> None:
119        r"""Initialise the HAT mask backbone network with task embeddings and masks.
120
121        **Args:**
122        - **output_dim** (`int`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be `None`.
123        - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
124            - `sigmoid`: the sigmoid function.
125        """
126        CLBackbone.__init__(self, output_dim=output_dim)
127
128        self.register_hat_mask_module_explicitly(
129            gate=gate
130        )  # we moved the registration of the modules to a separate method to solve a problem of multiple inheritance in terms of `nn.Module`
131
132        HATMaskBackbone.sanity_check(self)

Initialise the HAT mask backbone network with task embeddings and masks.

Args:

  • output_dim (int): The output dimension which connects to CL output heads. The input_dim of output heads are expected to be the same as this output_dim. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be None.
  • gate (str): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
    • sigmoid: the sigmoid function.
def register_hat_mask_module_explicitly(self, gate: str) -> None:
134    def register_hat_mask_module_explicitly(self, gate: str) -> None:
135        r"""Register all `nn.Module`s explicitly in this method. For `HATMaskBackbone`, they are task embedding for the current task and the masks.
136
137        **Args:**
138        - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
139            - `sigmoid`: the sigmoid function.
140        """
141        self.gate: str = gate
142        r"""Store the type of gate function."""
143        if gate == "sigmoid":
144            self.gate_fn: nn.Module = nn.Sigmoid()
145            r"""The gate function turning the real value task embeddings into attention masks."""
146
147        self.task_embedding_t: nn.ModuleDict = nn.ModuleDict()
148        r"""Store the task embedding for the current task. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units).
149        
150        We use `ModuleDict` rather than `dict` to make sure `LightningModule` can properly register these model parameters for the purpose of, like automatically transfering to device, being recorded in model summaries.
151        
152        we use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.)
153        
154        **This must be defined to cover each weighted layer (just as `self.weighted_layer_names` listed) in the backbone network.** Otherwise, the uncovered parts will keep updating for all tasks and become a source of catastrophic forgetting. """

Register all nn.Modules explicitly in this method. For HATMaskBackbone, they are task embedding for the current task and the masks.

Args:

  • gate (str): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
    • sigmoid: the sigmoid function.
def initialise_task_embedding(self, mode: str) -> None:
156    def initialise_task_embedding(self, mode: str) -> None:
157        r"""Initialise the task embedding for the current task.
158
159        **Args:**
160        - **mode** (`str`): the initialisation mode for task embeddings, should be one of the following:
161            1. 'N01' (default): standard normal distribution $N(0, 1)$.
162            2. 'U-11': uniform distribution $U(-1, 1)$.
163            3. 'U01': uniform distribution $U(0, 1)$.
164            4. 'U-10': uniform distribution $U(-1, 0)$.
165            5. 'last': inherit task embedding from last task.
166        """
167        for te in self.task_embedding_t.values():
168            if mode == "N01":
169                nn.init.normal_(te.weight, 0, 1)
170            elif mode == "U-11":
171                nn.init.uniform_(te.weight, -1, 1)
172            elif mode == "U01":
173                nn.init.uniform_(te.weight, 0, 1)
174            elif mode == "U-10":
175                nn.init.uniform_(te.weight, -1, 0)
176            elif mode == "last":
177                pass

Initialise the task embedding for the current task.

Args:

  • mode (str): the initialisation mode for task embeddings, should be one of the following:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit task embedding from last task.
def sanity_check(self) -> None:
179    def sanity_check(self) -> None:
180        r"""Check the sanity of the arguments.
181
182        **Raises:**
183        - **ValueError**: when the `gate` is not one of the valid options.
184        """
185
186        if self.gate not in ["sigmoid"]:
187            raise ValueError("The gate should be one of 'sigmoid'.")

Check the sanity of the arguments.

Raises:

  • ValueError: when the gate is not one of the valid options.
def get_mask( self, stage: str, s_max: float | None = None, batch_idx: int | None = None, num_batches: int | None = None, test_mask: dict[str, torch.Tensor] | None = None) -> dict[str, torch.Tensor]:
189    def get_mask(
190        self,
191        stage: str,
192        s_max: float | None = None,
193        batch_idx: int | None = None,
194        num_batches: int | None = None,
195        test_mask: dict[str, Tensor] | None = None,
196    ) -> dict[str, Tensor]:
197        r"""Get the hard attention mask used in `forward()` method for different stages.
198
199        **Args:**
200        - **stage** (`str`): the stage when applying the conversion, should be one of the following:
201            1. 'train': training stage. If stage is 'train', get the mask from task embedding of current task through the gate function, which is scaled by an annealed scalar. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
202            2. ‘validation': validation stage. If stage is 'validation', get the mask from task embedding of current task through the gate function, which is scaled by `s_max`. (Note that in this stage, the binary mask hasn't been stored yet as the training is not over.)
203            3. 'test': testing stage. If stage is 'test', apply the mask gate function is scaled by `s_max`, the large scaling making masks nearly binary.
204        - **s_max** (`float`): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
205        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
206        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
207        - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`.
208
209        **Returns:**
210        - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
211
212        **Raises:**
213        - **ValueError**: if the `batch_idx` and `batch_num` are not provided in 'train' stage; if the `s_max` is not provided in 'validation' stage; if the `task_id` is not provided in 'test' stage.
214        """
215
216        # sanity check
217        if stage == "train" and (
218            s_max is None or batch_idx is None or num_batches is None
219        ):
220            raise ValueError(
221                "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`."
222            )
223        if stage == "validation" and (s_max is None):
224            raise ValueError(
225                "The `s_max` should be provided at validation stage, instead of the default value `None`."
226            )
227        if stage == "test" and (test_mask is None):
228            raise ValueError(
229                "The `task_mask` should be provided at testing stage, instead of the default value `None`."
230            )
231
232        mask = {}
233        if stage == "train":
234            for layer_name in self.weighted_layer_names:
235                anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / (
236                    num_batches - 1
237                )  # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
238                mask[layer_name] = self.gate_fn(
239                    self.task_embedding_t[layer_name].weight * anneal_scalar
240                ).squeeze()
241        elif stage == "validation":
242            for layer_name in self.weighted_layer_names:
243                mask[layer_name] = self.gate_fn(
244                    self.task_embedding_t[layer_name].weight * s_max
245                ).squeeze()
246        elif stage == "test":
247            mask = test_mask
248
249        return mask

Get the hard attention mask used in forward() method for different stages.

Args:

  • stage (str): the stage when applying the conversion, should be one of the following:
    1. 'train': training stage. If stage is 'train', get the mask from task embedding of current task through the gate function, which is scaled by an annealed scalar. See chapter 2.4 "Hard Attention Training" in HAT paper.
    2. ‘validation': validation stage. If stage is 'validation', get the mask from task embedding of current task through the gate function, which is scaled by s_max. (Note that in this stage, the binary mask hasn't been stored yet as the training is not over.)
    3. 'test': testing stage. If stage is 'test', apply the mask gate function is scaled by s_max, the large scaling making masks nearly binary.
  • s_max (float): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in HAT paper.
  • batch_idx (int | None): the current batch index. Applies only to training stage. For other stages, it is default None.
  • num_batches (int | None): the total number of batches. Applies only to training stage. For other stages, it is default None.
  • test_mask (dict[str, Tensor] | None): the binary mask used for test. Applies only to testing stage. For other stages, it is default None.

Returns:

  • mask (dict[str, Tensor]): the hard attention (whose values are 0 or 1) mask. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units).

Raises:

  • ValueError: if the batch_idx and batch_num are not provided in 'train' stage; if the s_max is not provided in 'validation' stage; if the task_id is not provided in 'test' stage.
def get_cumulative_mask(self) -> dict[str, torch.Tensor]:
251    def get_cumulative_mask(self) -> dict[str, Tensor]:
252        r"""Get the cumulative mask till current task.
253
254        **Returns:**
255        - **cumulative_mask** (`dict[str, Tensor]`): the cumulative mask. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
256        """
257        return self.cumulative_mask_for_previous_tasks

Get the cumulative mask till current task.

Returns:

  • cumulative_mask (dict[str, Tensor]): the cumulative mask. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units).
def get_summative_mask(self) -> dict[str, torch.Tensor]:
259    def get_summative_mask(self) -> dict[str, Tensor]:
260        r"""Get the summative mask till current task.
261
262        **Returns:**
263        - **summative_mask** (`dict[str, Tensor]`): the summative mask tensor. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
264        """
265        return self.summative_mask_for_previous_tasks

Get the summative mask till current task.

Returns:

  • summative_mask (dict[str, Tensor]): the summative mask tensor. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units).
def get_layer_measure_parameter_wise( self, unit_wise_measure: dict[str, torch.Tensor], layer_name: str, aggregation: str) -> torch.Tensor:
267    def get_layer_measure_parameter_wise(
268        self,
269        unit_wise_measure: dict[str, Tensor],
270        layer_name: str,
271        aggregation: str,
272    ) -> Tensor:
273        r"""Get the parameter-wise measure on the parameters right before the given layer.
274
275        It is calculated from the given unit-wise measure. It aggregates two feature-sized vectors (corresponding the given layer and preceding layer) into a weight-wise matrix (corresponding the weights in between) and bias-wise vector (corresponding the bias of the given layer), using the given aggregation method. For example, given two feature-sized measure $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is then $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
276
277        Note that if the given layer is the first layer with no preceding layer, we will get parameter-wise measure directly broadcasted from the unit-wise measure of given layer.
278
279        This method is used in the calculation of parameter-wise measure in various HAT-based algorithms:
280
281        - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the unit-wise cumulative mask of previous tasks `self.cumulative_mask_for_previous_tasks`, which is $\min \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in equation (2) in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
282        - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise summative mask of previous tasks `self.summative_mask_for_previous_tasks`, which is $\min \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
283        - **CBPHAT**: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise importance of previous tasks `self.unit_importance_for_previous_tasks` based on contribution utility, which is $\min \left(I_{l,i}^{(t-1)}, I_{l-1,j}^{(t-1)}\right)$ in the adjustment rate formula in the paper draft.
284
285        **Args:**
286        - **unit_wise_measure** (`dict[str, Tensor]`): the unit-wise measure. Key is layer name, value is the unit-wise measure tensor. The measure tensor has size (number of units).
287        - **layer_name** (`str`): the name of given layer.
288        - **aggregation** (`str`): the aggregation method turning two feature-wise measures into weight-wise matrix, should be one of the following:
289            - 'min': takes minimum of the two connected unit measures.
290            - 'max': takes maximum of the two connected unit measures.
291
292        **Returns:**
293        - **weight_measure** (`Tensor`): the weight measure matrix, same size as the corresponding weights.
294        - **bias_measure** (`Tensor`): the bias measure vector, same size as the corresponding bias.
295
296
297        """
298
299        # initialise the aggregation function
300        if aggregation == "min":
301            aggregation_func = torch.min
302        elif aggregation == "max":
303            aggregation_func = torch.max
304        else:
305            raise ValueError(f"The aggregation method {aggregation} is not supported.")
306
307        # get the preceding layer name
308        preceding_layer_name = self.preceding_layer_name(layer_name)
309
310        # get weight size for expanding the measures
311        layer = self.get_layer_by_name(layer_name)
312        weight_size = layer.weight.size()
313
314        # construct the weight-wise measure
315        layer_measure = unit_wise_measure[layer_name]
316        layer_measure_broadcast_size = (-1, 1) + tuple(
317            1 for _ in range(len(weight_size) - 2)
318        )  # since the size of mask tensor is (number of units), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
319
320        layer_measure_broadcasted = layer_measure.view(
321            *layer_measure_broadcast_size
322        ).expand(
323            weight_size,
324        )  # expand the given layer mask to the weight size and broadcast
325
326        if (
327            preceding_layer_name
328        ):  # if the layer is not the first layer, where the preceding layer exists
329
330            preceding_layer_measure_broadcast_size = (1, -1) + tuple(
331                1 for _ in range(len(weight_size) - 2)
332            )  # since the size of mask tensor is (number of units), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers
333            preceding_layer_measure = unit_wise_measure[preceding_layer_name]
334            preceding_layer_measure_broadcasted = preceding_layer_measure.view(
335                *preceding_layer_measure_broadcast_size
336            ).expand(
337                weight_size
338            )  # expand the preceding layer mask to the weight size and broadcast
339            weight_measure = aggregation_func(
340                layer_measure_broadcasted, preceding_layer_measure_broadcasted
341            )  # get the minimum of the two mask vectors, from expanded
342        else:  # if the layer is the first layer
343            weight_measure = layer_measure_broadcasted
344
345        # construct the bias-wise measure
346        bias_measure = layer_measure
347
348        return weight_measure, bias_measure

Get the parameter-wise measure on the parameters right before the given layer.

It is calculated from the given unit-wise measure. It aggregates two feature-sized vectors (corresponding the given layer and preceding layer) into a weight-wise matrix (corresponding the weights in between) and bias-wise vector (corresponding the bias of the given layer), using the given aggregation method. For example, given two feature-sized measure $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is then $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.

Note that if the given layer is the first layer with no preceding layer, we will get parameter-wise measure directly broadcasted from the unit-wise measure of given layer.

This method is used in the calculation of parameter-wise measure in various HAT-based algorithms:

  • HAT: the parameter-wise measure is the binary mask for previous tasks from the unit-wise cumulative mask of previous tasks self.cumulative_mask_for_previous_tasks, which is $\min \left(a_{l,i}^{HAT paper.
  • AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise summative mask of previous tasks self.summative_mask_for_previous_tasks, which is $\min \left(m_{l,i}^{AdaHAT paper.
  • CBPHAT: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise importance of previous tasks self.unit_importance_for_previous_tasks based on contribution utility, which is $\min \left(I_{l,i}^{(t-1)}, I_{l-1,j}^{(t-1)}\right)$ in the adjustment rate formula in the paper draft.

Args:

  • unit_wise_measure (dict[str, Tensor]): the unit-wise measure. Key is layer name, value is the unit-wise measure tensor. The measure tensor has size (number of units).
  • layer_name (str): the name of given layer.
  • aggregation (str): the aggregation method turning two feature-wise measures into weight-wise matrix, should be one of the following:
    • 'min': takes minimum of the two connected unit measures.
    • 'max': takes maximum of the two connected unit measures.

Returns:

  • weight_measure (Tensor): the weight measure matrix, same size as the corresponding weights.
  • bias_measure (Tensor): the bias measure vector, same size as the corresponding bias.
@override
def forward( self, input: torch.Tensor, stage: str, s_max: float | None = None, batch_idx: int | None = None, num_batches: int | None = None, test_mask: dict[str, torch.Tensor] | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor]]:
350    @override
351    def forward(
352        self,
353        input: Tensor,
354        stage: str,
355        s_max: float | None = None,
356        batch_idx: int | None = None,
357        num_batches: int | None = None,
358        test_mask: dict[str, Tensor] | None = None,
359    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
360        r"""The forward pass for data from task `task_id`. Task-specific mask for `task_id` are applied to the units in each layer.
361
362        **Args:**
363        - **input** (`Tensor`): The input tensor from data.
364        - **stage** (`str`): the stage of the forward pass, should be one of the following:
365            1. 'train': training stage.
366            2. 'validation': validation stage.
367            3. 'test': testing stage.
368        - **s_max** (`float`): the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
369        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
370        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
371        - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`.
372
373        **Returns:**
374        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
375        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
376        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
377
378        """
379        # this should be copied to all subclasses. Make sure it is called to get the mask for the current task from the task embedding in this stage
380        mask = self.get_mask(
381            stage,
382            s_max=s_max,
383            batch_idx=batch_idx,
384            num_batches=num_batches,
385            test_mask=test_mask,
386        )

The forward pass for data from task task_id. Task-specific mask for task_id are applied to the units in each layer.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): the stage of the forward pass, should be one of the following:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • s_max (float): the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in HAT paper.
  • batch_idx (int | None): the current batch index. Applies only to training stage. For other stages, it is default None.
  • num_batches (int | None): the total number of batches. Applies only to training stage. For other stages, it is default None.
  • test_mask (dict[str, Tensor] | None): the binary mask used for test. Applies only to testing stage. For other stages, it is default None.

Returns:

  • output_feature (Tensor): the output feature tensor to be passed into heads. This is the main target of backpropagation.
  • mask (dict[str, Tensor]): the mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units).
  • hidden_features (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this forward() method of HAT class.