clarena.backbones
Backbone Networks
This submodule provides the backbone neural network architectures for all paradigms in CLArena.
Here are the base classes for backbone networks, which inherit from PyTorch nn.Module:
Backbone: the base class for all backbone networks. Multi-task and single-task learning can use this class directly.CLBackbone: the base class for continual learning backbone networks, which incorporates mechanisms for managing continual learning tasks.HATMaskBackbone: the base class for backbones used in HAT (Hard Attention to the Task) CL algorithm.WSNMaskBackbone: The base class for backbones used in WSN (Winning Subnetworks) CL algorithm.
Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement backbone networks:
1r""" 2 3# Backbone Networks 4 5This submodule provides the **backbone neural network architectures** for all paradigms in CLArena. 6 7Here are the base classes for backbone networks, which inherit from PyTorch `nn.Module`: 8 9- `Backbone`: the base class for all backbone networks. Multi-task and single-task learning can use this class directly. 10- `CLBackbone`: the base class for continual learning backbone networks, which incorporates mechanisms for managing continual learning tasks. 11 - `HATMaskBackbone`: the base class for backbones used in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) CL algorithm. 12 - `WSNMaskBackbone`: The base class for backbones used in [WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) CL algorithm. 13 14Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement backbone networks: 15 16- [**Configure Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/backbone-network) 17- [**Implement Custom Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/backbone-network) 18 19 20""" 21 22from .base import ( 23 Backbone, 24 CLBackbone, 25 HATMaskBackbone, 26 WSNMaskBackbone, 27) 28from .mlp import MLP, CLMLP, HATMaskMLP, WSNMaskMLP 29from .resnet import ( 30 ResNet18, 31 ResNet34, 32 ResNet50, 33 ResNet101, 34 ResNet152, 35 CLResNet18, 36 CLResNet34, 37 CLResNet50, 38 CLResNet101, 39 CLResNet152, 40 HATMaskResNet18, 41 HATMaskResNet34, 42 HATMaskResNet50, 43 HATMaskResNet101, 44 HATMaskResNet152, 45) 46 47 48__all__ = [ 49 "Backbone", 50 "CLBackbone", 51 "HATMaskBackbone", 52 "WSNMaskBackbone", 53 "mlp", 54 "resnet", 55]
25class Backbone(nn.Module): 26 r"""The base class for backbone networks.""" 27 28 def __init__(self, output_dim: int | None, **kwargs) -> None: 29 r""" 30 **Args:** 31 - **output_dim** (`int` | `None`): The output dimension that connects to output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 32 - **kwargs**: Reserved for multiple inheritance. 33 """ 34 super().__init__() 35 36 self.output_dim: int = output_dim 37 r"""The output dimension of the backbone network.""" 38 39 self.weighted_layer_names: list[str] = [] 40 r"""The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. **It must be provided in subclasses.** 41 42 The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the `nn.Module` internal naming mechanism with `.` replaced with `/`. For example: 43 - If a layer is assigned to `self.conv1`, the name becomes `conv1`. 44 - If `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. 45 - If a hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should have been `block.conv1` according to `nn.Module`'s rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a `ModuleDict`. 46 """ 47 48 def get_layer_by_name(self, layer_name: str | None) -> nn.Module | None: 49 r"""Get the layer by its name. 50 51 **Args:** 52 - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`. 53 54 **Returns:** 55 - **layer** (`nn.Module` | `None`): The layer. If `layer_name` is `None`, return `None`. 56 """ 57 if layer_name is None: 58 return None 59 60 for name, layer in self.named_modules(): 61 if name == layer_name.replace("/", "."): 62 return layer 63 64 def preceding_layer_name(self, layer_name: str | None) -> str | None: 65 r"""Get the name of the preceding layer of the given layer from the stored `weighted_layer_names`. 66 67 **Args:** 68 - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`. 69 70 **Returns:** 71 - **preceding_layer_name** (`str`): The name of the preceding layer. If the given layer is the first layer, return `None`. 72 """ 73 if layer_name is None: 74 return None 75 76 if layer_name not in self.weighted_layer_names: 77 raise ValueError( 78 f"The layer name {layer_name} doesn't exist in weighted layer names." 79 ) 80 81 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 82 if weighted_layer_idx == 0: 83 return None 84 preceding_layer_name = self.weighted_layer_names[weighted_layer_idx - 1] 85 return preceding_layer_name 86 87 def next_layer_name(self, layer_name: str) -> str: 88 r"""Get the name of the next layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the last layer of the BACKBONE, return `None`. 89 90 **Args:** 91 - **layer_name** (`str`): The name of the layer. 92 93 **Returns:** 94 - **next_layer_name** (`str`): The name of the next layer. 95 96 **Raises:** 97 - **ValueError**: If `layer_name` is not in the weighted layer order. 98 """ 99 100 if layer_name not in self.weighted_layer_names: 101 raise ValueError(f"The layer name {layer_name} doesn't exist.") 102 103 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 104 if weighted_layer_idx == len(self.weighted_layer_names) - 1: 105 return None 106 next_layer_name = self.weighted_layer_names[weighted_layer_idx + 1] 107 return next_layer_name 108 109 @override # since `nn.Module` uses it 110 def forward( 111 self, 112 input: Tensor, 113 stage: str, 114 ) -> tuple[Tensor, dict[str, Tensor]]: 115 r"""The forward pass. **It must be implemented by subclasses.** 116 117 **Args:** 118 - **input** (`Tensor`): The input tensor from data. 119 - **stage** (`str`): The stage of the forward pass; one of: 120 1. 'train': training stage. 121 2. 'validation': validation stage. 122 3. 'test': testing stage. 123 124 **Returns:** 125 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 126 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes. 127 """
The base class for backbone networks.
28 def __init__(self, output_dim: int | None, **kwargs) -> None: 29 r""" 30 **Args:** 31 - **output_dim** (`int` | `None`): The output dimension that connects to output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 32 - **kwargs**: Reserved for multiple inheritance. 33 """ 34 super().__init__() 35 36 self.output_dim: int = output_dim 37 r"""The output dimension of the backbone network.""" 38 39 self.weighted_layer_names: list[str] = [] 40 r"""The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. **It must be provided in subclasses.** 41 42 The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the `nn.Module` internal naming mechanism with `.` replaced with `/`. For example: 43 - If a layer is assigned to `self.conv1`, the name becomes `conv1`. 44 - If `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. 45 - If a hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should have been `block.conv1` according to `nn.Module`'s rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a `ModuleDict`. 46 """
Args:
- output_dim (
int|None): The output dimension that connects to output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - kwargs: Reserved for multiple inheritance.
The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. It must be provided in subclasses.
The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the nn.Module internal naming mechanism with . replaced with /. For example:
- If a layer is assigned to
self.conv1, the name becomesconv1. - If
nn.Sequentialis used, the name becomes the index of the layer in the sequence, such as0,1, etc. - If a hierarchical structure is used, for example, a
nn.Moduleis assigned toself.blockwhich hasself.conv1, the name becomesblock/conv1. Note that it should have beenblock.conv1according tonn.Module's rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in aModuleDict.
48 def get_layer_by_name(self, layer_name: str | None) -> nn.Module | None: 49 r"""Get the layer by its name. 50 51 **Args:** 52 - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`. 53 54 **Returns:** 55 - **layer** (`nn.Module` | `None`): The layer. If `layer_name` is `None`, return `None`. 56 """ 57 if layer_name is None: 58 return None 59 60 for name, layer in self.named_modules(): 61 if name == layer_name.replace("/", "."): 62 return layer
Get the layer by its name.
Args:
- layer_name (
str|None): The layer name following thenn.Moduleinternal naming mechanism with.replaced with/. IfNone, returnNone.
Returns:
- layer (
nn.Module|None): The layer. Iflayer_nameisNone, returnNone.
64 def preceding_layer_name(self, layer_name: str | None) -> str | None: 65 r"""Get the name of the preceding layer of the given layer from the stored `weighted_layer_names`. 66 67 **Args:** 68 - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`. 69 70 **Returns:** 71 - **preceding_layer_name** (`str`): The name of the preceding layer. If the given layer is the first layer, return `None`. 72 """ 73 if layer_name is None: 74 return None 75 76 if layer_name not in self.weighted_layer_names: 77 raise ValueError( 78 f"The layer name {layer_name} doesn't exist in weighted layer names." 79 ) 80 81 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 82 if weighted_layer_idx == 0: 83 return None 84 preceding_layer_name = self.weighted_layer_names[weighted_layer_idx - 1] 85 return preceding_layer_name
Get the name of the preceding layer of the given layer from the stored weighted_layer_names.
Args:
- layer_name (
str|None): The layer name following thenn.Moduleinternal naming mechanism with.replaced with/. IfNone, returnNone.
Returns:
- preceding_layer_name (
str): The name of the preceding layer. If the given layer is the first layer, returnNone.
87 def next_layer_name(self, layer_name: str) -> str: 88 r"""Get the name of the next layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the last layer of the BACKBONE, return `None`. 89 90 **Args:** 91 - **layer_name** (`str`): The name of the layer. 92 93 **Returns:** 94 - **next_layer_name** (`str`): The name of the next layer. 95 96 **Raises:** 97 - **ValueError**: If `layer_name` is not in the weighted layer order. 98 """ 99 100 if layer_name not in self.weighted_layer_names: 101 raise ValueError(f"The layer name {layer_name} doesn't exist.") 102 103 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 104 if weighted_layer_idx == len(self.weighted_layer_names) - 1: 105 return None 106 next_layer_name = self.weighted_layer_names[weighted_layer_idx + 1] 107 return next_layer_name
Get the name of the next layer of the given layer from the stored self.masked_layer_order. If the given layer is the last layer of the BACKBONE, return None.
Args:
- layer_name (
str): The name of the layer.
Returns:
- next_layer_name (
str): The name of the next layer.
Raises:
- ValueError: If
layer_nameis not in the weighted layer order.
109 @override # since `nn.Module` uses it 110 def forward( 111 self, 112 input: Tensor, 113 stage: str, 114 ) -> tuple[Tensor, dict[str, Tensor]]: 115 r"""The forward pass. **It must be implemented by subclasses.** 116 117 **Args:** 118 - **input** (`Tensor`): The input tensor from data. 119 - **stage** (`str`): The stage of the forward pass; one of: 120 1. 'train': training stage. 121 2. 'validation': validation stage. 122 3. 'test': testing stage. 123 124 **Returns:** 125 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 126 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes. 127 """
The forward pass. It must be implemented by subclasses.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes.
130class CLBackbone(Backbone): 131 r"""The base class of continual learning backbone networks.""" 132 133 def __init__(self, output_dim: int | None, **kwargs) -> None: 134 r""" 135 **Args:** 136 - **output_dim** (`int` | `None`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 137 - **kwargs**: Reserved for multiple inheritance. 138 """ 139 super().__init__(output_dim=output_dim, **kwargs) 140 141 # task ID control 142 self.task_id: int 143 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 144 self.processed_task_ids: list[int] = [] 145 r"""Task IDs that have been processed.""" 146 147 def setup_task_id(self, task_id: int) -> None: 148 r"""Set up task `task_id`. This must be done before the `forward()` method is called.""" 149 self.task_id = task_id 150 self.processed_task_ids.append(task_id) 151 152 @override # since `nn.Module` uses it 153 def forward( 154 self, 155 input: Tensor, 156 stage: str, 157 task_id: int | None = None, 158 ) -> tuple[Tensor, dict[str, Tensor]]: 159 r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.** 160 161 **Args:** 162 - **input** (`Tensor`): The input tensor from data. 163 - **stage** (`str`): The stage of the forward pass; one of: 164 1. 'train': training stage. 165 2. 'validation': validation stage. 166 3. 'test': testing stage. 167 - **task_id** (`int` | `None`): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task `self.task_id`. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value. 168 169 **Returns:** 170 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 171 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes. 172 """
The base class of continual learning backbone networks.
133 def __init__(self, output_dim: int | None, **kwargs) -> None: 134 r""" 135 **Args:** 136 - **output_dim** (`int` | `None`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 137 - **kwargs**: Reserved for multiple inheritance. 138 """ 139 super().__init__(output_dim=output_dim, **kwargs) 140 141 # task ID control 142 self.task_id: int 143 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 144 self.processed_task_ids: list[int] = [] 145 r"""Task IDs that have been processed."""
Args:
- output_dim (
int|None): The output dimension that connects to CL output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - kwargs: Reserved for multiple inheritance.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.
152 @override # since `nn.Module` uses it 153 def forward( 154 self, 155 input: Tensor, 156 stage: str, 157 task_id: int | None = None, 158 ) -> tuple[Tensor, dict[str, Tensor]]: 159 r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.** 160 161 **Args:** 162 - **input** (`Tensor`): The input tensor from data. 163 - **stage** (`str`): The stage of the forward pass; one of: 164 1. 'train': training stage. 165 2. 'validation': validation stage. 166 3. 'test': testing stage. 167 - **task_id** (`int` | `None`): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task `self.task_id`. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value. 168 169 **Returns:** 170 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 171 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes. 172 """
The forward pass for data from task task_id. In some backbones, the forward pass might be different for different tasks. It must be implemented by subclasses.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int|None): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current taskself.task_id. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
175class HATMaskBackbone(CLBackbone): 176 r"""The backbone network for HAT-based algorithms with learnable hard attention masks. 177 178 HAT-based algorithms include: 179 180 - [**HAT (Hard Attention to the Task, 2018)**](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters. 181 - [**AdaHAT (Adaptive Hard Attention to the Task, 2024)**](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves HAT by introducing adaptive soft gradient clipping based on parameter importance and network sparsity. 182 - **FG-AdaHAT** is an architecture-based continual learning approach that improves HAT by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT. 183 """ 184 185 def __init__(self, output_dim: int | None, gate: str, **kwargs) -> None: 186 r""" 187 **Args:** 188 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 189 - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of: 190 - `sigmoid`: the sigmoid function. 191 - `tanh`: the hyperbolic tangent function. 192 - **kwargs**: Reserved for multiple inheritance. 193 """ 194 super().__init__(output_dim=output_dim, **kwargs) 195 196 self.gate: str = gate 197 r"""The type of gate function.""" 198 self.gate_fn: Callable 199 r"""The gate function mapping the real value task embeddings into attention masks.""" 200 201 if gate == "sigmoid": 202 self.gate_fn = nn.Sigmoid() 203 204 self.task_embedding_t: nn.ModuleDict = nn.ModuleDict() 205 r"""The task embedding for the current task `self.task_id`. Keys are layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units). 206 207 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 208 209 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 210 211 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 212 """ 213 214 self.masks: dict[int, dict[str, Tensor]] = {} 215 r"""The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ). """ 216 217 HATMaskBackbone.sanity_check(self) 218 219 def initialize_task_embedding(self, mode: str) -> None: 220 r"""Initialize the task embedding for the current task `self.task_id`. 221 222 **Args:** 223 - **mode** (`str`): The initialization mode for task embeddings; one of: 224 1. 'N01' (default): standard normal distribution $N(0, 1)$. 225 2. 'U-11': uniform distribution $U(-1, 1)$. 226 3. 'U01': uniform distribution $U(0, 1)$. 227 4. 'U-10': uniform distribution $U(-1, 0)$. 228 5. 'last': inherit task embeddings from the last task. 229 """ 230 for te in self.task_embedding_t.values(): 231 if mode == "N01": 232 nn.init.normal_(te.weight, 0, 1) 233 elif mode == "U-11": 234 nn.init.uniform_(te.weight, -1, 1) 235 elif mode == "U01": 236 nn.init.uniform_(te.weight, 0, 1) 237 elif mode == "U-10": 238 nn.init.uniform_(te.weight, -1, 0) 239 elif mode == "last": 240 pass 241 242 def sanity_check(self) -> None: 243 r"""Sanity check.""" 244 245 if self.gate not in ["sigmoid"]: 246 raise ValueError("The gate should be one of: 'sigmoid'.") 247 248 def get_mask( 249 self, 250 stage: str, 251 s_max: float | None = None, 252 batch_idx: int | None = None, 253 num_batches: int | None = None, 254 test_task_id: int | None = None, 255 ) -> dict[str, Tensor]: 256 r"""Get the hard attention mask used in the `forward()` method for different stages. 257 258 **Args:** 259 - **stage** (`str`): The stage when applying the conversion; one of: 260 1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 261 2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by `s_max`, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) 262 3. 'test': testing stage. Apply the test mask directly from the stored masks using `test_task_id`. 263 - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 264 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 265 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 266 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 267 268 **Returns:** 269 - **mask** (`dict[str, Tensor]`): The hard attention (with values 0 or 1) mask. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 270 """ 271 272 # sanity check 273 if stage == "train" and ( 274 s_max is None or batch_idx is None or num_batches is None 275 ): 276 raise ValueError( 277 "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`." 278 ) 279 if stage == "validation" and (s_max is None): 280 raise ValueError( 281 "The `s_max` should be provided at validation stage, instead of the default value `None`." 282 ) 283 if stage == "test" and (test_task_id is None): 284 raise ValueError( 285 "The `task_mask` should be provided at testing stage, instead of the default value `None`." 286 ) 287 288 mask = {} 289 if stage == "train": 290 for layer_name in self.weighted_layer_names: 291 anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / ( 292 num_batches - 1 293 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 294 mask[layer_name] = self.gate_fn( 295 self.task_embedding_t[layer_name].weight * anneal_scalar 296 ).squeeze() 297 elif stage == "validation": 298 for layer_name in self.weighted_layer_names: 299 mask[layer_name] = self.gate_fn( 300 self.task_embedding_t[layer_name].weight * s_max 301 ).squeeze() 302 elif stage == "test": 303 mask = self.masks[test_task_id] 304 305 return mask 306 307 def te_to_binary_mask(self) -> dict[str, Tensor]: 308 r"""Convert the current task embedding into a binary mask. 309 310 This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task. 311 312 **Returns:** 313 - **mask_t** (`dict[str, Tensor]`): The binary mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 314 """ 315 # get the mask for the current task 316 mask_t = { 317 layer_name: (self.task_embedding_t[layer_name].weight > 0) 318 .float() 319 .squeeze() 320 .detach() 321 for layer_name in self.weighted_layer_names 322 } 323 324 return mask_t 325 326 def store_mask(self) -> None: 327 r"""Store the mask for the current task `self.task_id`.""" 328 mask_t = self.te_to_binary_mask() 329 self.masks[self.task_id] = mask_t 330 331 return mask_t 332 333 def get_layer_measure_parameter_wise( 334 self, 335 neuron_wise_measure: dict[str, Tensor], 336 layer_name: str, 337 aggregation_mode: str, 338 ) -> Tensor: 339 r"""Get the parameter-wise measure on the parameters right before the given layer. 340 341 It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$. 342 343 Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer. 344 345 This method is used to calculate parameter-wise measures in various HAT-based algorithms: 346 347 - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks `cumulative_mask_for_previous_tasks`, which is $\text{Agg} \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in Eq. (2) in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 348 - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks `summative_mask_for_previous_tasks`, which is $\text{Agg} \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 349 - **FG-AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks `summative_importance_for_previous_tasks`, which is $\text{Agg} \left(I_{l,i}^{<t}, I_{l-1,j}^{<t}\right)$ in Eq. (2) in the FG-AdaHAT paper. 350 351 **Args:** 352 - **neuron_wise_measure** (`dict[str, Tensor]`): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ). 353 - **layer_name** (`str`): The name of the given layer. 354 - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of: 355 - 'min': takes the minimum of the two connected unit measures. 356 - 'max': takes the maximum of the two connected unit measures. 357 - 'mean': takes the mean of the two connected unit measures. 358 359 **Returns:** 360 - **weight_measure** (`Tensor`): The weight measure matrix, the same size as the corresponding weights. 361 - **bias_measure** (`Tensor`): The bias measure vector, the same size as the corresponding bias. 362 """ 363 364 # initialize the aggregation function 365 if aggregation_mode == "min": 366 aggregation_func = torch.min 367 elif aggregation_mode == "max": 368 aggregation_func = torch.max 369 elif aggregation_mode == "mean": 370 aggregation_func = torch.mean 371 else: 372 raise ValueError( 373 f"The aggregation method {aggregation_mode} is not supported." 374 ) 375 376 # get the preceding layer 377 preceding_layer_name = self.preceding_layer_name(layer_name) 378 379 # get weight size for expanding the measures 380 layer = self.get_layer_by_name(layer_name) 381 weight_size = layer.weight.size() 382 383 # construct the weight-wise measure 384 layer_measure = neuron_wise_measure[layer_name] 385 layer_measure_broadcast_size = (-1, 1) + tuple( 386 1 for _ in range(len(weight_size) - 2) 387 ) # since the size of mask tensor is (number of units, ), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 388 389 layer_measure_broadcasted = layer_measure.view( 390 *layer_measure_broadcast_size 391 ).expand( 392 weight_size, 393 ) # expand the given layer mask to the weight size and broadcast 394 395 if ( 396 preceding_layer_name 397 ): # if the layer is not the first layer, where the preceding layer exists 398 399 preceding_layer_measure_broadcast_size = (1, -1) + tuple( 400 1 for _ in range(len(weight_size) - 2) 401 ) # since the size of mask tensor is (number of units, ), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 402 preceding_layer_measure = neuron_wise_measure[preceding_layer_name] 403 preceding_layer_measure_broadcasted = preceding_layer_measure.view( 404 *preceding_layer_measure_broadcast_size 405 ).expand( 406 weight_size 407 ) # expand the preceding layer mask to the weight size and broadcast 408 weight_measure = aggregation_func( 409 layer_measure_broadcasted, preceding_layer_measure_broadcasted 410 ) # get the minimum of the two mask vectors, from expanded 411 else: # if the layer is the first layer 412 weight_measure = layer_measure_broadcasted 413 414 # construct the bias-wise measure 415 bias_measure = layer_measure 416 417 return weight_measure, bias_measure 418 419 @override 420 def forward( 421 self, 422 input: Tensor, 423 stage: str, 424 s_max: float | None = None, 425 batch_idx: int | None = None, 426 num_batches: int | None = None, 427 test_task_id: int | None = None, 428 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 429 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to the units in each layer. 430 431 **Args:** 432 - **input** (`Tensor`): The input tensor from data. 433 - **stage** (`str`): The stage of the forward pass; one of: 434 1. 'train': training stage. 435 2. 'validation': validation stage. 436 3. 'test': testing stage. 437 - **s_max** (`float`): The maximum scaling factor in the gate function. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 438 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 439 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 440 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 441 442 **Returns:** 443 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 444 - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 445 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class. 446 """
The backbone network for HAT-based algorithms with learnable hard attention masks.
HAT-based algorithms include:
- HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
- AdaHAT (Adaptive Hard Attention to the Task, 2024) is an architecture-based continual learning approach that improves HAT by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
- FG-AdaHAT is an architecture-based continual learning approach that improves HAT by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
185 def __init__(self, output_dim: int | None, gate: str, **kwargs) -> None: 186 r""" 187 **Args:** 188 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 189 - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of: 190 - `sigmoid`: the sigmoid function. 191 - `tanh`: the hyperbolic tangent function. 192 - **kwargs**: Reserved for multiple inheritance. 193 """ 194 super().__init__(output_dim=output_dim, **kwargs) 195 196 self.gate: str = gate 197 r"""The type of gate function.""" 198 self.gate_fn: Callable 199 r"""The gate function mapping the real value task embeddings into attention masks.""" 200 201 if gate == "sigmoid": 202 self.gate_fn = nn.Sigmoid() 203 204 self.task_embedding_t: nn.ModuleDict = nn.ModuleDict() 205 r"""The task embedding for the current task `self.task_id`. Keys are layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units). 206 207 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 208 209 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 210 211 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 212 """ 213 214 self.masks: dict[int, dict[str, Tensor]] = {} 215 r"""The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ). """ 216 217 HATMaskBackbone.sanity_check(self)
Args:
- output_dim (
int): The output dimension that connects to CL output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - gate (
str): The type of gate function turning the real value task embeddings into attention masks; one of:sigmoid: the sigmoid function.tanh: the hyperbolic tangent function.
- kwargs: Reserved for multiple inheritance.
The task embedding for the current task self.task_id. Keys are layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has size (1, number of units).
We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.
We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)
This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ).
219 def initialize_task_embedding(self, mode: str) -> None: 220 r"""Initialize the task embedding for the current task `self.task_id`. 221 222 **Args:** 223 - **mode** (`str`): The initialization mode for task embeddings; one of: 224 1. 'N01' (default): standard normal distribution $N(0, 1)$. 225 2. 'U-11': uniform distribution $U(-1, 1)$. 226 3. 'U01': uniform distribution $U(0, 1)$. 227 4. 'U-10': uniform distribution $U(-1, 0)$. 228 5. 'last': inherit task embeddings from the last task. 229 """ 230 for te in self.task_embedding_t.values(): 231 if mode == "N01": 232 nn.init.normal_(te.weight, 0, 1) 233 elif mode == "U-11": 234 nn.init.uniform_(te.weight, -1, 1) 235 elif mode == "U01": 236 nn.init.uniform_(te.weight, 0, 1) 237 elif mode == "U-10": 238 nn.init.uniform_(te.weight, -1, 0) 239 elif mode == "last": 240 pass
Initialize the task embedding for the current task self.task_id.
Args:
- mode (
str): The initialization mode for task embeddings; one of:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit task embeddings from the last task.
242 def sanity_check(self) -> None: 243 r"""Sanity check.""" 244 245 if self.gate not in ["sigmoid"]: 246 raise ValueError("The gate should be one of: 'sigmoid'.")
Sanity check.
248 def get_mask( 249 self, 250 stage: str, 251 s_max: float | None = None, 252 batch_idx: int | None = None, 253 num_batches: int | None = None, 254 test_task_id: int | None = None, 255 ) -> dict[str, Tensor]: 256 r"""Get the hard attention mask used in the `forward()` method for different stages. 257 258 **Args:** 259 - **stage** (`str`): The stage when applying the conversion; one of: 260 1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 261 2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by `s_max`, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) 262 3. 'test': testing stage. Apply the test mask directly from the stored masks using `test_task_id`. 263 - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 264 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 265 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 266 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 267 268 **Returns:** 269 - **mask** (`dict[str, Tensor]`): The hard attention (with values 0 or 1) mask. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 270 """ 271 272 # sanity check 273 if stage == "train" and ( 274 s_max is None or batch_idx is None or num_batches is None 275 ): 276 raise ValueError( 277 "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`." 278 ) 279 if stage == "validation" and (s_max is None): 280 raise ValueError( 281 "The `s_max` should be provided at validation stage, instead of the default value `None`." 282 ) 283 if stage == "test" and (test_task_id is None): 284 raise ValueError( 285 "The `task_mask` should be provided at testing stage, instead of the default value `None`." 286 ) 287 288 mask = {} 289 if stage == "train": 290 for layer_name in self.weighted_layer_names: 291 anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / ( 292 num_batches - 1 293 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 294 mask[layer_name] = self.gate_fn( 295 self.task_embedding_t[layer_name].weight * anneal_scalar 296 ).squeeze() 297 elif stage == "validation": 298 for layer_name in self.weighted_layer_names: 299 mask[layer_name] = self.gate_fn( 300 self.task_embedding_t[layer_name].weight * s_max 301 ).squeeze() 302 elif stage == "test": 303 mask = self.masks[test_task_id] 304 305 return mask
Get the hard attention mask used in the forward() method for different stages.
Args:
- stage (
str): The stage when applying the conversion; one of:- 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the HAT paper.
- 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by
s_max, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) - 'test': testing stage. Apply the test mask directly from the stored masks using
test_task_id.
- s_max (
float): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the HAT paper. - batch_idx (
int|None): The current batch index. Applies only to the training stage. For other stages, it isNone. - num_batches (
int|None): The total number of batches. Applies only to the training stage. For other stages, it isNone. - test_task_id (
int|None): The test task ID. Applies only to the testing stage. For other stages, it isNone.
Returns:
- mask (
dict[str, Tensor]): The hard attention (with values 0 or 1) mask. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
307 def te_to_binary_mask(self) -> dict[str, Tensor]: 308 r"""Convert the current task embedding into a binary mask. 309 310 This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task. 311 312 **Returns:** 313 - **mask_t** (`dict[str, Tensor]`): The binary mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 314 """ 315 # get the mask for the current task 316 mask_t = { 317 layer_name: (self.task_embedding_t[layer_name].weight > 0) 318 .float() 319 .squeeze() 320 .detach() 321 for layer_name in self.weighted_layer_names 322 } 323 324 return mask_t
Convert the current task embedding into a binary mask.
This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task.
Returns:
- mask_t (
dict[str, Tensor]): The binary mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
326 def store_mask(self) -> None: 327 r"""Store the mask for the current task `self.task_id`.""" 328 mask_t = self.te_to_binary_mask() 329 self.masks[self.task_id] = mask_t 330 331 return mask_t
Store the mask for the current task self.task_id.
333 def get_layer_measure_parameter_wise( 334 self, 335 neuron_wise_measure: dict[str, Tensor], 336 layer_name: str, 337 aggregation_mode: str, 338 ) -> Tensor: 339 r"""Get the parameter-wise measure on the parameters right before the given layer. 340 341 It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$. 342 343 Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer. 344 345 This method is used to calculate parameter-wise measures in various HAT-based algorithms: 346 347 - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks `cumulative_mask_for_previous_tasks`, which is $\text{Agg} \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in Eq. (2) in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 348 - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks `summative_mask_for_previous_tasks`, which is $\text{Agg} \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 349 - **FG-AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks `summative_importance_for_previous_tasks`, which is $\text{Agg} \left(I_{l,i}^{<t}, I_{l-1,j}^{<t}\right)$ in Eq. (2) in the FG-AdaHAT paper. 350 351 **Args:** 352 - **neuron_wise_measure** (`dict[str, Tensor]`): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ). 353 - **layer_name** (`str`): The name of the given layer. 354 - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of: 355 - 'min': takes the minimum of the two connected unit measures. 356 - 'max': takes the maximum of the two connected unit measures. 357 - 'mean': takes the mean of the two connected unit measures. 358 359 **Returns:** 360 - **weight_measure** (`Tensor`): The weight measure matrix, the same size as the corresponding weights. 361 - **bias_measure** (`Tensor`): The bias measure vector, the same size as the corresponding bias. 362 """ 363 364 # initialize the aggregation function 365 if aggregation_mode == "min": 366 aggregation_func = torch.min 367 elif aggregation_mode == "max": 368 aggregation_func = torch.max 369 elif aggregation_mode == "mean": 370 aggregation_func = torch.mean 371 else: 372 raise ValueError( 373 f"The aggregation method {aggregation_mode} is not supported." 374 ) 375 376 # get the preceding layer 377 preceding_layer_name = self.preceding_layer_name(layer_name) 378 379 # get weight size for expanding the measures 380 layer = self.get_layer_by_name(layer_name) 381 weight_size = layer.weight.size() 382 383 # construct the weight-wise measure 384 layer_measure = neuron_wise_measure[layer_name] 385 layer_measure_broadcast_size = (-1, 1) + tuple( 386 1 for _ in range(len(weight_size) - 2) 387 ) # since the size of mask tensor is (number of units, ), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 388 389 layer_measure_broadcasted = layer_measure.view( 390 *layer_measure_broadcast_size 391 ).expand( 392 weight_size, 393 ) # expand the given layer mask to the weight size and broadcast 394 395 if ( 396 preceding_layer_name 397 ): # if the layer is not the first layer, where the preceding layer exists 398 399 preceding_layer_measure_broadcast_size = (1, -1) + tuple( 400 1 for _ in range(len(weight_size) - 2) 401 ) # since the size of mask tensor is (number of units, ), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 402 preceding_layer_measure = neuron_wise_measure[preceding_layer_name] 403 preceding_layer_measure_broadcasted = preceding_layer_measure.view( 404 *preceding_layer_measure_broadcast_size 405 ).expand( 406 weight_size 407 ) # expand the preceding layer mask to the weight size and broadcast 408 weight_measure = aggregation_func( 409 layer_measure_broadcasted, preceding_layer_measure_broadcasted 410 ) # get the minimum of the two mask vectors, from expanded 411 else: # if the layer is the first layer 412 weight_measure = layer_measure_broadcasted 413 414 # construct the bias-wise measure 415 bias_measure = layer_measure 416 417 return weight_measure, bias_measure
Get the parameter-wise measure on the parameters right before the given layer.
It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer.
This method is used to calculate parameter-wise measures in various HAT-based algorithms:
- HAT: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks
cumulative_mask_for_previous_tasks, which is $\text{Agg} \left(a_{l,i}^{HAT paper. - AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks
summative_mask_for_previous_tasks, which is $\text{Agg} \left(m_{l,i}^{AdaHAT paper. - FG-AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks
summative_importance_for_previous_tasks, which is $\text{Agg} \left(I_{l,i}^{
Args:
- neuron_wise_measure (
dict[str, Tensor]): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ). - layer_name (
str): The name of the given layer. - aggregation_mode (
str): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:- 'min': takes the minimum of the two connected unit measures.
- 'max': takes the maximum of the two connected unit measures.
- 'mean': takes the mean of the two connected unit measures.
Returns:
- weight_measure (
Tensor): The weight measure matrix, the same size as the corresponding weights. - bias_measure (
Tensor): The bias measure vector, the same size as the corresponding bias.
419 @override 420 def forward( 421 self, 422 input: Tensor, 423 stage: str, 424 s_max: float | None = None, 425 batch_idx: int | None = None, 426 num_batches: int | None = None, 427 test_task_id: int | None = None, 428 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 429 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to the units in each layer. 430 431 **Args:** 432 - **input** (`Tensor`): The input tensor from data. 433 - **stage** (`str`): The stage of the forward pass; one of: 434 1. 'train': training stage. 435 2. 'validation': validation stage. 436 3. 'test': testing stage. 437 - **s_max** (`float`): The maximum scaling factor in the gate function. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 438 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 439 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 440 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 441 442 **Returns:** 443 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 444 - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 445 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class. 446 """
The forward pass for data from task self.task_id. Task-specific masks for self.task_id are applied to the units in each layer.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- s_max (
float): The maximum scaling factor in the gate function. See Sec. 2.4 in the HAT paper. - batch_idx (
int|None): The current batch index. Applies only to the training stage. For other stages, it isNone. - num_batches (
int|None): The total number of batches. Applies only to the training stage. For other stages, it isNone. - test_task_id (
int|None): The test task ID. Applies only to the testing stage. For other stages, it isNone.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - mask (
dict[str, Tensor]): The mask for the current task. Keys (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ). - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit thisforward()method of theHATclass.
449class WSNMaskBackbone(CLBackbone): 450 r"""The backbone network for the WSN algorithm with learnable parameter masks. 451 452 [WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) is an architecture-based continual learning algorithm. It trains learnable parameter-wise scores and selects the most scored $c\%$ of the network parameters to be used for each task. 453 """ 454 455 def __init__(self, output_dim: int | None, **kwargs) -> None: 456 r""" 457 **Args:** 458 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 459 - **kwargs**: Reserved for multiple inheritance. 460 """ 461 super().__init__(output_dim=output_dim, **kwargs) 462 463 self.gate_fn: torch.autograd.Function = PercentileLayerParameterMaskingByScore 464 r"""The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score.""" 465 466 self.weight_score_t: nn.ModuleDict = nn.ModuleDict() 467 r"""The weight score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (output features, input features) as the weight. 468 469 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 470 471 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 472 473 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 474 """ 475 476 self.bias_score_t: nn.ModuleDict = nn.ModuleDict() 477 r"""The bias score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is `None`. 478 479 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 480 481 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 482 483 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 484 """ 485 486 WSNMaskBackbone.sanity_check(self) 487 488 def sanity_check(self) -> None: 489 r"""Sanity check.""" 490 491 def initialize_parameter_score(self, mode: str) -> None: 492 r"""Initialize the parameter score for the current task. 493 494 **Args:** 495 - **mode** (`str`): The initialization mode for parameter scores; one of: 496 1. 'default': the default initialization mode in the original WSN code. 497 2. 'N01': standard normal distribution $N(0, 1)$. 498 3. 'U01': uniform distribution $U(0, 1)$. 499 """ 500 501 for layer_name, weight_score in self.weight_score_t.items(): 502 if mode == "default": 503 # Kaiming Uniform Initialization for weight score 504 nn.init.kaiming_uniform_(weight_score.weight, a=math.sqrt(5)) 505 506 for layer_name, bias_score in self.bias_score_t.items(): 507 if bias_score is not None: 508 # For bias, follow the standard bias initialization using fan_in 509 weight_score = self.weight_score_t[layer_name] 510 fan_in, _ = nn.init._calculate_fan_in_and_fan_out( 511 weight_score.weight 512 ) 513 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 514 nn.init.uniform_(bias_score.weight, -bound, bound) 515 elif mode == "N01": 516 nn.init.normal_(weight_score.weight, 0, 1) 517 for layer_name, bias_score in self.bias_score_t.items(): 518 if bias_score is not None: 519 nn.init.normal_(bias_score.weight, 0, 1) 520 elif mode == "U01": 521 nn.init.uniform_(weight_score.weight, 0, 1) 522 for layer_name, bias_score in self.bias_score_t.items(): 523 if bias_score is not None: 524 nn.init.uniform_(bias_score.weight, 0, 1) 525 526 def get_mask( 527 self, 528 stage: str, 529 mask_percentage: float, 530 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 531 ) -> dict[str, Tensor]: 532 r"""Get the binary parameter mask used in the `forward()` method for different stages. 533 534 **Args:** 535 - **stage** (`str`): The stage when applying the conversion; one of: 536 1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 537 2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) 538 3. 'test': testing stage. Apply the test mask directly from the argument `test_mask`. 539 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`. 540 541 **Returns:** 542 - **weight_mask** (`dict[str, Tensor]`): The binary mask on weights. Key (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. 543 - **bias_mask** (`dict[str, Tensor]`): The binary mask on biases. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 544 """ 545 weight_mask = {} 546 bias_mask = {} 547 if stage == "train" or stage == "validation": 548 for layer_name in self.weighted_layer_names: 549 weight_mask[layer_name] = self.gate_fn.apply( 550 self.weight_score_t[layer_name].weight, mask_percentage 551 ) 552 if self.bias_score_t[layer_name] is not None: 553 bias_mask[layer_name] = self.gate_fn.apply( 554 self.bias_score_t[layer_name].weight.squeeze( 555 0 556 ), # from (1, output_dim) to (output_dim, ) 557 mask_percentage, 558 ) 559 else: 560 bias_mask[layer_name] = None 561 elif stage == "test": 562 weight_mask, bias_mask = test_mask 563 564 return weight_mask, bias_mask 565 566 @override 567 def forward( 568 self, 569 input: Tensor, 570 stage: str, 571 mask_percentage: float, 572 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 573 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]: 574 r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units in each layer. 575 576 **Args:** 577 - **input** (`Tensor`): The input tensor from data. 578 - **stage** (`str`): The stage of the forward pass; one of: 579 1. 'train': training stage. 580 2. 'validation': validation stage. 581 3. 'test': testing stage. 582 - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1. 583 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is `None`. 584 585 **Returns:** 586 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 587 - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Key (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, input features) as the weight. 588 - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 589 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 590 """
The backbone network for the WSN algorithm with learnable parameter masks.
WSN (Winning Subnetworks, 2022) is an architecture-based continual learning algorithm. It trains learnable parameter-wise scores and selects the most scored $c\%$ of the network parameters to be used for each task.
455 def __init__(self, output_dim: int | None, **kwargs) -> None: 456 r""" 457 **Args:** 458 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 459 - **kwargs**: Reserved for multiple inheritance. 460 """ 461 super().__init__(output_dim=output_dim, **kwargs) 462 463 self.gate_fn: torch.autograd.Function = PercentileLayerParameterMaskingByScore 464 r"""The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score.""" 465 466 self.weight_score_t: nn.ModuleDict = nn.ModuleDict() 467 r"""The weight score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (output features, input features) as the weight. 468 469 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 470 471 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 472 473 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 474 """ 475 476 self.bias_score_t: nn.ModuleDict = nn.ModuleDict() 477 r"""The bias score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is `None`. 478 479 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 480 481 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 482 483 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 484 """ 485 486 WSNMaskBackbone.sanity_check(self)
Args:
- output_dim (
int): The output dimension that connects to CL output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - kwargs: Reserved for multiple inheritance.
The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score.
The weight score for the current task self.task_id. Keys are the layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has the same size (output features, input features) as the weight.
We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.
We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)
This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
The bias score for the current task self.task_id. Keys are the layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is None.
We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.
We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)
This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
491 def initialize_parameter_score(self, mode: str) -> None: 492 r"""Initialize the parameter score for the current task. 493 494 **Args:** 495 - **mode** (`str`): The initialization mode for parameter scores; one of: 496 1. 'default': the default initialization mode in the original WSN code. 497 2. 'N01': standard normal distribution $N(0, 1)$. 498 3. 'U01': uniform distribution $U(0, 1)$. 499 """ 500 501 for layer_name, weight_score in self.weight_score_t.items(): 502 if mode == "default": 503 # Kaiming Uniform Initialization for weight score 504 nn.init.kaiming_uniform_(weight_score.weight, a=math.sqrt(5)) 505 506 for layer_name, bias_score in self.bias_score_t.items(): 507 if bias_score is not None: 508 # For bias, follow the standard bias initialization using fan_in 509 weight_score = self.weight_score_t[layer_name] 510 fan_in, _ = nn.init._calculate_fan_in_and_fan_out( 511 weight_score.weight 512 ) 513 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 514 nn.init.uniform_(bias_score.weight, -bound, bound) 515 elif mode == "N01": 516 nn.init.normal_(weight_score.weight, 0, 1) 517 for layer_name, bias_score in self.bias_score_t.items(): 518 if bias_score is not None: 519 nn.init.normal_(bias_score.weight, 0, 1) 520 elif mode == "U01": 521 nn.init.uniform_(weight_score.weight, 0, 1) 522 for layer_name, bias_score in self.bias_score_t.items(): 523 if bias_score is not None: 524 nn.init.uniform_(bias_score.weight, 0, 1)
Initialize the parameter score for the current task.
Args:
- mode (
str): The initialization mode for parameter scores; one of:- 'default': the default initialization mode in the original WSN code.
- 'N01': standard normal distribution $N(0, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
526 def get_mask( 527 self, 528 stage: str, 529 mask_percentage: float, 530 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 531 ) -> dict[str, Tensor]: 532 r"""Get the binary parameter mask used in the `forward()` method for different stages. 533 534 **Args:** 535 - **stage** (`str`): The stage when applying the conversion; one of: 536 1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 537 2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) 538 3. 'test': testing stage. Apply the test mask directly from the argument `test_mask`. 539 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`. 540 541 **Returns:** 542 - **weight_mask** (`dict[str, Tensor]`): The binary mask on weights. Key (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. 543 - **bias_mask** (`dict[str, Tensor]`): The binary mask on biases. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 544 """ 545 weight_mask = {} 546 bias_mask = {} 547 if stage == "train" or stage == "validation": 548 for layer_name in self.weighted_layer_names: 549 weight_mask[layer_name] = self.gate_fn.apply( 550 self.weight_score_t[layer_name].weight, mask_percentage 551 ) 552 if self.bias_score_t[layer_name] is not None: 553 bias_mask[layer_name] = self.gate_fn.apply( 554 self.bias_score_t[layer_name].weight.squeeze( 555 0 556 ), # from (1, output_dim) to (output_dim, ) 557 mask_percentage, 558 ) 559 else: 560 bias_mask[layer_name] = None 561 elif stage == "test": 562 weight_mask, bias_mask = test_mask 563 564 return weight_mask, bias_mask
Get the binary parameter mask used in the forward() method for different stages.
Args:
- stage (
str): The stage when applying the conversion; one of:- 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the WSN paper.
- 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
- 'test': testing stage. Apply the test mask directly from the argument
test_mask.
- test_mask (
tuple[dict[str, Tensor], dict[str, Tensor]]|None): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it isNone.
Returns:
- weight_mask (
dict[str, Tensor]): The binary mask on weights. Key (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. - bias_mask (
dict[str, Tensor]): The binary mask on biases. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it isNone.
566 @override 567 def forward( 568 self, 569 input: Tensor, 570 stage: str, 571 mask_percentage: float, 572 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 573 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]: 574 r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units in each layer. 575 576 **Args:** 577 - **input** (`Tensor`): The input tensor from data. 578 - **stage** (`str`): The stage of the forward pass; one of: 579 1. 'train': training stage. 580 2. 'validation': validation stage. 581 3. 'test': testing stage. 582 - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1. 583 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is `None`. 584 585 **Returns:** 586 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 587 - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Key (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, input features) as the weight. 588 - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 589 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 590 """
The forward pass for data from task self.task_id. Task-specific mask for self.task_id are applied to the units in each layer.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- mask_percentage (
float): The percentage of parameters to be masked. The value should be between 0 and 1. - test_mask (
tuple[dict[str, Tensor], dict[str, Tensor]]|None): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it isNone.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - weight_mask (
dict[str, Tensor]): The weight mask for the current task. Key (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has same (output features, input features) as the weight. - bias_mask (
dict[str, Tensor]): The bias mask for the current task. Keys (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it isNone. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes.