clarena.backbones
Backbone Networks
This submodule provides the backbone neural network architectures for all paradigms in CLArena.
Here are the base classes for backbone networks, which inherit from PyTorch nn.Module:
Backbone: the base class for all backbone networks. Multi-task and single-task learning can use this class directly.CLBackbone: the base class for continual learning backbone networks, which incorporates mechanisms for managing continual learning tasks.HATMaskBackbone: the base class for backbones used in HAT (Hard Attention to the Task) CL algorithm.AmnesiacHATBackbone: The base class for backbones used in AmnesiacHAT CL algorithm.
WSNMaskBackbone: The base class for backbones used in WSN (Winning Subnetworks) CL algorithm.
Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement backbone networks:
1r""" 2 3# Backbone Networks 4 5This submodule provides the **backbone neural network architectures** for all paradigms in CLArena. 6 7Here are the base classes for backbone networks, which inherit from PyTorch `nn.Module`: 8 9- `Backbone`: the base class for all backbone networks. Multi-task and single-task learning can use this class directly. 10- `CLBackbone`: the base class for continual learning backbone networks, which incorporates mechanisms for managing continual learning tasks. 11 - `HATMaskBackbone`: the base class for backbones used in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) CL algorithm. 12 - `AmnesiacHATBackbone`: The base class for backbones used in AmnesiacHAT CL algorithm. 13 - `WSNMaskBackbone`: The base class for backbones used in [WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) CL algorithm. 14 15Please note that this is an API documentation. Please refer to the main documentation pages for more information about how to configure and implement backbone networks: 16 17- [**Configure Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/backbone-network) 18- [**Implement Custom Backbone Network**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/backbone-network) 19 20 21""" 22 23from .base import ( 24 Backbone, 25 CLBackbone, 26 HATMaskBackbone, 27 AmnesiacHATBackbone, 28 WSNMaskBackbone, 29) 30from .mlp import MLP, CLMLP 31from .hat_mask_mlp import HATMaskMLP 32from .amnesiac_hat_mlp import AmnesiacHATMLP 33from .wsn_mask_mlp import WSNMaskMLP 34from .resnet import ( 35 ResNet18, 36 ResNet34, 37 ResNet50, 38 ResNet101, 39 ResNet152, 40 CLResNet18, 41 CLResNet34, 42 CLResNet50, 43 CLResNet101, 44 CLResNet152, 45) 46from .hat_mask_resnet import ( 47 HATMaskResNet18, 48 HATMaskResNet34, 49 HATMaskResNet50, 50 HATMaskResNet101, 51 HATMaskResNet152, 52) 53from .amnesiac_hat_resnet import ( 54 AmnesiacHATResNet18, 55 AmnesiacHATResNet34, 56 AmnesiacHATResNet50, 57 AmnesiacHATResNet101, 58 AmnesiacHATResNet152, 59) 60 61 62__all__ = [ 63 "Backbone", 64 "CLBackbone", 65 "HATMaskBackbone", 66 "AmnesiacHATBackbone", 67 "WSNMaskBackbone", 68 "mlp", 69 "hat_mask_mlp", 70 "amnesiac_hat_mlp", 71 "wsn_mask_mlp", 72 "resnet", 73 "hat_mask_resnet", 74 "amnesiac_hat_resnet", 75]
26class Backbone(nn.Module): 27 r"""The base class for backbone networks.""" 28 29 def __init__(self, output_dim: int | None, **kwargs) -> None: 30 r""" 31 **Args:** 32 - **output_dim** (`int` | `None`): The output dimension that connects to output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 33 - **kwargs**: Reserved for multiple inheritance. 34 """ 35 super().__init__() 36 37 self.output_dim: int = output_dim 38 r"""The output dimension of the backbone network.""" 39 40 self.weighted_layer_names: list[str] = [] 41 r"""The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. **It must be provided in subclasses.** 42 43 The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the `nn.Module` internal naming mechanism with `.` replaced with `/`. For example: 44 - If a layer is assigned to `self.conv1`, the name becomes `conv1`. 45 - If `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. 46 - If a hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should have been `block.conv1` according to `nn.Module`'s rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a `ModuleDict`. 47 """ 48 49 def get_layer_by_name(self, layer_name: str | None) -> nn.Module | None: 50 r"""Get the layer by its name. 51 52 **Args:** 53 - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`. 54 55 **Returns:** 56 - **layer** (`nn.Module` | `None`): The layer. If `layer_name` is `None`, return `None`. 57 """ 58 if layer_name is None: 59 return None 60 61 for name, layer in self.named_modules(): 62 if name == layer_name.replace("/", "."): 63 return layer 64 65 def preceding_layer_name(self, layer_name: str | None) -> str | None: 66 r"""Get the name of the preceding layer of the given layer from the stored `weighted_layer_names`. 67 68 **Args:** 69 - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`. 70 71 **Returns:** 72 - **preceding_layer_name** (`str`): The name of the preceding layer. If the given layer is the first layer, return `None`. 73 """ 74 if layer_name is None: 75 return None 76 77 if layer_name not in self.weighted_layer_names: 78 raise ValueError( 79 f"The layer name {layer_name} doesn't exist in weighted layer names." 80 ) 81 82 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 83 if weighted_layer_idx == 0: 84 return None 85 preceding_layer_name = self.weighted_layer_names[weighted_layer_idx - 1] 86 return preceding_layer_name 87 88 def next_layer_name(self, layer_name: str) -> str: 89 r"""Get the name of the next layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the last layer of the BACKBONE, return `None`. 90 91 **Args:** 92 - **layer_name** (`str`): The name of the layer. 93 94 **Returns:** 95 - **next_layer_name** (`str`): The name of the next layer. 96 97 **Raises:** 98 - **ValueError**: If `layer_name` is not in the weighted layer order. 99 """ 100 101 if layer_name not in self.weighted_layer_names: 102 raise ValueError(f"The layer name {layer_name} doesn't exist.") 103 104 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 105 if weighted_layer_idx == len(self.weighted_layer_names) - 1: 106 return None 107 next_layer_name = self.weighted_layer_names[weighted_layer_idx + 1] 108 return next_layer_name 109 110 @override # since `nn.Module` uses it 111 def forward( 112 self, 113 input: Tensor, 114 stage: str, 115 ) -> tuple[Tensor, dict[str, Tensor]]: 116 r"""The forward pass. **It must be implemented by subclasses.** 117 118 **Args:** 119 - **input** (`Tensor`): The input tensor from data. 120 - **stage** (`str`): The stage of the forward pass; one of: 121 1. 'train': training stage. 122 2. 'validation': validation stage. 123 3. 'test': testing stage. 124 125 **Returns:** 126 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 127 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes. 128 """
The base class for backbone networks.
29 def __init__(self, output_dim: int | None, **kwargs) -> None: 30 r""" 31 **Args:** 32 - **output_dim** (`int` | `None`): The output dimension that connects to output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 33 - **kwargs**: Reserved for multiple inheritance. 34 """ 35 super().__init__() 36 37 self.output_dim: int = output_dim 38 r"""The output dimension of the backbone network.""" 39 40 self.weighted_layer_names: list[str] = [] 41 r"""The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. **It must be provided in subclasses.** 42 43 The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the `nn.Module` internal naming mechanism with `.` replaced with `/`. For example: 44 - If a layer is assigned to `self.conv1`, the name becomes `conv1`. 45 - If `nn.Sequential` is used, the name becomes the index of the layer in the sequence, such as `0`, `1`, etc. 46 - If a hierarchical structure is used, for example, a `nn.Module` is assigned to `self.block` which has `self.conv1`, the name becomes `block/conv1`. Note that it should have been `block.conv1` according to `nn.Module`'s rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in a `ModuleDict`. 47 """
Args:
- output_dim (
int|None): The output dimension that connects to output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - kwargs: Reserved for multiple inheritance.
The list of the weighted layer names in order (from input to output). A weighted layer has weights connecting to other weighted layers. They are the main part of neural networks. It must be provided in subclasses.
The layer names must match the names of weighted layers defined in the backbone and include all of them. The names follow the nn.Module internal naming mechanism with . replaced with /. For example:
- If a layer is assigned to
self.conv1, the name becomesconv1. - If
nn.Sequentialis used, the name becomes the index of the layer in the sequence, such as0,1, etc. - If a hierarchical structure is used, for example, a
nn.Moduleis assigned toself.blockwhich hasself.conv1, the name becomesblock/conv1. Note that it should have beenblock.conv1according tonn.Module's rules, but we use '/' instead of '.' to avoid errors when using '.' as keys in aModuleDict.
49 def get_layer_by_name(self, layer_name: str | None) -> nn.Module | None: 50 r"""Get the layer by its name. 51 52 **Args:** 53 - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`. 54 55 **Returns:** 56 - **layer** (`nn.Module` | `None`): The layer. If `layer_name` is `None`, return `None`. 57 """ 58 if layer_name is None: 59 return None 60 61 for name, layer in self.named_modules(): 62 if name == layer_name.replace("/", "."): 63 return layer
Get the layer by its name.
Args:
- layer_name (
str|None): The layer name following thenn.Moduleinternal naming mechanism with.replaced with/. IfNone, returnNone.
Returns:
- layer (
nn.Module|None): The layer. Iflayer_nameisNone, returnNone.
65 def preceding_layer_name(self, layer_name: str | None) -> str | None: 66 r"""Get the name of the preceding layer of the given layer from the stored `weighted_layer_names`. 67 68 **Args:** 69 - **layer_name** (`str` | `None`): The layer name following the `nn.Module` internal naming mechanism with `.` replaced with `/`. If `None`, return `None`. 70 71 **Returns:** 72 - **preceding_layer_name** (`str`): The name of the preceding layer. If the given layer is the first layer, return `None`. 73 """ 74 if layer_name is None: 75 return None 76 77 if layer_name not in self.weighted_layer_names: 78 raise ValueError( 79 f"The layer name {layer_name} doesn't exist in weighted layer names." 80 ) 81 82 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 83 if weighted_layer_idx == 0: 84 return None 85 preceding_layer_name = self.weighted_layer_names[weighted_layer_idx - 1] 86 return preceding_layer_name
Get the name of the preceding layer of the given layer from the stored weighted_layer_names.
Args:
- layer_name (
str|None): The layer name following thenn.Moduleinternal naming mechanism with.replaced with/. IfNone, returnNone.
Returns:
- preceding_layer_name (
str): The name of the preceding layer. If the given layer is the first layer, returnNone.
88 def next_layer_name(self, layer_name: str) -> str: 89 r"""Get the name of the next layer of the given layer from the stored `self.masked_layer_order`. If the given layer is the last layer of the BACKBONE, return `None`. 90 91 **Args:** 92 - **layer_name** (`str`): The name of the layer. 93 94 **Returns:** 95 - **next_layer_name** (`str`): The name of the next layer. 96 97 **Raises:** 98 - **ValueError**: If `layer_name` is not in the weighted layer order. 99 """ 100 101 if layer_name not in self.weighted_layer_names: 102 raise ValueError(f"The layer name {layer_name} doesn't exist.") 103 104 weighted_layer_idx = self.weighted_layer_names.index(layer_name) 105 if weighted_layer_idx == len(self.weighted_layer_names) - 1: 106 return None 107 next_layer_name = self.weighted_layer_names[weighted_layer_idx + 1] 108 return next_layer_name
Get the name of the next layer of the given layer from the stored self.masked_layer_order. If the given layer is the last layer of the BACKBONE, return None.
Args:
- layer_name (
str): The name of the layer.
Returns:
- next_layer_name (
str): The name of the next layer.
Raises:
- ValueError: If
layer_nameis not in the weighted layer order.
110 @override # since `nn.Module` uses it 111 def forward( 112 self, 113 input: Tensor, 114 stage: str, 115 ) -> tuple[Tensor, dict[str, Tensor]]: 116 r"""The forward pass. **It must be implemented by subclasses.** 117 118 **Args:** 119 - **input** (`Tensor`): The input tensor from data. 120 - **stage** (`str`): The stage of the forward pass; one of: 121 1. 'train': training stage. 122 2. 'validation': validation stage. 123 3. 'test': testing stage. 124 125 **Returns:** 126 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 127 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes. 128 """
The forward pass. It must be implemented by subclasses.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for certain algorithms that need to use the hidden features for various purposes.
131class CLBackbone(Backbone): 132 r"""The base class of continual learning backbone networks.""" 133 134 def __init__(self, output_dim: int | None, **kwargs) -> None: 135 r""" 136 **Args:** 137 - **output_dim** (`int` | `None`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 138 - **kwargs**: Reserved for multiple inheritance. 139 """ 140 super().__init__(output_dim=output_dim, **kwargs) 141 142 # task ID control 143 self.task_id: int 144 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 145 self.processed_task_ids: list[int] = [] 146 r"""Task IDs that have been processed.""" 147 148 def setup_task_id(self, task_id: int) -> None: 149 r"""Set up task `task_id`. This must be done before the `forward()` method is called.""" 150 self.task_id = task_id 151 self.processed_task_ids.append(task_id) 152 153 @override # since `nn.Module` uses it 154 def forward( 155 self, 156 input: Tensor, 157 stage: str, 158 task_id: int | None = None, 159 ) -> tuple[Tensor, dict[str, Tensor]]: 160 r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.** 161 162 **Args:** 163 - **input** (`Tensor`): The input tensor from data. 164 - **stage** (`str`): The stage of the forward pass; one of: 165 1. 'train': training stage. 166 2. 'validation': validation stage. 167 3. 'test': testing stage. 168 - **task_id** (`int` | `None`): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task `self.task_id`. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value. 169 170 **Returns:** 171 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 172 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes. 173 """
The base class of continual learning backbone networks.
134 def __init__(self, output_dim: int | None, **kwargs) -> None: 135 r""" 136 **Args:** 137 - **output_dim** (`int` | `None`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 138 - **kwargs**: Reserved for multiple inheritance. 139 """ 140 super().__init__(output_dim=output_dim, **kwargs) 141 142 # task ID control 143 self.task_id: int 144 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.""" 145 self.processed_task_ids: list[int] = [] 146 r"""Task IDs that have been processed."""
Args:
- output_dim (
int|None): The output dimension that connects to CL output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - kwargs: Reserved for multiple inheritance.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to the number of tasks in the CL dataset.
153 @override # since `nn.Module` uses it 154 def forward( 155 self, 156 input: Tensor, 157 stage: str, 158 task_id: int | None = None, 159 ) -> tuple[Tensor, dict[str, Tensor]]: 160 r"""The forward pass for data from task `task_id`. In some backbones, the forward pass might be different for different tasks. **It must be implemented by subclasses.** 161 162 **Args:** 163 - **input** (`Tensor`): The input tensor from data. 164 - **stage** (`str`): The stage of the forward pass; one of: 165 1. 'train': training stage. 166 2. 'validation': validation stage. 167 3. 'test': testing stage. 168 - **task_id** (`int` | `None`): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current task `self.task_id`. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value. 169 170 **Returns:** 171 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 172 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes. 173 """
The forward pass for data from task task_id. In some backbones, the forward pass might be different for different tasks. It must be implemented by subclasses.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int|None): The task ID where the data are from. If the stage is 'train' or 'validation', it is usually the current taskself.task_id. If the stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided; thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistency and is not used. Best practice is not to provide this argument and leave it as the default value.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
176class HATMaskBackbone(CLBackbone): 177 r"""The backbone network for HAT-based algorithms with learnable hard attention masks. 178 179 HAT-based algorithms include: 180 181 - [**HAT (Hard Attention to the Task, 2018)**](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters. 182 - [**AdaHAT (Adaptive Hard Attention to the Task, 2024)**](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves HAT by introducing adaptive soft gradient clipping based on parameter importance and network sparsity. 183 - **FG-AdaHAT** is an architecture-based continual learning approach that improves HAT by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT. 184 """ 185 186 def __init__(self, output_dim: int | None, gate: str, **kwargs) -> None: 187 r""" 188 **Args:** 189 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 190 - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of: 191 - `sigmoid`: the sigmoid function. 192 - **kwargs**: Reserved for multiple inheritance. 193 """ 194 super().__init__(output_dim=output_dim, **kwargs) 195 196 self.gate: str = gate 197 r"""The type of gate function.""" 198 self.gate_fn: Callable 199 r"""The gate function mapping the real value task embeddings into attention masks.""" 200 201 if gate == "sigmoid": 202 self.gate_fn = nn.Sigmoid() 203 204 self.task_embedding_t: nn.ModuleDict = nn.ModuleDict() 205 r"""The task embedding for the current task `self.task_id`. Keys are layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units). 206 207 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 208 209 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 210 211 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 212 """ 213 214 self.masks: dict[int, dict[str, Tensor]] = {} 215 r"""The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ). """ 216 217 HATMaskBackbone.sanity_check(self) 218 219 def initialize_task_embedding(self, mode: str) -> None: 220 r"""Initialize the task embedding for the current task `self.task_id`. 221 222 **Args:** 223 - **mode** (`str`): The initialization mode for task embeddings; one of: 224 1. 'N01' (default): standard normal distribution $N(0, 1)$. 225 2. 'U-11': uniform distribution $U(-1, 1)$. 226 3. 'U01': uniform distribution $U(0, 1)$. 227 4. 'U-10': uniform distribution $U(-1, 0)$. 228 5. 'last': inherit task embeddings from the last task. 229 """ 230 for te in self.task_embedding_t.values(): 231 if mode == "N01": 232 nn.init.normal_(te.weight, 0, 1) 233 elif mode == "U-11": 234 nn.init.uniform_(te.weight, -1, 1) 235 elif mode == "U01": 236 nn.init.uniform_(te.weight, 0, 1) 237 elif mode == "U-10": 238 nn.init.uniform_(te.weight, -1, 0) 239 elif mode == "last": 240 pass 241 242 def sanity_check(self) -> None: 243 r"""Sanity check.""" 244 245 if self.gate not in ["sigmoid"]: 246 raise ValueError("The gate should be one of: 'sigmoid'.") 247 248 def get_mask( 249 self, 250 stage: str, 251 s_max: float | None = None, 252 batch_idx: int | None = None, 253 num_batches: int | None = None, 254 test_task_id: int | None = None, 255 ) -> dict[str, Tensor]: 256 r"""Get the hard attention mask used in the `forward()` method for different stages. 257 258 **Args:** 259 - **stage** (`str`): The stage when applying the conversion; one of: 260 1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 261 2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by `s_max`, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) 262 3. 'test': testing stage. Apply the test mask directly from the stored masks using `test_task_id`. 263 4. 'unlearning_test': unlearning testing stage. The mask is set to all 1s for unlearning testing. 264 - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 265 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 266 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 267 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 268 269 **Returns:** 270 - **mask** (`dict[str, Tensor]`): The hard attention (with values 0 or 1) mask. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 271 """ 272 273 # sanity check 274 if stage == "train" and ( 275 s_max is None or batch_idx is None or num_batches is None 276 ): 277 raise ValueError( 278 "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`." 279 ) 280 if stage == "validation" and (s_max is None): 281 raise ValueError( 282 "The `s_max` should be provided at validation stage, instead of the default value `None`." 283 ) 284 if stage == "test" and (test_task_id is None): 285 raise ValueError( 286 "The `task_mask` should be provided at testing stage, instead of the default value `None`." 287 ) 288 289 mask = {} 290 if stage == "train": 291 for layer_name in self.weighted_layer_names: 292 anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / ( 293 num_batches - 1 294 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 295 mask[layer_name] = self.gate_fn( 296 self.task_embedding_t[layer_name].weight * anneal_scalar 297 ).squeeze() 298 elif stage == "validation": 299 for layer_name in self.weighted_layer_names: 300 mask[layer_name] = self.gate_fn( 301 self.task_embedding_t[layer_name].weight * s_max 302 ).squeeze() 303 elif stage == "test": 304 mask = self.masks[test_task_id] 305 for layer_name, layer_mask in mask.items(): 306 layer = self.get_layer_by_name(layer_name) 307 target_device = layer.weight.device 308 if layer_mask.device != target_device: 309 mask[layer_name] = layer_mask.to(target_device) 310 elif stage == "unlearning_test": 311 for layer_name in self.weighted_layer_names: 312 layer = self.get_layer_by_name(layer_name) 313 mask[layer_name] = torch.ones( 314 layer.weight.size(0), device=layer.weight.device 315 ) 316 317 return mask 318 319 def te_to_binary_mask(self) -> dict[str, Tensor]: 320 r"""Convert the current task embedding into a binary mask. 321 322 This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task. 323 324 **Returns:** 325 - **mask_t** (`dict[str, Tensor]`): The binary mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 326 """ 327 # get the mask for the current task 328 mask_t = { 329 layer_name: (self.task_embedding_t[layer_name].weight > 0) 330 .float() 331 .squeeze() 332 .detach() 333 for layer_name in self.weighted_layer_names 334 } 335 336 return mask_t 337 338 def combine_masks( 339 self, masks: list[dict[str, Tensor]], mode: str 340 ) -> dict[str, Tensor]: 341 r"""Combine multiple masks by taking their element-wise minimum (for intersection) / maximum (for union). 342 343 **Args:** 344 - **masks** (`list[dict[str, Tensor]]`): A list of masks. Each mask is a dict where keys are layer names and values are mask tensors. 345 - **mode** (`str`): The combination mode; one of: 346 - 'intersection': take the element-wise minimum of the masks (for intersection). 347 - 'union': take the element-wise maximum of the masks (for union). 348 349 **Returns:** 350 - **combined_mask** (`dict[str, Tensor]`): The combined mask. 351 """ 352 353 combined_mask = {} 354 for layer_name in masks[0].keys(): 355 layer_mask_tensors = torch.stack( 356 [mask[layer_name] for mask in masks], dim=0 357 ) 358 if mode == "intersection": 359 combined_mask[layer_name] = torch.min(layer_mask_tensors, dim=0).values 360 elif mode == "union": 361 combined_mask[layer_name] = torch.max(layer_mask_tensors, dim=0).values 362 else: 363 raise ValueError( 364 f"Unsupported mode: {mode}. Use 'intersection' or 'union'." 365 ) 366 367 return combined_mask 368 369 def store_mask(self) -> None: 370 r"""Store the mask for the current task `self.task_id`.""" 371 mask_t = self.te_to_binary_mask() 372 373 for subhatmodule in self.modules(): 374 if isinstance(subhatmodule, HATMaskBackbone): # for all sub HAT modules 375 subhatmodule.masks[self.task_id] = mask_t 376 377 return mask_t 378 379 def get_layer_measure_parameter_wise( 380 self, 381 neuron_wise_measure: dict[str, Tensor], 382 layer_name: str, 383 aggregation_mode: str, 384 ) -> Tensor: 385 r"""Get the parameter-wise measure on the parameters right before the given layer. 386 387 It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$. 388 389 Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer. 390 391 This method is used to calculate parameter-wise measures in various HAT-based algorithms: 392 393 - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks `cumulative_mask_for_previous_tasks`, which is $\text{Agg} \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in Eq. (2) in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 394 - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks `summative_mask_for_previous_tasks`, which is $\text{Agg} \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 395 - **FG-AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks `summative_importance_for_previous_tasks`, which is $\text{Agg} \left(I_{l,i}^{<t}, I_{l-1,j}^{<t}\right)$ in Eq. (2) in the FG-AdaHAT paper. 396 397 **Args:** 398 - **neuron_wise_measure** (`dict[str, Tensor]`): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ). 399 - **layer_name** (`str`): The name of the given layer. 400 - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of: 401 - 'min': takes the minimum of the two connected unit measures. 402 - 'max': takes the maximum of the two connected unit measures. 403 - 'mean': takes the mean of the two connected unit measures. 404 405 **Returns:** 406 - **weight_measure** (`Tensor`): The weight measure matrix, the same size as the corresponding weights. 407 - **bias_measure** (`Tensor`): The bias measure vector, the same size as the corresponding bias. 408 """ 409 410 # initialize the aggregation function 411 if aggregation_mode == "min": 412 aggregation_func = torch.min 413 elif aggregation_mode == "max": 414 aggregation_func = torch.max 415 elif aggregation_mode == "mean": 416 aggregation_func = torch.mean 417 else: 418 raise ValueError( 419 f"The aggregation method {aggregation_mode} is not supported." 420 ) 421 422 # get the preceding layer 423 preceding_layer_name = self.preceding_layer_name(layer_name) 424 425 # get weight size for expanding the measures 426 layer = self.get_layer_by_name(layer_name) 427 weight_size = layer.weight.size() 428 429 # construct the weight-wise measure 430 layer_measure = neuron_wise_measure[layer_name] 431 layer_measure_broadcast_size = (-1, 1) + tuple( 432 1 for _ in range(len(weight_size) - 2) 433 ) # since the size of mask tensor is (number of units, ), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 434 435 layer_measure_broadcasted = layer_measure.view( 436 *layer_measure_broadcast_size 437 ).expand( 438 weight_size, 439 ) # expand the given layer mask to the weight size and broadcast 440 441 if ( 442 preceding_layer_name 443 ): # if the layer is not the first layer, where the preceding layer exists 444 445 preceding_layer_measure_broadcast_size = (1, -1) + tuple( 446 1 for _ in range(len(weight_size) - 2) 447 ) # since the size of mask tensor is (number of units, ), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 448 preceding_layer_measure = neuron_wise_measure[preceding_layer_name] 449 preceding_layer_measure_broadcasted = preceding_layer_measure.view( 450 *preceding_layer_measure_broadcast_size 451 ).expand( 452 weight_size 453 ) # expand the preceding layer mask to the weight size and broadcast 454 weight_measure = aggregation_func( 455 layer_measure_broadcasted, preceding_layer_measure_broadcasted 456 ) # get the minimum of the two mask vectors, from expanded 457 else: # if the layer is the first layer 458 weight_measure = layer_measure_broadcasted 459 460 # construct the bias-wise measure 461 bias_measure = layer_measure 462 463 return weight_measure, bias_measure 464 465 @override 466 def forward( 467 self, 468 input: Tensor, 469 stage: str, 470 s_max: float | None = None, 471 batch_idx: int | None = None, 472 num_batches: int | None = None, 473 test_task_id: int | None = None, 474 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 475 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to the units in each layer. 476 477 **Args:** 478 - **input** (`Tensor`): The input tensor from data. 479 - **stage** (`str`): The stage of the forward pass; one of: 480 1. 'train': training stage. 481 2. 'validation': validation stage. 482 3. 'test': testing stage. 483 - **s_max** (`float`): The maximum scaling factor in the gate function. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 484 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 485 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 486 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 487 488 **Returns:** 489 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 490 - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 491 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class. 492 """
The backbone network for HAT-based algorithms with learnable hard attention masks.
HAT-based algorithms include:
- HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
- AdaHAT (Adaptive Hard Attention to the Task, 2024) is an architecture-based continual learning approach that improves HAT by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
- FG-AdaHAT is an architecture-based continual learning approach that improves HAT by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
186 def __init__(self, output_dim: int | None, gate: str, **kwargs) -> None: 187 r""" 188 **Args:** 189 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 190 - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of: 191 - `sigmoid`: the sigmoid function. 192 - **kwargs**: Reserved for multiple inheritance. 193 """ 194 super().__init__(output_dim=output_dim, **kwargs) 195 196 self.gate: str = gate 197 r"""The type of gate function.""" 198 self.gate_fn: Callable 199 r"""The gate function mapping the real value task embeddings into attention masks.""" 200 201 if gate == "sigmoid": 202 self.gate_fn = nn.Sigmoid() 203 204 self.task_embedding_t: nn.ModuleDict = nn.ModuleDict() 205 r"""The task embedding for the current task `self.task_id`. Keys are layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has size (1, number of units). 206 207 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 208 209 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 210 211 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 212 """ 213 214 self.masks: dict[int, dict[str, Tensor]] = {} 215 r"""The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ). """ 216 217 HATMaskBackbone.sanity_check(self)
Args:
- output_dim (
int): The output dimension that connects to CL output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - gate (
str): The type of gate function turning the real value task embeddings into attention masks; one of:sigmoid: the sigmoid function.
- kwargs: Reserved for multiple inheritance.
The task embedding for the current task self.task_id. Keys are layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has size (1, number of units).
We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.
We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)
This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
The binary attention mask of each previous task gated from the task embedding. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ).
219 def initialize_task_embedding(self, mode: str) -> None: 220 r"""Initialize the task embedding for the current task `self.task_id`. 221 222 **Args:** 223 - **mode** (`str`): The initialization mode for task embeddings; one of: 224 1. 'N01' (default): standard normal distribution $N(0, 1)$. 225 2. 'U-11': uniform distribution $U(-1, 1)$. 226 3. 'U01': uniform distribution $U(0, 1)$. 227 4. 'U-10': uniform distribution $U(-1, 0)$. 228 5. 'last': inherit task embeddings from the last task. 229 """ 230 for te in self.task_embedding_t.values(): 231 if mode == "N01": 232 nn.init.normal_(te.weight, 0, 1) 233 elif mode == "U-11": 234 nn.init.uniform_(te.weight, -1, 1) 235 elif mode == "U01": 236 nn.init.uniform_(te.weight, 0, 1) 237 elif mode == "U-10": 238 nn.init.uniform_(te.weight, -1, 0) 239 elif mode == "last": 240 pass
Initialize the task embedding for the current task self.task_id.
Args:
- mode (
str): The initialization mode for task embeddings; one of:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit task embeddings from the last task.
242 def sanity_check(self) -> None: 243 r"""Sanity check.""" 244 245 if self.gate not in ["sigmoid"]: 246 raise ValueError("The gate should be one of: 'sigmoid'.")
Sanity check.
248 def get_mask( 249 self, 250 stage: str, 251 s_max: float | None = None, 252 batch_idx: int | None = None, 253 num_batches: int | None = None, 254 test_task_id: int | None = None, 255 ) -> dict[str, Tensor]: 256 r"""Get the hard attention mask used in the `forward()` method for different stages. 257 258 **Args:** 259 - **stage** (`str`): The stage when applying the conversion; one of: 260 1. 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 261 2. 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by `s_max`, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) 262 3. 'test': testing stage. Apply the test mask directly from the stored masks using `test_task_id`. 263 4. 'unlearning_test': unlearning testing stage. The mask is set to all 1s for unlearning testing. 264 - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 265 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 266 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 267 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 268 269 **Returns:** 270 - **mask** (`dict[str, Tensor]`): The hard attention (with values 0 or 1) mask. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 271 """ 272 273 # sanity check 274 if stage == "train" and ( 275 s_max is None or batch_idx is None or num_batches is None 276 ): 277 raise ValueError( 278 "The `s_max`, `batch_idx` and `batch_num` should be provided at training stage, instead of the default value `None`." 279 ) 280 if stage == "validation" and (s_max is None): 281 raise ValueError( 282 "The `s_max` should be provided at validation stage, instead of the default value `None`." 283 ) 284 if stage == "test" and (test_task_id is None): 285 raise ValueError( 286 "The `task_mask` should be provided at testing stage, instead of the default value `None`." 287 ) 288 289 mask = {} 290 if stage == "train": 291 for layer_name in self.weighted_layer_names: 292 anneal_scalar = 1 / s_max + (s_max - 1 / s_max) * (batch_idx - 1) / ( 293 num_batches - 1 294 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 295 mask[layer_name] = self.gate_fn( 296 self.task_embedding_t[layer_name].weight * anneal_scalar 297 ).squeeze() 298 elif stage == "validation": 299 for layer_name in self.weighted_layer_names: 300 mask[layer_name] = self.gate_fn( 301 self.task_embedding_t[layer_name].weight * s_max 302 ).squeeze() 303 elif stage == "test": 304 mask = self.masks[test_task_id] 305 for layer_name, layer_mask in mask.items(): 306 layer = self.get_layer_by_name(layer_name) 307 target_device = layer.weight.device 308 if layer_mask.device != target_device: 309 mask[layer_name] = layer_mask.to(target_device) 310 elif stage == "unlearning_test": 311 for layer_name in self.weighted_layer_names: 312 layer = self.get_layer_by_name(layer_name) 313 mask[layer_name] = torch.ones( 314 layer.weight.size(0), device=layer.weight.device 315 ) 316 317 return mask
Get the hard attention mask used in the forward() method for different stages.
Args:
- stage (
str): The stage when applying the conversion; one of:- 'train': training stage. Get the mask from the current task embedding through the gate function, scaled by an annealed scalar. See Sec. 2.4 in the HAT paper.
- 'validation': validation stage. Get the mask from the current task embedding through the gate function, scaled by
s_max, where large scaling makes masks nearly binary. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) - 'test': testing stage. Apply the test mask directly from the stored masks using
test_task_id. - 'unlearning_test': unlearning testing stage. The mask is set to all 1s for unlearning testing.
- s_max (
float): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the HAT paper. - batch_idx (
int|None): The current batch index. Applies only to the training stage. For other stages, it isNone. - num_batches (
int|None): The total number of batches. Applies only to the training stage. For other stages, it isNone. - test_task_id (
int|None): The test task ID. Applies only to the testing stage. For other stages, it isNone.
Returns:
- mask (
dict[str, Tensor]): The hard attention (with values 0 or 1) mask. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
319 def te_to_binary_mask(self) -> dict[str, Tensor]: 320 r"""Convert the current task embedding into a binary mask. 321 322 This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task. 323 324 **Returns:** 325 - **mask_t** (`dict[str, Tensor]`): The binary mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 326 """ 327 # get the mask for the current task 328 mask_t = { 329 layer_name: (self.task_embedding_t[layer_name].weight > 0) 330 .float() 331 .squeeze() 332 .detach() 333 for layer_name in self.weighted_layer_names 334 } 335 336 return mask_t
Convert the current task embedding into a binary mask.
This method is used before the testing stage to convert the task embedding into a binary mask for each layer. The binary mask is used to select parameters for the current task.
Returns:
- mask_t (
dict[str, Tensor]): The binary mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
338 def combine_masks( 339 self, masks: list[dict[str, Tensor]], mode: str 340 ) -> dict[str, Tensor]: 341 r"""Combine multiple masks by taking their element-wise minimum (for intersection) / maximum (for union). 342 343 **Args:** 344 - **masks** (`list[dict[str, Tensor]]`): A list of masks. Each mask is a dict where keys are layer names and values are mask tensors. 345 - **mode** (`str`): The combination mode; one of: 346 - 'intersection': take the element-wise minimum of the masks (for intersection). 347 - 'union': take the element-wise maximum of the masks (for union). 348 349 **Returns:** 350 - **combined_mask** (`dict[str, Tensor]`): The combined mask. 351 """ 352 353 combined_mask = {} 354 for layer_name in masks[0].keys(): 355 layer_mask_tensors = torch.stack( 356 [mask[layer_name] for mask in masks], dim=0 357 ) 358 if mode == "intersection": 359 combined_mask[layer_name] = torch.min(layer_mask_tensors, dim=0).values 360 elif mode == "union": 361 combined_mask[layer_name] = torch.max(layer_mask_tensors, dim=0).values 362 else: 363 raise ValueError( 364 f"Unsupported mode: {mode}. Use 'intersection' or 'union'." 365 ) 366 367 return combined_mask
Combine multiple masks by taking their element-wise minimum (for intersection) / maximum (for union).
Args:
- masks (
list[dict[str, Tensor]]): A list of masks. Each mask is a dict where keys are layer names and values are mask tensors. - mode (
str): The combination mode; one of:- 'intersection': take the element-wise minimum of the masks (for intersection).
- 'union': take the element-wise maximum of the masks (for union).
Returns:
- combined_mask (
dict[str, Tensor]): The combined mask.
369 def store_mask(self) -> None: 370 r"""Store the mask for the current task `self.task_id`.""" 371 mask_t = self.te_to_binary_mask() 372 373 for subhatmodule in self.modules(): 374 if isinstance(subhatmodule, HATMaskBackbone): # for all sub HAT modules 375 subhatmodule.masks[self.task_id] = mask_t 376 377 return mask_t
Store the mask for the current task self.task_id.
379 def get_layer_measure_parameter_wise( 380 self, 381 neuron_wise_measure: dict[str, Tensor], 382 layer_name: str, 383 aggregation_mode: str, 384 ) -> Tensor: 385 r"""Get the parameter-wise measure on the parameters right before the given layer. 386 387 It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$. 388 389 Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer. 390 391 This method is used to calculate parameter-wise measures in various HAT-based algorithms: 392 393 - **HAT**: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks `cumulative_mask_for_previous_tasks`, which is $\text{Agg} \left(a_{l,i}^{<t}, a_{l-1,j}^{<t}\right)$ in Eq. (2) in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 394 - **AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks `summative_mask_for_previous_tasks`, which is $\text{Agg} \left(m_{l,i}^{<t,\text{sum}}, m_{l-1,j}^{<t,\text{sum}}\right)$ in Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 395 - **FG-AdaHAT**: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks `summative_importance_for_previous_tasks`, which is $\text{Agg} \left(I_{l,i}^{<t}, I_{l-1,j}^{<t}\right)$ in Eq. (2) in the FG-AdaHAT paper. 396 397 **Args:** 398 - **neuron_wise_measure** (`dict[str, Tensor]`): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ). 399 - **layer_name** (`str`): The name of the given layer. 400 - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of: 401 - 'min': takes the minimum of the two connected unit measures. 402 - 'max': takes the maximum of the two connected unit measures. 403 - 'mean': takes the mean of the two connected unit measures. 404 405 **Returns:** 406 - **weight_measure** (`Tensor`): The weight measure matrix, the same size as the corresponding weights. 407 - **bias_measure** (`Tensor`): The bias measure vector, the same size as the corresponding bias. 408 """ 409 410 # initialize the aggregation function 411 if aggregation_mode == "min": 412 aggregation_func = torch.min 413 elif aggregation_mode == "max": 414 aggregation_func = torch.max 415 elif aggregation_mode == "mean": 416 aggregation_func = torch.mean 417 else: 418 raise ValueError( 419 f"The aggregation method {aggregation_mode} is not supported." 420 ) 421 422 # get the preceding layer 423 preceding_layer_name = self.preceding_layer_name(layer_name) 424 425 # get weight size for expanding the measures 426 layer = self.get_layer_by_name(layer_name) 427 weight_size = layer.weight.size() 428 429 # construct the weight-wise measure 430 layer_measure = neuron_wise_measure[layer_name] 431 layer_measure_broadcast_size = (-1, 1) + tuple( 432 1 for _ in range(len(weight_size) - 2) 433 ) # since the size of mask tensor is (number of units, ), we extend it to (number of units, 1) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 434 435 layer_measure_broadcasted = layer_measure.view( 436 *layer_measure_broadcast_size 437 ).expand( 438 weight_size, 439 ) # expand the given layer mask to the weight size and broadcast 440 441 if ( 442 preceding_layer_name 443 ): # if the layer is not the first layer, where the preceding layer exists 444 445 preceding_layer_measure_broadcast_size = (1, -1) + tuple( 446 1 for _ in range(len(weight_size) - 2) 447 ) # since the size of mask tensor is (number of units, ), we extend it to (1, number of units) and expand it to the weight size. The weight size has 2 dimensions in fully connected layers and 4 dimensions in convolutional layers 448 preceding_layer_measure = neuron_wise_measure[preceding_layer_name] 449 preceding_layer_measure_broadcasted = preceding_layer_measure.view( 450 *preceding_layer_measure_broadcast_size 451 ).expand( 452 weight_size 453 ) # expand the preceding layer mask to the weight size and broadcast 454 weight_measure = aggregation_func( 455 layer_measure_broadcasted, preceding_layer_measure_broadcasted 456 ) # get the minimum of the two mask vectors, from expanded 457 else: # if the layer is the first layer 458 weight_measure = layer_measure_broadcasted 459 460 # construct the bias-wise measure 461 bias_measure = layer_measure 462 463 return weight_measure, bias_measure
Get the parameter-wise measure on the parameters right before the given layer.
It is calculated from the given neuron-wise measure. It aggregates two feature-sized vectors (corresponding to the given layer and the preceding layer) into a weight-wise matrix (corresponding to the weights in between) and a bias-wise vector (corresponding to the bias of the given layer), using the given aggregation method. For example, given two feature-sized measures $m_{l,i}$ and $m_{l-1,j}$ and 'min' aggregation, the parameter-wise measure is $\min \left(a_{l,i}, a_{l-1,j}\right)$, a matrix with respect to $i, j$.
Note that if the given layer is the first layer with no preceding layer, we will get the parameter-wise measure directly by broadcasting from the neuron-wise measure of the given layer.
This method is used to calculate parameter-wise measures in various HAT-based algorithms:
- HAT: the parameter-wise measure is the binary mask for previous tasks from the neuron-wise cumulative mask of previous tasks
cumulative_mask_for_previous_tasks, which is $\text{Agg} \left(a_{l,i}^{HAT paper. - AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise summative mask of previous tasks
summative_mask_for_previous_tasks, which is $\text{Agg} \left(m_{l,i}^{AdaHAT paper. - FG-AdaHAT: the parameter-wise measure is the parameter importance for previous tasks from the neuron-wise importance of previous tasks
summative_importance_for_previous_tasks, which is $\text{Agg} \left(I_{l,i}^{
Args:
- neuron_wise_measure (
dict[str, Tensor]): The neuron-wise measure. Keys are layer names and values are the neuron-wise measure tensor. The tensor has size (number of units, ). - layer_name (
str): The name of the given layer. - aggregation_mode (
str): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:- 'min': takes the minimum of the two connected unit measures.
- 'max': takes the maximum of the two connected unit measures.
- 'mean': takes the mean of the two connected unit measures.
Returns:
- weight_measure (
Tensor): The weight measure matrix, the same size as the corresponding weights. - bias_measure (
Tensor): The bias measure vector, the same size as the corresponding bias.
465 @override 466 def forward( 467 self, 468 input: Tensor, 469 stage: str, 470 s_max: float | None = None, 471 batch_idx: int | None = None, 472 num_batches: int | None = None, 473 test_task_id: int | None = None, 474 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 475 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to the units in each layer. 476 477 **Args:** 478 - **input** (`Tensor`): The input tensor from data. 479 - **stage** (`str`): The stage of the forward pass; one of: 480 1. 'train': training stage. 481 2. 'validation': validation stage. 482 3. 'test': testing stage. 483 - **s_max** (`float`): The maximum scaling factor in the gate function. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 484 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 485 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 486 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 487 488 **Returns:** 489 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 490 - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 491 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class. 492 """
The forward pass for data from task self.task_id. Task-specific masks for self.task_id are applied to the units in each layer.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- s_max (
float): The maximum scaling factor in the gate function. See Sec. 2.4 in the HAT paper. - batch_idx (
int|None): The current batch index. Applies only to the training stage. For other stages, it isNone. - num_batches (
int|None): The total number of batches. Applies only to the training stage. For other stages, it isNone. - test_task_id (
int|None): The test task ID. Applies only to the testing stage. For other stages, it isNone.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - mask (
dict[str, Tensor]): The mask for the current task. Keys (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ). - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit thisforward()method of theHATclass.
495class AmnesiacHATBackbone(HATMaskBackbone): 496 r"""The backbone network for AmnesiacHAT on top of HAT. AmnesiacHAT introduces a parallel backup backbone in case of effects caused by unlearning.""" 497 498 original_backbone_class: type[Backbone] 499 r"""The original backbone class used to instantiate backup backbones. Must be defined in subclasses.""" 500 501 def __init__( 502 self, 503 output_dim: int | None, 504 gate: str, 505 disable_unlearning: bool = False, 506 **kwargs, 507 ) -> None: 508 r""" 509 **Args:** 510 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 511 - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of: 512 - `sigmoid`: the sigmoid function. 513 - **disable_unlearning** (`bool`): whether to disable unlearning. This is used in reference experiments following continual learning pipeline. Default is `False`. 514 - **kwargs**: Reserved for multiple inheritance. 515 """ 516 super().__init__(output_dim=output_dim, gate=gate, **kwargs) 517 518 self.disable_unlearning: bool = disable_unlearning 519 r"""Whether to disable unlearning. This is used in reference experiments following continual learning pipeline.""" 520 521 if not disable_unlearning: 522 self.backup_backbones: nn.ModuleDict 523 r"""The backup backbone networks. Keys are task IDs (in string format because ModuleDict keys have to be strings) that the backbone is backed up in case of which is unlearned, and values are the corresponding backbone networks that the backup is trained on. They all have the same architecture as the main backbone network. 524 525 Please note that we use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for training. DO NOT change this to `dict`.""" 526 527 self.backup_task_ids: list[int] 528 r"""The task IDs that need to have backup backbones at current task `self.task_id`.""" 529 530 self.backup_state_dicts: dict[tuple[int, int], dict[str, Tensor]] = {} 531 r"""The backup state dict for each task. Keys are tuples (backup task IDs, the task ID that the backup is for) and values are the corresponding state dicts.""" 532 533 def instantiate_backup_backbones( 534 self, 535 backup_task_ids: list[int], 536 ) -> None: 537 r"""Instantiate the backup backbone network for the current task. This is called when a new task is created. 538 539 **Args:** 540 - **backup_task_ids** (`list[int]`): The list of task IDs to backup at current task `self.task_id`. 541 """ 542 543 self.backup_task_ids = backup_task_ids 544 545 self.backup_backbones = nn.ModuleDict( 546 { 547 f"{task_id_to_backup}": self.original_backbone_class( 548 **self.backup_backbone_kwargs, 549 ) 550 for task_id_to_backup in backup_task_ids 551 } 552 ) 553 554 pylogger.debug( 555 "Backup backbones (backing up task IDs %s) for current task ID %d have been instantiated.", 556 backup_task_ids, 557 self.task_id, 558 )
The backbone network for AmnesiacHAT on top of HAT. AmnesiacHAT introduces a parallel backup backbone in case of effects caused by unlearning.
501 def __init__( 502 self, 503 output_dim: int | None, 504 gate: str, 505 disable_unlearning: bool = False, 506 **kwargs, 507 ) -> None: 508 r""" 509 **Args:** 510 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 511 - **gate** (`str`): The type of gate function turning the real value task embeddings into attention masks; one of: 512 - `sigmoid`: the sigmoid function. 513 - **disable_unlearning** (`bool`): whether to disable unlearning. This is used in reference experiments following continual learning pipeline. Default is `False`. 514 - **kwargs**: Reserved for multiple inheritance. 515 """ 516 super().__init__(output_dim=output_dim, gate=gate, **kwargs) 517 518 self.disable_unlearning: bool = disable_unlearning 519 r"""Whether to disable unlearning. This is used in reference experiments following continual learning pipeline.""" 520 521 if not disable_unlearning: 522 self.backup_backbones: nn.ModuleDict 523 r"""The backup backbone networks. Keys are task IDs (in string format because ModuleDict keys have to be strings) that the backbone is backed up in case of which is unlearned, and values are the corresponding backbone networks that the backup is trained on. They all have the same architecture as the main backbone network. 524 525 Please note that we use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for training. DO NOT change this to `dict`.""" 526 527 self.backup_task_ids: list[int] 528 r"""The task IDs that need to have backup backbones at current task `self.task_id`.""" 529 530 self.backup_state_dicts: dict[tuple[int, int], dict[str, Tensor]] = {} 531 r"""The backup state dict for each task. Keys are tuples (backup task IDs, the task ID that the backup is for) and values are the corresponding state dicts."""
Args:
- output_dim (
int): The output dimension that connects to CL output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - gate (
str): The type of gate function turning the real value task embeddings into attention masks; one of:sigmoid: the sigmoid function.
- disable_unlearning (
bool): whether to disable unlearning. This is used in reference experiments following continual learning pipeline. Default isFalse. - kwargs: Reserved for multiple inheritance.
The original backbone class used to instantiate backup backbones. Must be defined in subclasses.
Whether to disable unlearning. This is used in reference experiments following continual learning pipeline.
533 def instantiate_backup_backbones( 534 self, 535 backup_task_ids: list[int], 536 ) -> None: 537 r"""Instantiate the backup backbone network for the current task. This is called when a new task is created. 538 539 **Args:** 540 - **backup_task_ids** (`list[int]`): The list of task IDs to backup at current task `self.task_id`. 541 """ 542 543 self.backup_task_ids = backup_task_ids 544 545 self.backup_backbones = nn.ModuleDict( 546 { 547 f"{task_id_to_backup}": self.original_backbone_class( 548 **self.backup_backbone_kwargs, 549 ) 550 for task_id_to_backup in backup_task_ids 551 } 552 ) 553 554 pylogger.debug( 555 "Backup backbones (backing up task IDs %s) for current task ID %d have been instantiated.", 556 backup_task_ids, 557 self.task_id, 558 )
Instantiate the backup backbone network for the current task. This is called when a new task is created.
Args:
- backup_task_ids (
list[int]): The list of task IDs to backup at current taskself.task_id.
561class WSNMaskBackbone(CLBackbone): 562 r"""The backbone network for the WSN algorithm with learnable parameter masks. 563 564 [WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) is an architecture-based continual learning algorithm. It trains learnable parameter-wise scores and selects the most scored $c\%$ of the network parameters to be used for each task. 565 """ 566 567 def __init__(self, output_dim: int | None, **kwargs) -> None: 568 r""" 569 **Args:** 570 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 571 - **kwargs**: Reserved for multiple inheritance. 572 """ 573 super().__init__(output_dim=output_dim, **kwargs) 574 575 self.gate_fn: torch.autograd.Function = PercentileLayerParameterMaskingByScore 576 r"""The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score.""" 577 578 self.weight_score_t: nn.ModuleDict = nn.ModuleDict() 579 r"""The weight score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (output features, input features) as the weight. 580 581 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 582 583 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 584 585 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 586 """ 587 588 self.bias_score_t: nn.ModuleDict = nn.ModuleDict() 589 r"""The bias score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is `None`. 590 591 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 592 593 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 594 595 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 596 """ 597 598 WSNMaskBackbone.sanity_check(self) 599 600 def sanity_check(self) -> None: 601 r"""Sanity check.""" 602 603 def initialize_parameter_score(self, mode: str) -> None: 604 r"""Initialize the parameter score for the current task. 605 606 **Args:** 607 - **mode** (`str`): The initialization mode for parameter scores; one of: 608 1. 'default': the default initialization mode in the original WSN code. 609 2. 'N01': standard normal distribution $N(0, 1)$. 610 3. 'U01': uniform distribution $U(0, 1)$. 611 """ 612 613 for layer_name, weight_score in self.weight_score_t.items(): 614 if mode == "default": 615 # Kaiming Uniform Initialization for weight score 616 nn.init.kaiming_uniform_(weight_score.weight, a=math.sqrt(5)) 617 618 for layer_name, bias_score in self.bias_score_t.items(): 619 if bias_score is not None: 620 # For bias, follow the standard bias initialization using fan_in 621 weight_score = self.weight_score_t[layer_name] 622 fan_in, _ = nn.init._calculate_fan_in_and_fan_out( 623 weight_score.weight 624 ) 625 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 626 nn.init.uniform_(bias_score.weight, -bound, bound) 627 elif mode == "N01": 628 nn.init.normal_(weight_score.weight, 0, 1) 629 for layer_name, bias_score in self.bias_score_t.items(): 630 if bias_score is not None: 631 nn.init.normal_(bias_score.weight, 0, 1) 632 elif mode == "U01": 633 nn.init.uniform_(weight_score.weight, 0, 1) 634 for layer_name, bias_score in self.bias_score_t.items(): 635 if bias_score is not None: 636 nn.init.uniform_(bias_score.weight, 0, 1) 637 638 def get_mask( 639 self, 640 stage: str, 641 mask_percentage: float, 642 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 643 ) -> dict[str, Tensor]: 644 r"""Get the binary parameter mask used in the `forward()` method for different stages. 645 646 **Args:** 647 - **stage** (`str`): The stage when applying the conversion; one of: 648 1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 649 2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) 650 3. 'test': testing stage. Apply the test mask directly from the argument `test_mask`. 651 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`. 652 653 **Returns:** 654 - **weight_mask** (`dict[str, Tensor]`): The binary mask on weights. Key (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. 655 - **bias_mask** (`dict[str, Tensor]`): The binary mask on biases. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 656 """ 657 weight_mask = {} 658 bias_mask = {} 659 if stage == "train" or stage == "validation": 660 for layer_name in self.weighted_layer_names: 661 weight_mask[layer_name] = self.gate_fn.apply( 662 self.weight_score_t[layer_name].weight, mask_percentage 663 ) 664 if self.bias_score_t[layer_name] is not None: 665 bias_mask[layer_name] = self.gate_fn.apply( 666 self.bias_score_t[layer_name].weight.squeeze( 667 0 668 ), # from (1, output_dim) to (output_dim, ) 669 mask_percentage, 670 ) 671 else: 672 bias_mask[layer_name] = None 673 elif stage == "test": 674 weight_mask, bias_mask = test_mask 675 676 return weight_mask, bias_mask 677 678 @override 679 def forward( 680 self, 681 input: Tensor, 682 stage: str, 683 mask_percentage: float, 684 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 685 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]: 686 r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units in each layer. 687 688 **Args:** 689 - **input** (`Tensor`): The input tensor from data. 690 - **stage** (`str`): The stage of the forward pass; one of: 691 1. 'train': training stage. 692 2. 'validation': validation stage. 693 3. 'test': testing stage. 694 - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1. 695 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is `None`. 696 697 **Returns:** 698 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 699 - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Key (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, input features) as the weight. 700 - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 701 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 702 """
The backbone network for the WSN algorithm with learnable parameter masks.
WSN (Winning Subnetworks, 2022) is an architecture-based continual learning algorithm. It trains learnable parameter-wise scores and selects the most scored $c\%$ of the network parameters to be used for each task.
567 def __init__(self, output_dim: int | None, **kwargs) -> None: 568 r""" 569 **Args:** 570 - **output_dim** (`int`): The output dimension that connects to CL output heads. The `input_dim` of output heads is expected to be the same as this `output_dim`. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can be `None`. 571 - **kwargs**: Reserved for multiple inheritance. 572 """ 573 super().__init__(output_dim=output_dim, **kwargs) 574 575 self.gate_fn: torch.autograd.Function = PercentileLayerParameterMaskingByScore 576 r"""The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score.""" 577 578 self.weight_score_t: nn.ModuleDict = nn.ModuleDict() 579 r"""The weight score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (output features, input features) as the weight. 580 581 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 582 583 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 584 585 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 586 """ 587 588 self.bias_score_t: nn.ModuleDict = nn.ModuleDict() 589 r"""The bias score for the current task `self.task_id`. Keys are the layer names and values are the task embedding `nn.Embedding` for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is `None`. 590 591 We use `ModuleDict` rather than `dict` to ensure `LightningModule` properly registers these model parameters for purposes such as automatic device transfer and model summaries. 592 593 We use `nn.Embedding` rather than `nn.Parameter` to store the task embedding for each layer, which is a type of `nn.Module` and can be accepted by `nn.ModuleDict`. (`nn.Parameter` cannot be accepted by `nn.ModuleDict`.) 594 595 **This must be defined to cover each weighted layer (as listed in `weighted_layer_names`) in the backbone network.** Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting. 596 """ 597 598 WSNMaskBackbone.sanity_check(self)
Args:
- output_dim (
int): The output dimension that connects to CL output heads. Theinput_dimof output heads is expected to be the same as thisoutput_dim. In some cases, this class is used as a block in the backbone network that doesn't have an output dimension. In this case, it can beNone. - kwargs: Reserved for multiple inheritance.
The gate function mapping the real-value parameter score into binary parameter masks. It is a custom autograd function that applies percentile parameter masking by score.
The weight score for the current task self.task_id. Keys are the layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has the same size (output features, input features) as the weight.
We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.
We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)
This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
The bias score for the current task self.task_id. Keys are the layer names and values are the task embedding nn.Embedding for the layer. Each task embedding has the same size (1, output features) as the bias. If the layer doesn't have a bias, it is None.
We use ModuleDict rather than dict to ensure LightningModule properly registers these model parameters for purposes such as automatic device transfer and model summaries.
We use nn.Embedding rather than nn.Parameter to store the task embedding for each layer, which is a type of nn.Module and can be accepted by nn.ModuleDict. (nn.Parameter cannot be accepted by nn.ModuleDict.)
This must be defined to cover each weighted layer (as listed in weighted_layer_names) in the backbone network. Otherwise, the uncovered parts will keep being updated for all tasks and become the source of catastrophic forgetting.
603 def initialize_parameter_score(self, mode: str) -> None: 604 r"""Initialize the parameter score for the current task. 605 606 **Args:** 607 - **mode** (`str`): The initialization mode for parameter scores; one of: 608 1. 'default': the default initialization mode in the original WSN code. 609 2. 'N01': standard normal distribution $N(0, 1)$. 610 3. 'U01': uniform distribution $U(0, 1)$. 611 """ 612 613 for layer_name, weight_score in self.weight_score_t.items(): 614 if mode == "default": 615 # Kaiming Uniform Initialization for weight score 616 nn.init.kaiming_uniform_(weight_score.weight, a=math.sqrt(5)) 617 618 for layer_name, bias_score in self.bias_score_t.items(): 619 if bias_score is not None: 620 # For bias, follow the standard bias initialization using fan_in 621 weight_score = self.weight_score_t[layer_name] 622 fan_in, _ = nn.init._calculate_fan_in_and_fan_out( 623 weight_score.weight 624 ) 625 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 626 nn.init.uniform_(bias_score.weight, -bound, bound) 627 elif mode == "N01": 628 nn.init.normal_(weight_score.weight, 0, 1) 629 for layer_name, bias_score in self.bias_score_t.items(): 630 if bias_score is not None: 631 nn.init.normal_(bias_score.weight, 0, 1) 632 elif mode == "U01": 633 nn.init.uniform_(weight_score.weight, 0, 1) 634 for layer_name, bias_score in self.bias_score_t.items(): 635 if bias_score is not None: 636 nn.init.uniform_(bias_score.weight, 0, 1)
Initialize the parameter score for the current task.
Args:
- mode (
str): The initialization mode for parameter scores; one of:- 'default': the default initialization mode in the original WSN code.
- 'N01': standard normal distribution $N(0, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
638 def get_mask( 639 self, 640 stage: str, 641 mask_percentage: float, 642 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 643 ) -> dict[str, Tensor]: 644 r"""Get the binary parameter mask used in the `forward()` method for different stages. 645 646 **Args:** 647 - **stage** (`str`): The stage when applying the conversion; one of: 648 1. 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 649 2. 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.) 650 3. 'test': testing stage. Apply the test mask directly from the argument `test_mask`. 651 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`. 652 653 **Returns:** 654 - **weight_mask** (`dict[str, Tensor]`): The binary mask on weights. Key (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. 655 - **bias_mask** (`dict[str, Tensor]`): The binary mask on biases. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 656 """ 657 weight_mask = {} 658 bias_mask = {} 659 if stage == "train" or stage == "validation": 660 for layer_name in self.weighted_layer_names: 661 weight_mask[layer_name] = self.gate_fn.apply( 662 self.weight_score_t[layer_name].weight, mask_percentage 663 ) 664 if self.bias_score_t[layer_name] is not None: 665 bias_mask[layer_name] = self.gate_fn.apply( 666 self.bias_score_t[layer_name].weight.squeeze( 667 0 668 ), # from (1, output_dim) to (output_dim, ) 669 mask_percentage, 670 ) 671 else: 672 bias_mask[layer_name] = None 673 elif stage == "test": 674 weight_mask, bias_mask = test_mask 675 676 return weight_mask, bias_mask
Get the binary parameter mask used in the forward() method for different stages.
Args:
- stage (
str): The stage when applying the conversion; one of:- 'train': training stage. Get the mask from the parameter score of the current task through the gate function that masks the top $c\%$ largest scored parameters. See Sec. 3.1 "Winning Subnetworks" in the WSN paper.
- 'validation': validation stage. Same as 'train'. (Note that in this stage, the binary mask hasn't been stored yet, as training is not over.)
- 'test': testing stage. Apply the test mask directly from the argument
test_mask.
- test_mask (
tuple[dict[str, Tensor], dict[str, Tensor]]|None): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it isNone.
Returns:
- weight_mask (
dict[str, Tensor]): The binary mask on weights. Key (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. - bias_mask (
dict[str, Tensor]): The binary mask on biases. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it isNone.
678 @override 679 def forward( 680 self, 681 input: Tensor, 682 stage: str, 683 mask_percentage: float, 684 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 685 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]: 686 r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units in each layer. 687 688 **Args:** 689 - **input** (`Tensor`): The input tensor from data. 690 - **stage** (`str`): The stage of the forward pass; one of: 691 1. 'train': training stage. 692 2. 'validation': validation stage. 693 3. 'test': testing stage. 694 - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1. 695 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it is `None`. 696 697 **Returns:** 698 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 699 - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Key (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, input features) as the weight. 700 - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are layer names and values (`Tensor`) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 701 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 702 """
The forward pass for data from task self.task_id. Task-specific mask for self.task_id are applied to the units in each layer.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- mask_percentage (
float): The percentage of parameters to be masked. The value should be between 0 and 1. - test_mask (
tuple[dict[str, Tensor], dict[str, Tensor]]|None): The binary weight and bias mask used for test. Applies only to the testing stage. For other stages, it isNone.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - weight_mask (
dict[str, Tensor]): The weight mask for the current task. Key (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has same (output features, input features) as the weight. - bias_mask (
dict[str, Tensor]): The bias mask for the current task. Keys (str) are layer names and values (Tensor) are the mask tensors. The mask tensor has same (output features, ) as the bias. If the layer doesn't have a bias, it isNone. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for the continual learning algorithms that need to use the hidden features for various purposes.