clarena.backbones
Backbone Networks for Continual Learning
This submodule provides the backbone neural network architectures for continual learning.
Please note that this is an API documentation. Please refer to the main documentation pages for more information about the backbone networks and how to configure and implement them:
The backbones are implemented as subclasses of CLBackbone
classes, which are the base class for all continual learning backbones in CLArena.
CLBackbone
: The base class for continual learning backbones.HATMaskBackbone
: The base class for backbones used in HAT (Hard Attention to the Task) algorithm. A child class ofCLBackbone
.
1r""" 2 3# Backbone Networks for Continual Learning 4 5This submodule provides the **backbone neural network architectures for continual learning**. 6 7Please note that this is an API documentation. Please refer to the main documentation pages for more information about the backbone networks and how to 8configure and implement them: 9 10- [**Configure Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiment/backbone-network) 11- [**Implement Your CL Backbone Class**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/implement-your-CL-modules/backbone-network) 12 13 14 15The backbones are implemented as subclasses of `CLBackbone` classes, which are the base class for all continual learning backbones in CLArena. 16 17- `CLBackbone`: The base class for continual learning backbones. 18- `HATMaskBackbone`: The base class for backbones used in [HAT (Hard Attention to the Task) algorithm](http://proceedings.mlr.press/v80/serra18a). A child class of `CLBackbone`. 19 20 21""" 22 23from .base import CLBackbone, HATMaskBackbone 24from .mlp import MLP, HATMaskMLP 25from .resnet import ( 26 HATMaskResNet18, 27 HATMaskResNet34, 28 HATMaskResNet50, 29 HATMaskResNet101, 30 HATMaskResNet152, 31 ResNet18, 32 ResNet34, 33 ResNet50, 34 ResNet101, 35 ResNet152, 36) 37 38__all__ = ["CLBackbone", "HATMaskBackbone", "mlp", "resnet"]
18class CLBackbone(nn.Module): 19 r"""The base class of continual learning backbone networks, inherited from `nn.Module`.""" 20 21 def __init__(self, output_dim: int | None) -> None: 22 r"""Initialise the CL backbone network. 23 24 **Args:** 25 - **output_dim** (`int` | `None`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be `None`. 26 """ 27 nn.Module.__init__(self) 28 29 self.output_dim: int = output_dim 30 r"""Store the output dimension of the backbone network.""" 31 32 self.weighted_layer_names: list[str] = [] 33 r"""Maintain a list of the weighted layer names. Weighted layer has weights connecting to other weighted layer. They are the main part of neural networks. **It must be provided in subclasses.** 34 35 The names are following the `nn.Module` internal naming mechanism. For example, if the a layer is assigned to `self.conv1`, the name becomes `conv1`. If the `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. If hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should be `block.conv1` according to `nn.Module` internal mechanism, but we use '/' instead of '.' to avoid the error of using '.' in the key of `ModuleDict`. 36 37 In HAT architecture, it's also the layer names with task embedding masking in the order of forward pass. HAT gives task embedding to every possible weighted layer. 38 """ 39 40 self.task_id: int 41 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 42 43 def setup_task_id(self, task_id: int) -> None: 44 r"""Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called. 45 46 **Args:** 47 - **task_id** (`int`): the target task ID. 48 """ 49 self.task_id = task_id 50 51 def get_layer_by_name(self, layer_name: str) -> nn.Module: 52 r"""Get the layer by its name. 53 54 **Args:** 55 - **layer_name** (`str`): the name of the layer. Note that the name is the name substituting the '.' with '/', like `block/conv1`, rather than `block.conv1`. 56 57 **Returns:** 58 - **layer** (`nn.Module`): the layer. 59 """ 60 for name, layer in self.named_modules(): 61 if name == layer_name.replace("/", "."): 62 return layer 63 64 def preceding_layer_name(self, layer_name: str) -> str: 65 r"""Get the name of the preceding layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the first layer, return `None`. 66 67 **Args:** 68 - **layer_name** (`str`): the name of the layer. 69 70 **Returns:** 71 - **preceding_layer_name** (`str`): the name of the preceding layer. 72 73 **Raises:** 74 - **ValueError**: if `layer_name` is not in the weighted layer order. 75 """ 76 77 if layer_name not in self.weighted_layer_names: 78 raise ValueError(f"The layer name {layer_name} doesn't exist.") 79 80 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 81 if weighted_layer_idx == 0: 82 return None 83 return self.weighted_layer_names[weighted_layer_idx - 1] 84 85 @override # since `nn.Module` uses it 86 def forward( 87 self, 88 input: Tensor, 89 stage: str, 90 task_id: int | None = None, 91 ) -> tuple[Tensor, dict[str, Tensor]]: 92 r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.** 93 94 **Args:** 95 - **input** (`Tensor`): The input tensor from data. 96 - **stage** (`str`): the stage of the forward pass, should be one of the following: 97 1. 'train': training stage. 98 2. 'validation': validation stage. 99 3. 'test': testing stage. 100 - **task_id** (`int` | `None`): the task ID where the data are from. If stage is 'train' or 'validation', it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. 101 102 **Returns:** 103 - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation. 104 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 105 """
The base class of continual learning backbone networks, inherited from nn.Module
.
21 def __init__(self, output_dim: int | None) -> None: 22 r"""Initialise the CL backbone network. 23 24 **Args:** 25 - **output_dim** (`int` | `None`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be `None`. 26 """ 27 nn.Module.__init__(self) 28 29 self.output_dim: int = output_dim 30 r"""Store the output dimension of the backbone network.""" 31 32 self.weighted_layer_names: list[str] = [] 33 r"""Maintain a list of the weighted layer names. Weighted layer has weights connecting to other weighted layer. They are the main part of neural networks. **It must be provided in subclasses.** 34 35 The names are following the `nn.Module` internal naming mechanism. For example, if the a layer is assigned to `self.conv1`, the name becomes `conv1`. If the `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. If hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should be `block.conv1` according to `nn.Module` internal mechanism, but we use '/' instead of '.' to avoid the error of using '.' in the key of `ModuleDict`. 36 37 In HAT architecture, it's also the layer names with task embedding masking in the order of forward pass. HAT gives task embedding to every possible weighted layer. 38 """ 39 40 self.task_id: int 41 r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
Initialise the CL backbone network.
Args:
- output_dim (
int
|None
): The output dimension which connects to CL output heads. Theinput_dim
of output heads are expected to be the same as thisoutput_dim
. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can beNone
.
Maintain a list of the weighted layer names. Weighted layer has weights connecting to other weighted layer. They are the main part of neural networks. It must be provided in subclasses.
The names are following the nn.Module
internal naming mechanism. For example, if the a layer is assigned to self.conv1
, the name becomes conv1
. If the nn.Sequential
is used, the name becomes the index of the layer in the sequence, such as 0
, 1
, etc. If hierarchical structure is used, for example, a nn.Module
is assigned to self.block
which has self.conv1
, the name becomes block/conv1
. Note that it should be block.conv1
according to nn.Module
internal mechanism, but we use '/' instead of '.' to avoid the error of using '.' in the key of ModuleDict
.
In HAT architecture, it's also the layer names with task embedding masking in the order of forward pass. HAT gives task embedding to every possible weighted layer.
Task ID counter indicating which task is being processed. Self updated during the task loop.
43 def setup_task_id(self, task_id: int) -> None: 44 r"""Set up which task's dataset the CL experiment is on. This must be done before `forward()` method is called. 45 46 **Args:** 47 - **task_id** (`int`): the target task ID. 48 """ 49 self.task_id = task_id
Set up which task's dataset the CL experiment is on. This must be done before forward()
method is called.
Args:
- task_id (
int
): the target task ID.
51 def get_layer_by_name(self, layer_name: str) -> nn.Module: 52 r"""Get the layer by its name. 53 54 **Args:** 55 - **layer_name** (`str`): the name of the layer. Note that the name is the name substituting the '.' with '/', like `block/conv1`, rather than `block.conv1`. 56 57 **Returns:** 58 - **layer** (`nn.Module`): the layer. 59 """ 60 for name, layer in self.named_modules(): 61 if name == layer_name.replace("/", "."): 62 return layer
Get the layer by its name.
Args:
- layer_name (
str
): the name of the layer. Note that the name is the name substituting the '.' with '/', likeblock/conv1
, rather thanblock.conv1
.
Returns:
- layer (
nn.Module
): the layer.
64 def preceding_layer_name(self, layer_name: str) -> str: 65 r"""Get the name of the preceding layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the first layer, return `None`. 66 67 **Args:** 68 - **layer_name** (`str`): the name of the layer. 69 70 **Returns:** 71 - **preceding_layer_name** (`str`): the name of the preceding layer. 72 73 **Raises:** 74 - **ValueError**: if `layer_name` is not in the weighted layer order. 75 """ 76 77 if layer_name not in self.weighted_layer_names: 78 raise ValueError(f"The layer name {layer_name} doesn't exist.") 79 80 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 81 if weighted_layer_idx == 0: 82 return None 83 return self.weighted_layer_names[weighted_layer_idx - 1]
Get the name of the preceding layer of the given layer from the stored self.masked_layer_order
. If the given layer is the first layer, return None
.
Args:
- layer_name (
str
): the name of the layer.
Returns:
- preceding_layer_name (
str
): the name of the preceding layer.
Raises:
- ValueError: if
layer_name
is not in the weighted layer order.
85 @override # since `nn.Module` uses it 86 def forward( 87 self, 88 input: Tensor, 89 stage: str, 90 task_id: int | None = None, 91 ) -> tuple[Tensor, dict[str, Tensor]]: 92 r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.** 93 94 **Args:** 95 - **input** (`Tensor`): The input tensor from data. 96 - **stage** (`str`): the stage of the forward pass, should be one of the following: 97 1. 'train': training stage. 98 2. 'validation': validation stage. 99 3. 'test': testing stage. 100 - **task_id** (`int` | `None`): the task ID where the data are from. If stage is 'train' or 'validation', it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. 101 102 **Returns:** 103 - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation. 104 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 105 """
The forward pass for data from task task_id
. In some backbones, the forward pass might be different for different tasks. It must be implemented by subclasses.
Args:
- input (
Tensor
): The input tensor from data. - stage (
str
): the stage of the forward pass, should be one of the following:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int
|None
): the task ID where the data are from. If stage is 'train' or 'validation', it is usually from the current taskself.task_id
. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value.
Returns:
- output_feature (
Tensor
): the output feature tensor to be passed into heads. This is the main target of backpropagation. - hidden_features (
dict[str, Tensor]
): the hidden features (after activation) in each weighted layer. Key (str
) is the weighted layer name, value (Tensor
) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
108class HATMaskBackbone(CLBackbone): 109 r"""The backbone network for HAT-based algorithms with learnable hard attention masks. 110 111 HAT-based algorithms: 112 113 - [**HAT (Hard Attention to the Task, 2018)**](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters. 114 - [**Adaptive HAT (Adaptive Hard Attention to the Task, 2024)**](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity. 115 - **CBPHAT** is what I am working on, trying combining HAT (Hard Attention to the Task) algorithm with Continual Backpropagation (CBP) by leveraging the contribution utility as the parameter importance like in AdaHAT (Adaptive Hard Attention to the Task) algorithm. 116 """ 117 118 def __init__(self, output_dim: int | None, gate: str) -> None: 119 r"""Initialise the HAT mask backbone network with task embeddings and masks. 120 121 **Args:** 122 - **output_dim** (`int`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be `None`. 123 - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following: 124 - `sigmoid`: the sigmoid function. 125 """ 126 CLBackbone.__init__(self, output_dim=output_dim) 127 128 self.register_hat_mask_module_explicitly( 129 gate=gate 130 ) # we moved the registration of the modules to a separate method to solve a problem of multiple inheritance in terms of `nn.Module` 131 132 HATMaskBackbone.sanity_check(self) 133 134 def register_hat_mask_module_explicitly(self, gate: str) -> None: 135 r"""Register all `nn.Module`s explicitly in this method. For `HATMaskBackbone`, they are task embedding for the current task and the masks. 136 137 **Args:** 138 - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following: 139 - `sigmoid`: the sigmoid function. 140 """ 141 self.gate: str = gate 142 r"""Store the type of gate function.""" 143 if gate == "sigmoid": 144 self.gate_fn: nn.Module = nn.Sigmoid() 145 r"""The gate function turning the real value task embeddings into attention masks.""" 146 147 self.task_embedding_t: nn.ModuleDict = nn.ModuleDict() 148 r"""Store the task embedding for the current task. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units). 149 150 We use `ModuleDict` rather than `dict` to make sure `LightningModule` can properly register these model parameters for the purpose of, like automatically transfering to device, being recorded in model summaries. 151 152 we use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 153 154 **This must be defined to cover each weighted layer (just as `self.weighted_layer_names` listed) in the backbone network.** Otherwise, the uncovered parts will keep updating for all tasks and become a source of catastrophic forgetting. """ 155 156 def initialise_task_embedding(self, mode: str) -> None: 157 r"""Initialise the task embedding for the current task. 158 159 **Args:** 160 - **mode** (`str`): the initialisation mode for task embeddings, should be one of the following: 161 1. 'N01' (default): standard normal distribution $N(0, 1)$. 162 2. 'U-11': uniform distribution $U(-1, 1)$. 163 3. 'U01': uniform distribution $U(0, 1)$. 164 4. 'U-10': uniform distribution $U(-1, 0)$. 165 5. 'last': inherit task embedding from last task. 166 """ 167 for te in self.task_embedding_t.values(): 168 if mode == "N01": 169 nn.init.normal_(te.weight, 0, 1) 170 elif mode == "U-11": 171 nn.init.uniform_(te.weight, -1, 1) 172 elif mode == "U01": 173 nn.init.uniform_(te.weight, 0, 1) 174 elif mode == "U-10": 175 nn.init.uniform_(te.weight, -1, 0) 176 elif mode == "last": 177 pass 178 179 def sanity_check(self) -> None: 180 r"""Check the sanity of the arguments. 181 182 **Raises:** 183 - **ValueError**: when the `gate` is not one of the valid options. 184 """ 185 186 if self.gate not in ["sigmoid"]: 187 raise ValueError("The gate should be one of 'sigmoid'.") 188 189 def get_mask( 190 self, 191 stage: str, 192 s_max: float | None = None, 193 batch_idx: int | None = None, 194 num_batches: int | None = None, 195 test_mask: dict[str, Tensor] | None = None, 196 ) -> dict[str, Tensor]: 197 r"""Get the hard attention mask used in `forward()` method for different stages. 198 199 **Args:** 200 - **stage** (`str`): the stage when applying the conversion, should be one of the following: 201 1. 'train': training stage. If stage is 'train', get the mask from task embedding of current task through the gate function, which is scaled by an annealed scalar. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 202 2. ‘validation': validation stage. If stage is 'validation', get the mask from task embedding of current task through the gate function, which is scaled by `s_max`. (Note that in this stage, the binary mask hasn't been stored yet as the training is not over.) 203 3. 'test': testing stage. If stage is 'test', apply the mask gate function is scaled by `s_max`, the large scaling making masks nearly binary. 204 - **s_max** (`float`): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 205 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 206 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 207 - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`. 208 209 **Returns:** 210 - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 211 212 **Raises:** 213 - **ValueError**: if the `batch_idx` and `batch_num` are not provided in 'train' stage; if the `s_max` is not provided in 'validation' stage; if the `task_id` is not provided in 'test' stage. 214 """ 215 216 # sanity check 217 if stage == "train" and ( 218 s_max is None or batch_idx is None or num_batches is None 219 ): 220 raise ValueError( 221 "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`." 222 ) 223 if stage == "validation" and (s_max is None): 224 raise ValueError( 225 "The `s_max` should be provided at validation stage, instead of the default value `None`." 226 ) 227 if stage == "test" and (test_mask is None): 228 raise ValueError( 229 "The `task_mask` should be provided at testing stage, instead of the default value `None`." 230 ) 231 232 mask = {} 233 if stage == "train": 234 for layer_name in self.weighted_layer_names: 235 anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / ( 236 num_batches - 1 237 ) # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 238 mask[layer_name] = self.gate_fn( 239 self.task_embedding_t[layer_name].weight * anneal_scalar 240 ).squeeze() 241 elif stage == "validation": 242 for layer_name in self.weighted_layer_names: 243 mask[layer_name] = self.gate_fn( 244 self.task_embedding_t[layer_name].weight * s_max 245 ).squeeze() 246 elif stage == "test": 247 mask = test_mask 248 249 return mask 250 251 def get_cumulative_mask(self) -> dict[str, Tensor]: 252 r"""Get the cumulative mask till current task. 253 254 **Returns:** 255 - **cumulative_mask** (`dict[str, Tensor]`): the cumulative mask. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 256 """ 257 return self.cumulative_mask_for_previous_tasks 258 259 def get_summative_mask(self) -> dict[str, Tensor]: 260 r"""Get the summative mask till current task. 261 262 **Returns:** 263 - **summative_mask** (`dict[str, Tensor]`): the summative mask tensor. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 264 """ 265 return self.summative_mask_for_previous_tasks 266 267 def get_layer_measure_parameter_wise( 268 self, 269 unit_wise_measure: dict[str, Tensor], 270 layer_name: str, 271 aggregation: str, 272 ) -> Tensor: 273 r"""Get the parameter-wise measure on the parameters right before the given layer. 274 275 It is calculated from the given unit-wise measure. It aggregates two feature-sized vectors (corresponding the given layer and preceding layer) into a weight-wise matrix (corresponding the weights in between) and bias-wise vector (corresponding the bias of the given layer), using the given aggregation method. For example, given two feature-sized measure $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is then $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$. 276 277 Note that if the given layer is the first layer with no preceding layer, we will get parameter-wise measure directly broadcasted from the unit-wise measure of given layer. 278 279 This method is used in the calculation of parameter-wise measure in various HAT-based algorithms: 280 281 - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the unit-wise cumulative mask of previous tasks `self.cumulative_mask_for_previous_tasks`, which is $\min \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in equation (2) in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 282 - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise summative mask of previous tasks `self.summative_mask_for_previous_tasks`, which is $\min \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 283 - **CBPHAT**: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise importance of previous tasks `self.unit_importance_for_previous_tasks` based on contribution utility, which is $\min \left(I_{l,i}^{(t-1)}, I_{l-1,j}^{(t-1)}\right)$ in the adjustment rate formula in the paper draft. 284 285 **Args:** 286 - **unit_wise_measure** (`dict[str, Tensor]`): the unit-wise measure. Key is layer name, value is the unit-wise measure tensor. The measure tensor has size (number of units). 287 - **layer_name** (`str`): the name of given layer. 288 - **aggregation** (`str`): the aggregation method turning two feature-wise measures into weight-wise matrix, should be one of the following: 289 - 'min': takes minimum of the two connected unit measures. 290 - 'max': takes maximum of the two connected unit measures. 291 292 **Returns:** 293 - **weight_measure** (`Tensor`): the weight measure matrix, same size as the corresponding weights. 294 - **bias_measure** (`Tensor`): the bias measure vector, same size as the corresponding bias. 295 296 297 """ 298 299 # initialise the aggregation function 300 if aggregation == "min": 301 aggregation_func = torch.min 302 elif aggregation == "max": 303 aggregation_func = torch.max 304 else: 305 raise ValueError(f"The aggregation method {aggregation} is not supported.") 306 307 # get the preceding layer name 308 preceding_layer_name = self.preceding_layer_name(layer_name) 309 310 # get weight size for expanding the measures 311 layer = self.get_layer_by_name(layer_name) 312 weight_size = layer.weight.size() 313 314 # construct the weight-wise measure 315 layer_measure = unit_wise_measure[layer_name] 316 layer_measure_broadcast_size = (-1, 1) + tuple( 317 1 for _ in range(len(weight_size) - 2) 318 ) # since the size of mask tensor is (number of units), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 319 320 layer_measure_broadcasted = layer_measure.view( 321 *layer_measure_broadcast_size 322 ).expand( 323 weight_size, 324 ) # expand the given layer mask to the weight size and broadcast 325 326 if ( 327 preceding_layer_name 328 ): # if the layer is not the first layer, where the preceding layer exists 329 330 preceding_layer_measure_broadcast_size = (1, -1) + tuple( 331 1 for _ in range(len(weight_size) - 2) 332 ) # since the size of mask tensor is (number of units), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 333 preceding_layer_measure = unit_wise_measure[preceding_layer_name] 334 preceding_layer_measure_broadcasted = preceding_layer_measure.view( 335 *preceding_layer_measure_broadcast_size 336 ).expand( 337 weight_size 338 ) # expand the preceding layer mask to the weight size and broadcast 339 weight_measure = aggregation_func( 340 layer_measure_broadcasted, preceding_layer_measure_broadcasted 341 ) # get the minimum of the two mask vectors, from expanded 342 else: # if the layer is the first layer 343 weight_measure = layer_measure_broadcasted 344 345 # construct the bias-wise measure 346 bias_measure = layer_measure 347 348 return weight_measure, bias_measure 349 350 @override 351 def forward( 352 self, 353 input: Tensor, 354 stage: str, 355 s_max: float | None = None, 356 batch_idx: int | None = None, 357 num_batches: int | None = None, 358 test_mask: dict[str, Tensor] | None = None, 359 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 360 r"""The forward pass for data from task `task_id`. Task-specific mask for `task_id` are applied to the units in each layer. 361 362 **Args:** 363 - **input** (`Tensor`): The input tensor from data. 364 - **stage** (`str`): the stage of the forward pass, should be one of the following: 365 1. 'train': training stage. 366 2. 'validation': validation stage. 367 3. 'test': testing stage. 368 - **s_max** (`float`): the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 369 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 370 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 371 - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`. 372 373 **Returns:** 374 - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation. 375 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 376 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 377 378 """ 379 # this should be copied to all subclasses. Make sure it is called to get the mask for the current task from the task embedding in this stage 380 mask = self.get_mask( 381 stage, 382 s_max=s_max, 383 batch_idx=batch_idx, 384 num_batches=num_batches, 385 test_mask=test_mask, 386 )
The backbone network for HAT-based algorithms with learnable hard attention masks.
HAT-based algorithms:
- HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.
- Adaptive HAT (Adaptive Hard Attention to the Task, 2024) is an architecture-based continual learning approach that improves HAT (Hard Attention to the Task, 2018) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity.
- CBPHAT is what I am working on, trying combining HAT (Hard Attention to the Task) algorithm with Continual Backpropagation (CBP) by leveraging the contribution utility as the parameter importance like in AdaHAT (Adaptive Hard Attention to the Task) algorithm.
118 def __init__(self, output_dim: int | None, gate: str) -> None: 119 r"""Initialise the HAT mask backbone network with task embeddings and masks. 120 121 **Args:** 122 - **output_dim** (`int`): The output dimension which connects to CL output heads. The `input_dim` of output heads are expected to be the same as this `output_dim`. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can be `None`. 123 - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following: 124 - `sigmoid`: the sigmoid function. 125 """ 126 CLBackbone.__init__(self, output_dim=output_dim) 127 128 self.register_hat_mask_module_explicitly( 129 gate=gate 130 ) # we moved the registration of the modules to a separate method to solve a problem of multiple inheritance in terms of `nn.Module` 131 132 HATMaskBackbone.sanity_check(self)
Initialise the HAT mask backbone network with task embeddings and masks.
Args:
- output_dim (
int
): The output dimension which connects to CL output heads. Theinput_dim
of output heads are expected to be the same as thisoutput_dim
. In some cases, this class is used for a block in the backbone network, which doesn't have the output dimension. In this case, it can beNone
. - gate (
str
): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:sigmoid
: the sigmoid function.
134 def register_hat_mask_module_explicitly(self, gate: str) -> None: 135 r"""Register all `nn.Module`s explicitly in this method. For `HATMaskBackbone`, they are task embedding for the current task and the masks. 136 137 **Args:** 138 - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following: 139 - `sigmoid`: the sigmoid function. 140 """ 141 self.gate: str = gate 142 r"""Store the type of gate function.""" 143 if gate == "sigmoid": 144 self.gate_fn: nn.Module = nn.Sigmoid() 145 r"""The gate function turning the real value task embeddings into attention masks.""" 146 147 self.task_embedding_t: nn.ModuleDict = nn.ModuleDict() 148 r"""Store the task embedding for the current task. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units). 149 150 We use `ModuleDict` rather than `dict` to make sure `LightningModule` can properly register these model parameters for the purpose of, like automatically transfering to device, being recorded in model summaries. 151 152 we use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 153 154 **This must be defined to cover each weighted layer (just as `self.weighted_layer_names` listed) in the backbone network.** Otherwise, the uncovered parts will keep updating for all tasks and become a source of catastrophic forgetting. """
Register all nn.Module
s explicitly in this method. For HATMaskBackbone
, they are task embedding for the current task and the masks.
Args:
- gate (
str
): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:sigmoid
: the sigmoid function.
156 def initialise_task_embedding(self, mode: str) -> None: 157 r"""Initialise the task embedding for the current task. 158 159 **Args:** 160 - **mode** (`str`): the initialisation mode for task embeddings, should be one of the following: 161 1. 'N01' (default): standard normal distribution $N(0, 1)$. 162 2. 'U-11': uniform distribution $U(-1, 1)$. 163 3. 'U01': uniform distribution $U(0, 1)$. 164 4. 'U-10': uniform distribution $U(-1, 0)$. 165 5. 'last': inherit task embedding from last task. 166 """ 167 for te in self.task_embedding_t.values(): 168 if mode == "N01": 169 nn.init.normal_(te.weight, 0, 1) 170 elif mode == "U-11": 171 nn.init.uniform_(te.weight, -1, 1) 172 elif mode == "U01": 173 nn.init.uniform_(te.weight, 0, 1) 174 elif mode == "U-10": 175 nn.init.uniform_(te.weight, -1, 0) 176 elif mode == "last": 177 pass
Initialise the task embedding for the current task.
Args:
- mode (
str
): the initialisation mode for task embeddings, should be one of the following:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit task embedding from last task.
179 def sanity_check(self) -> None: 180 r"""Check the sanity of the arguments. 181 182 **Raises:** 183 - **ValueError**: when the `gate` is not one of the valid options. 184 """ 185 186 if self.gate not in ["sigmoid"]: 187 raise ValueError("The gate should be one of 'sigmoid'.")
Check the sanity of the arguments.
Raises:
- ValueError: when the
gate
is not one of the valid options.
189 def get_mask( 190 self, 191 stage: str, 192 s_max: float | None = None, 193 batch_idx: int | None = None, 194 num_batches: int | None = None, 195 test_mask: dict[str, Tensor] | None = None, 196 ) -> dict[str, Tensor]: 197 r"""Get the hard attention mask used in `forward()` method for different stages. 198 199 **Args:** 200 - **stage** (`str`): the stage when applying the conversion, should be one of the following: 201 1. 'train': training stage. If stage is 'train', get the mask from task embedding of current task through the gate function, which is scaled by an annealed scalar. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 202 2. ‘validation': validation stage. If stage is 'validation', get the mask from task embedding of current task through the gate function, which is scaled by `s_max`. (Note that in this stage, the binary mask hasn't been stored yet as the training is not over.) 203 3. 'test': testing stage. If stage is 'test', apply the mask gate function is scaled by `s_max`, the large scaling making masks nearly binary. 204 - **s_max** (`float`): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 205 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 206 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 207 - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`. 208 209 **Returns:** 210 - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 211 212 **Raises:** 213 - **ValueError**: if the `batch_idx` and `batch_num` are not provided in 'train' stage; if the `s_max` is not provided in 'validation' stage; if the `task_id` is not provided in 'test' stage. 214 """ 215 216 # sanity check 217 if stage == "train" and ( 218 s_max is None or batch_idx is None or num_batches is None 219 ): 220 raise ValueError( 221 "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`." 222 ) 223 if stage == "validation" and (s_max is None): 224 raise ValueError( 225 "The `s_max` should be provided at validation stage, instead of the default value `None`." 226 ) 227 if stage == "test" and (test_mask is None): 228 raise ValueError( 229 "The `task_mask` should be provided at testing stage, instead of the default value `None`." 230 ) 231 232 mask = {} 233 if stage == "train": 234 for layer_name in self.weighted_layer_names: 235 anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / ( 236 num_batches - 1 237 ) # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 238 mask[layer_name] = self.gate_fn( 239 self.task_embedding_t[layer_name].weight * anneal_scalar 240 ).squeeze() 241 elif stage == "validation": 242 for layer_name in self.weighted_layer_names: 243 mask[layer_name] = self.gate_fn( 244 self.task_embedding_t[layer_name].weight * s_max 245 ).squeeze() 246 elif stage == "test": 247 mask = test_mask 248 249 return mask
Get the hard attention mask used in forward()
method for different stages.
Args:
- stage (
str
): the stage when applying the conversion, should be one of the following:- 'train': training stage. If stage is 'train', get the mask from task embedding of current task through the gate function, which is scaled by an annealed scalar. See chapter 2.4 "Hard Attention Training" in HAT paper.
- ‘validation': validation stage. If stage is 'validation', get the mask from task embedding of current task through the gate function, which is scaled by
s_max
. (Note that in this stage, the binary mask hasn't been stored yet as the training is not over.) - 'test': testing stage. If stage is 'test', apply the mask gate function is scaled by
s_max
, the large scaling making masks nearly binary.
- s_max (
float
): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in HAT paper. - batch_idx (
int
|None
): the current batch index. Applies only to training stage. For other stages, it is defaultNone
. - num_batches (
int
|None
): the total number of batches. Applies only to training stage. For other stages, it is defaultNone
. - test_mask (
dict[str, Tensor]
|None
): the binary mask used for test. Applies only to testing stage. For other stages, it is defaultNone
.
Returns:
- mask (
dict[str, Tensor]
): the hard attention (whose values are 0 or 1) mask. Key (str
) is layer name, value (Tensor
) is the mask tensor. The mask tensor has size (number of units).
Raises:
- ValueError: if the
batch_idx
andbatch_num
are not provided in 'train' stage; if thes_max
is not provided in 'validation' stage; if thetask_id
is not provided in 'test' stage.
251 def get_cumulative_mask(self) -> dict[str, Tensor]: 252 r"""Get the cumulative mask till current task. 253 254 **Returns:** 255 - **cumulative_mask** (`dict[str, Tensor]`): the cumulative mask. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 256 """ 257 return self.cumulative_mask_for_previous_tasks
Get the cumulative mask till current task.
Returns:
- cumulative_mask (
dict[str, Tensor]
): the cumulative mask. Key (str
) is layer name, value (Tensor
) is the mask tensor. The mask tensor has size (number of units).
259 def get_summative_mask(self) -> dict[str, Tensor]: 260 r"""Get the summative mask till current task. 261 262 **Returns:** 263 - **summative_mask** (`dict[str, Tensor]`): the summative mask tensor. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 264 """ 265 return self.summative_mask_for_previous_tasks
Get the summative mask till current task.
Returns:
- summative_mask (
dict[str, Tensor]
): the summative mask tensor. Key (str
) is layer name, value (Tensor
) is the mask tensor. The mask tensor has size (number of units).
267 def get_layer_measure_parameter_wise( 268 self, 269 unit_wise_measure: dict[str, Tensor], 270 layer_name: str, 271 aggregation: str, 272 ) -> Tensor: 273 r"""Get the parameter-wise measure on the parameters right before the given layer. 274 275 It is calculated from the given unit-wise measure. It aggregates two feature-sized vectors (corresponding the given layer and preceding layer) into a weight-wise matrix (corresponding the weights in between) and bias-wise vector (corresponding the bias of the given layer), using the given aggregation method. For example, given two feature-sized measure $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is then $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$. 276 277 Note that if the given layer is the first layer with no preceding layer, we will get parameter-wise measure directly broadcasted from the unit-wise measure of given layer. 278 279 This method is used in the calculation of parameter-wise measure in various HAT-based algorithms: 280 281 - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the unit-wise cumulative mask of previous tasks `self.cumulative_mask_for_previous_tasks`, which is $\min \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in equation (2) in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 282 - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise summative mask of previous tasks `self.summative_mask_for_previous_tasks`, which is $\min \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 283 - **CBPHAT**: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise importance of previous tasks `self.unit_importance_for_previous_tasks` based on contribution utility, which is $\min \left(I_{l,i}^{(t-1)}, I_{l-1,j}^{(t-1)}\right)$ in the adjustment rate formula in the paper draft. 284 285 **Args:** 286 - **unit_wise_measure** (`dict[str, Tensor]`): the unit-wise measure. Key is layer name, value is the unit-wise measure tensor. The measure tensor has size (number of units). 287 - **layer_name** (`str`): the name of given layer. 288 - **aggregation** (`str`): the aggregation method turning two feature-wise measures into weight-wise matrix, should be one of the following: 289 - 'min': takes minimum of the two connected unit measures. 290 - 'max': takes maximum of the two connected unit measures. 291 292 **Returns:** 293 - **weight_measure** (`Tensor`): the weight measure matrix, same size as the corresponding weights. 294 - **bias_measure** (`Tensor`): the bias measure vector, same size as the corresponding bias. 295 296 297 """ 298 299 # initialise the aggregation function 300 if aggregation == "min": 301 aggregation_func = torch.min 302 elif aggregation == "max": 303 aggregation_func = torch.max 304 else: 305 raise ValueError(f"The aggregation method {aggregation} is not supported.") 306 307 # get the preceding layer name 308 preceding_layer_name = self.preceding_layer_name(layer_name) 309 310 # get weight size for expanding the measures 311 layer = self.get_layer_by_name(layer_name) 312 weight_size = layer.weight.size() 313 314 # construct the weight-wise measure 315 layer_measure = unit_wise_measure[layer_name] 316 layer_measure_broadcast_size = (-1, 1) + tuple( 317 1 for _ in range(len(weight_size) - 2) 318 ) # since the size of mask tensor is (number of units), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 319 320 layer_measure_broadcasted = layer_measure.view( 321 *layer_measure_broadcast_size 322 ).expand( 323 weight_size, 324 ) # expand the given layer mask to the weight size and broadcast 325 326 if ( 327 preceding_layer_name 328 ): # if the layer is not the first layer, where the preceding layer exists 329 330 preceding_layer_measure_broadcast_size = (1, -1) + tuple( 331 1 for _ in range(len(weight_size) - 2) 332 ) # since the size of mask tensor is (number of units), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 333 preceding_layer_measure = unit_wise_measure[preceding_layer_name] 334 preceding_layer_measure_broadcasted = preceding_layer_measure.view( 335 *preceding_layer_measure_broadcast_size 336 ).expand( 337 weight_size 338 ) # expand the preceding layer mask to the weight size and broadcast 339 weight_measure = aggregation_func( 340 layer_measure_broadcasted, preceding_layer_measure_broadcasted 341 ) # get the minimum of the two mask vectors, from expanded 342 else: # if the layer is the first layer 343 weight_measure = layer_measure_broadcasted 344 345 # construct the bias-wise measure 346 bias_measure = layer_measure 347 348 return weight_measure, bias_measure
Get the parameter-wise measure on the parameters right before the given layer.
It is calculated from the given unit-wise measure. It aggregates two feature-sized vectors (corresponding the given layer and preceding layer) into a weight-wise matrix (corresponding the weights in between) and bias-wise vector (corresponding the bias of the given layer), using the given aggregation method. For example, given two feature-sized measure $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is then $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
Note that if the given layer is the first layer with no preceding layer, we will get parameter-wise measure directly broadcasted from the unit-wise measure of given layer.
This method is used in the calculation of parameter-wise measure in various HAT-based algorithms:
- HAT: the parameter-wise measure is the binary mask for previous tasks from the unit-wise cumulative mask of previous tasks
self.cumulative_mask_for_previous_tasks
, which is $\min \left(a_{l,i}^{HAT paper. - AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise summative mask of previous tasks
self.summative_mask_for_previous_tasks
, which is $\min \left(m_{l,i}^{AdaHAT paper. - CBPHAT: the parameter-wise measure is the parameter importance for previous tasks from the unit-wise importance of previous tasks
self.unit_importance_for_previous_tasks
based on contribution utility, which is $\min \left(I_{l,i}^{(t-1)}, I_{l-1,j}^{(t-1)}\right)$ in the adjustment rate formula in the paper draft.
Args:
- unit_wise_measure (
dict[str, Tensor]
): the unit-wise measure. Key is layer name, value is the unit-wise measure tensor. The measure tensor has size (number of units). - layer_name (
str
): the name of given layer. - aggregation (
str
): the aggregation method turning two feature-wise measures into weight-wise matrix, should be one of the following:- 'min': takes minimum of the two connected unit measures.
- 'max': takes maximum of the two connected unit measures.
Returns:
- weight_measure (
Tensor
): the weight measure matrix, same size as the corresponding weights. - bias_measure (
Tensor
): the bias measure vector, same size as the corresponding bias.
350 @override 351 def forward( 352 self, 353 input: Tensor, 354 stage: str, 355 s_max: float | None = None, 356 batch_idx: int | None = None, 357 num_batches: int | None = None, 358 test_mask: dict[str, Tensor] | None = None, 359 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 360 r"""The forward pass for data from task `task_id`. Task-specific mask for `task_id` are applied to the units in each layer. 361 362 **Args:** 363 - **input** (`Tensor`): The input tensor from data. 364 - **stage** (`str`): the stage of the forward pass, should be one of the following: 365 1. 'train': training stage. 366 2. 'validation': validation stage. 367 3. 'test': testing stage. 368 - **s_max** (`float`): the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 369 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 370 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 371 - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`. 372 373 **Returns:** 374 - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation. 375 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 376 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 377 378 """ 379 # this should be copied to all subclasses. Make sure it is called to get the mask for the current task from the task embedding in this stage 380 mask = self.get_mask( 381 stage, 382 s_max=s_max, 383 batch_idx=batch_idx, 384 num_batches=num_batches, 385 test_mask=test_mask, 386 )
The forward pass for data from task task_id
. Task-specific mask for task_id
are applied to the units in each layer.
Args:
- input (
Tensor
): The input tensor from data. - stage (
str
): the stage of the forward pass, should be one of the following:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- s_max (
float
): the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in HAT paper. - batch_idx (
int
|None
): the current batch index. Applies only to training stage. For other stages, it is defaultNone
. - num_batches (
int
|None
): the total number of batches. Applies only to training stage. For other stages, it is defaultNone
. - test_mask (
dict[str, Tensor]
|None
): the binary mask used for test. Applies only to testing stage. For other stages, it is defaultNone
.
Returns:
- output_feature (
Tensor
): the output feature tensor to be passed into heads. This is the main target of backpropagation. - mask (
dict[str, Tensor]
): the mask for the current task. Key (str
) is layer name, value (Tensor
) is the mask tensor. The mask tensor has size (number of units). - hidden_features (
dict[str, Tensor]
): the hidden features (after activation) in each weighted layer. Key (str
) is the weighted layer name, value (Tensor
) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited thisforward()
method ofHAT
class.