clarena.backbones.mlp
The submodule in backbones for the MLP backbone network. It includes multiple versions of MLP, including the basic MLP, the continual learning MLP, the HAT masked MLP, and the WSN masked MLP.
1r""" 2The submodule in `backbones` for the MLP backbone network. It includes multiple versions of MLP, including the basic MLP, the continual learning MLP, the [HAT](http://proceedings.mlr.press/v80/serra18a) masked MLP, and the [WSN](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) masked MLP. 3""" 4 5__all__ = ["MLP", "CLMLP", "HATMaskMLP", "WSNMaskMLP"] 6 7import logging 8from copy import deepcopy 9 10from torch import Tensor, nn 11 12from clarena.backbones import Backbone, CLBackbone, HATMaskBackbone, WSNMaskBackbone 13 14# always get logger for built-in logging in each module 15pylogger = logging.getLogger(__name__) 16 17 18class MLP(Backbone): 19 """Multi-layer perceptron (MLP), a.k.a. fully connected network. 20 21 MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the output heads. 22 """ 23 24 def __init__( 25 self, 26 input_dim: int, 27 hidden_dims: list[int], 28 output_dim: int, 29 activation_layer: nn.Module | None = nn.ReLU, 30 batch_normalization: bool = False, 31 bias: bool = True, 32 dropout: float | None = None, 33 **kwargs, 34 ) -> None: 35 r"""Construct and initialize the MLP backbone network. 36 37 **Args:** 38 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. Note that it is not required in convolutional networks. 39 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 40 - **output_dim** (`int`): The output dimension that connects to output heads. 41 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 42 - **batch_normalization** (`bool`): Whether to use batch normalization after the fully connected layers. Default `False`. 43 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 44 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. 45 - **kwargs**: Reserved for multiple inheritance. 46 """ 47 super().__init__(output_dim=output_dim, **kwargs) 48 49 self.input_dim: int = input_dim 50 r"""The input dimension of the MLP backbone network.""" 51 self.hidden_dims: list[int] = hidden_dims 52 r"""The hidden dimensions of the MLP backbone network.""" 53 self.output_dim: int = output_dim 54 r"""The output dimension of the MLP backbone network.""" 55 56 self.num_fc_layers: int = len(hidden_dims) + 1 57 r"""The number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass.""" 58 self.batch_normalization: bool = batch_normalization 59 r"""Whether to use batch normalization after the fully-connected layers.""" 60 self.activation: bool = activation_layer is not None 61 r"""Whether to use activation function after the fully-connected layers.""" 62 self.dropout: bool = dropout is not None 63 r"""Whether to use dropout after the fully-connected layers.""" 64 65 self.fc: nn.ModuleList = nn.ModuleList() 66 r"""The list of fully connected (`nn.Linear`) layers.""" 67 if self.batch_normalization: 68 self.fc_bn: nn.ModuleList = nn.ModuleList() 69 r"""The list of batch normalization (`nn.BatchNorm1d`) layers after the fully connected layers.""" 70 if self.activation: 71 self.fc_activation: nn.ModuleList = nn.ModuleList() 72 r"""The list of activation layers after the fully connected layers.""" 73 if self.dropout: 74 self.fc_dropout: nn.ModuleList = nn.ModuleList() 75 r"""The list of dropout layers after the fully connected layers.""" 76 77 # construct the weighted fully connected layers and attached layers (batch norm, activation, dropout, etc.) in a loop 78 for layer_idx in range(self.num_fc_layers): 79 80 # the input and output dim of the current weighted layer 81 layer_input_dim = ( 82 self.input_dim if layer_idx == 0 else self.hidden_dims[layer_idx - 1] 83 ) 84 layer_output_dim = ( 85 self.hidden_dims[layer_idx] 86 if layer_idx != len(self.hidden_dims) 87 else self.output_dim 88 ) 89 90 # construct the fully connected layer 91 self.fc.append( 92 nn.Linear( 93 in_features=layer_input_dim, 94 out_features=layer_output_dim, 95 bias=bias, 96 ) 97 ) 98 99 # update the weighted layer names 100 full_layer_name = f"fc/{layer_idx}" 101 self.weighted_layer_names.append(full_layer_name) 102 103 # construct the batch normalization layer 104 if self.batch_normalization: 105 self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim))) 106 107 # construct the activation layer 108 if self.activation: 109 self.fc_activation.append(activation_layer()) 110 111 # construct the dropout layer 112 if self.dropout: 113 self.fc_dropout.append(nn.Dropout(dropout)) 114 115 def forward( 116 self, input: Tensor, stage: str = None 117 ) -> tuple[Tensor, dict[str, Tensor]]: 118 r"""The forward pass for data. It is the same for all tasks. 119 120 **Args:** 121 - **input** (`Tensor`): The input tensor from data. 122 123 **Returns:** 124 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 125 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for certain algorithms that need to use hidden features for various purposes. 126 """ 127 batch_size = input.size(0) 128 activations = {} 129 130 x = input.view(batch_size, -1) # flatten before going through MLP 131 132 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 133 x = self.fc[layer_idx](x) # fully-connected layer first 134 if self.batch_normalization: 135 x = self.fc_bn[layer_idx]( 136 x 137 ) # batch normalization can be before or after activation. We put it before activation here 138 if self.activation: 139 x = self.fc_activation[layer_idx](x) # activation function third 140 activations[layer_name] = x # store the hidden feature 141 if self.dropout: 142 x = self.fc_dropout[layer_idx](x) # dropout last 143 144 output_feature = x 145 146 return output_feature, activations 147 148 149class CLMLP(CLBackbone, MLP): 150 """Multi-layer perceptron (MLP), a.k.a. fully connected network. Used as a continual learning backbone. 151 152 MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads. 153 """ 154 155 def __init__( 156 self, 157 input_dim: int, 158 hidden_dims: list[int], 159 output_dim: int, 160 activation_layer: nn.Module | None = nn.ReLU, 161 batch_normalization: bool = False, 162 bias: bool = True, 163 dropout: float | None = None, 164 **kwargs, 165 ) -> None: 166 r"""Construct and initialize the CLMLP backbone network. 167 168 **Args:** 169 - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks. 170 - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension. 171 - **output_dim** (`int`): the output dimension which connects to CL output heads. 172 - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`. 173 - **batch_normalization** (`bool`): whether to use batch normalization after the fully-connected layers. Default `False`. 174 - **bias** (`bool`): whether to use bias in the linear layer. Default `True`. 175 - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 176 """ 177 super().__init__( 178 input_dim=input_dim, 179 hidden_dims=hidden_dims, 180 output_dim=output_dim, 181 activation_layer=activation_layer, 182 batch_normalization=batch_normalization, 183 bias=bias, 184 dropout=dropout, 185 **kwargs, 186 ) 187 188 def forward( 189 self, input: Tensor, stage: str = None, task_id: int | None = None 190 ) -> tuple[Tensor, dict[str, Tensor]]: 191 r"""The forward pass for data. It is the same for all tasks. 192 193 **Args:** 194 - **input** (`Tensor`): The input tensor from data. 195 196 **Returns:** 197 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 198 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name; value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes. 199 """ 200 return MLP.forward(self, input, stage) # call the MLP forward method 201 202 203class HATMaskMLP(HATMaskBackbone, MLP): 204 r"""HAT-masked multi-layer perceptron (MLP). 205 206 [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters. 207 208 MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads. 209 210 The mask is applied to units (neurons) in each fully connected layer. The mask is generated from the neuron-wise task embedding and the gate function. 211 """ 212 213 def __init__( 214 self, 215 input_dim: int, 216 hidden_dims: list[int], 217 output_dim: int, 218 gate: str, 219 activation_layer: nn.Module | None = nn.ReLU, 220 batch_normalization: str | None = None, 221 bias: bool = True, 222 dropout: float | None = None, 223 ) -> None: 224 r"""Construct and initialize the HAT-masked MLP backbone network with task embeddings. Note that batch normalization is incompatible with the HAT mechanism. 225 226 **Args:** 227 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. 228 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 229 - **output_dim** (`int`): The output dimension that connects to CL output heads. 230 - **gate** (`str`): The type of gate function turning real-valued task embeddings into attention masks; one of: 231 - `sigmoid`: the sigmoid function. 232 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 233 - **batch_normalization** (`str` | `None`): How to use batch normalization after the fully connected layers; one of: 234 - `None`: no batch normalization layers. 235 - `shared`: use a single batch normalization layer for all tasks. Note that this can cause catastrophic forgetting. 236 - `independent`: use independent batch normalization layers for each task. 237 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 238 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 239 """ 240 super().__init__( 241 output_dim=output_dim, 242 gate=gate, 243 input_dim=input_dim, 244 hidden_dims=hidden_dims, 245 activation_layer=activation_layer, 246 batch_normalization=( 247 True if batch_normalization == "shared" or "independent" else False 248 ), 249 bias=bias, 250 dropout=dropout, 251 ) 252 253 # construct the task embedding for each weighted layer 254 for layer_idx in range(self.num_fc_layers): 255 full_layer_name = f"fc/{layer_idx}" 256 layer_output_dim = ( 257 hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim 258 ) 259 self.task_embedding_t[full_layer_name] = nn.Embedding( 260 num_embeddings=1, embedding_dim=layer_output_dim 261 ) 262 263 self.batch_normalization: str | None = batch_normalization 264 r"""The way to use batch normalization after the fully-connected layers. This overrides the `batch_normalization` argument in `MLP` class. """ 265 266 # construct the batch normalization layers if needed 267 if self.batch_normalization == "independent": 268 self.fc_bns: nn.ModuleDict = nn.ModuleDict() # initially empty 269 r"""Independent batch normalization layers are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding batch normalization layers for the `nn.Linear`. We use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for purposes such as automatic device transfer and model summaries. 270 271 Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model.""" 272 self.original_fc_bn_state_dict: dict = deepcopy(self.fc_bn.state_dict()) 273 r"""The original batch normalization state dict as the source for creating new independent batch normalization layers. """ 274 275 def setup_task_id(self, task_id: int) -> None: 276 r"""Set up task `task_id`. This must be done before the `forward()` method is called. 277 278 **Args:** 279 - **task_id** (`int`): The target task ID. 280 """ 281 HATMaskBackbone.setup_task_id(self, task_id=task_id) 282 283 if self.batch_normalization == "independent": 284 if self.task_id not in self.fc_bns.keys(): 285 self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn) 286 287 def get_bn(self, stage: str, test_task_id: int | None) -> nn.Module | None: 288 r"""Get the batch normalization layer used in the `forward()` method for different stages. 289 290 **Args:** 291 - **stage** (`str`): The stage of the forward pass; one of: 292 1. 'train': training stage. 293 2. 'validation': validation stage. 294 3. 'test': testing stage. 295 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 296 297 **Returns:** 298 - **fc_bn** (`nn.Module` | `None`): The batch normalization module. 299 """ 300 if self.batch_normalization == "independent" and stage == "test": 301 return self.fc_bns[f"{test_task_id}"] 302 else: 303 return self.fc_bn 304 305 def initialize_independent_bn(self) -> None: 306 r"""Initialize the independent batch normalization layer for the current task. This is called when a new task is created. Applies only when `batch_normalization` is 'independent'.""" 307 308 if self.batch_normalization == "independent": 309 self.fc_bn.load_state_dict(self.original_fc_bn_state_dict) 310 311 def store_bn(self) -> None: 312 r"""Store the batch normalization layer for the current task `self.task_id`. Applies only when `batch_normalization` is 'independent'.""" 313 314 if self.batch_normalization == "independent": 315 self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn) 316 317 def forward( 318 self, 319 input: Tensor, 320 stage: str, 321 s_max: float | None = None, 322 batch_idx: int | None = None, 323 num_batches: int | None = None, 324 test_task_id: int | None = None, 325 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 326 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer. 327 328 **Args:** 329 - **input** (`Tensor`): The input tensor from data. 330 - **stage** (`str`): The stage of the forward pass; one of: 331 1. 'train': training stage. 332 2. 'validation': validation stage. 333 3. 'test': testing stage. 334 - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 335 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 336 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 337 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 338 339 **Returns:** 340 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 341 - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 342 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class. 343 """ 344 batch_size = input.size(0) 345 activations = {} 346 347 mask = self.get_mask( 348 stage=stage, 349 s_max=s_max, 350 batch_idx=batch_idx, 351 num_batches=num_batches, 352 test_task_id=test_task_id, 353 ) 354 if self.batch_normalization: 355 fc_bn = self.get_bn(stage=stage, test_task_id=test_task_id) 356 x = input.view(batch_size, -1) # flatten before going through MLP 357 358 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 359 360 x = self.fc[layer_idx](x) # fully-connected layer first 361 if self.batch_normalization: 362 x = fc_bn[layer_idx](x) # batch normalization second 363 x = x * mask[f"fc/{layer_idx}"] # apply the mask to the parameters second 364 365 if self.activation: 366 x = self.fc_activation[layer_idx](x) # activation function third 367 activations[layer_name] = x # store the hidden feature 368 if self.dropout: 369 x = self.fc_dropout[layer_idx](x) # dropout last 370 371 output_feature = x 372 373 return output_feature, mask, activations 374 375 376class WSNMaskMLP(MLP, WSNMaskBackbone): 377 r"""[WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) masked multi-layer perceptron (MLP). 378 379 [WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) is an architecture-based continual learning algorithm. It trains learnable parameter-wise importance and selects the most important $c\%$ of the network parameters to be used for each task. 380 381 MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads. 382 383 The mask is applied to the weights and biases in each fully connected layer. The mask is generated from the parameter-wise score and the gate function. 384 """ 385 386 def __init__( 387 self, 388 input_dim: int, 389 hidden_dims: list[int], 390 output_dim: int, 391 activation_layer: nn.Module | None = nn.ReLU, 392 bias: bool = True, 393 dropout: float | None = None, 394 ) -> None: 395 r"""Construct and initialize the WSN-masked MLP backbone network with task embeddings. 396 397 **Args:** 398 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. 399 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 400 - **output_dim** (`int`): The output dimension that connects to CL output heads. 401 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 402 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 403 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 404 """ 405 # init from both inherited classes 406 super().__init__( 407 input_dim=input_dim, 408 hidden_dims=hidden_dims, 409 output_dim=output_dim, 410 activation_layer=activation_layer, 411 batch_normalization=False, 412 bias=bias, 413 dropout=dropout, 414 ) 415 416 # construct the parameter score for each weighted layer 417 for layer_idx in range(self.num_fc_layers): 418 full_layer_name = f"fc/{layer_idx}" 419 layer_input_dim = ( 420 input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1] 421 ) 422 layer_output_dim = ( 423 hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim 424 ) 425 self.weight_score_t[full_layer_name] = nn.Embedding( 426 num_embeddings=layer_output_dim, 427 embedding_dim=layer_input_dim, 428 ) 429 self.bias_score_t[full_layer_name] = nn.Embedding( 430 num_embeddings=1, 431 embedding_dim=layer_output_dim, 432 ) 433 434 def forward( 435 self, 436 input: Tensor, 437 stage: str, 438 mask_percentage: float, 439 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 440 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]: 441 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer. 442 443 **Args:** 444 - **input** (`Tensor`): The input tensor from data. 445 - **stage** (`str`): The stage of the forward pass; one of: 446 1. 'train': training stage. 447 2. 'validation': validation stage. 448 3. 'test': testing stage. 449 - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1. 450 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`. 451 452 **Returns:** 453 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 454 - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. 455 - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 456 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features for various purposes. 457 """ 458 batch_size = input.size(0) 459 activations = {} 460 461 weight_mask, bias_mask = self.get_mask( 462 stage, 463 mask_percentage=mask_percentage, 464 test_mask=test_mask, 465 ) 466 467 x = input.view(batch_size, -1) # flatten before going through MLP 468 469 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 470 weighted_layer = self.fc[layer_idx] 471 weight = weighted_layer.weight 472 bias = weighted_layer.bias 473 474 # mask the weight and bias 475 masked_weight = weight * weight_mask[f"fc/{layer_idx}"] 476 if bias is not None and bias_mask[f"fc/{layer_idx}"] is not None: 477 masked_bias = bias * bias_mask[f"fc/{layer_idx}"] 478 else: 479 masked_bias = None 480 481 # do the forward pass using the masked weight and bias. Do not modify the weight and bias data in the original layer object or it will lose the computation graph. 482 x = nn.functional.linear(x, masked_weight, masked_bias) 483 484 if self.activation: 485 x = self.fc_activation[layer_idx](x) # activation function third 486 activations[layer_name] = x # store the hidden feature 487 if self.dropout: 488 x = self.fc_dropout[layer_idx](x) # dropout last 489 490 output_feature = x 491 492 return output_feature, weight_mask, bias_mask, activations 493 494 495# r""" 496# The submodule in `backbones` for [NISPA (Neuro-Inspired Stability-Plasticity Adaptation)](https://proceedings.mlr.press/v162/gurbuz22a/gurbuz22a.pdf) masked MLP backbone network. 497# """ 498 499# __all__ = ["NISPAMaskMLP"] 500 501# from torch import Tensor, nn 502 503# from clarena.backbones import MLP, NISPAMaskBackbone 504 505 506# class NISPAMaskMLP(MLP, NISPAMaskBackbone): 507# r"""[NISPA (Neuro-Inspired Stability-Plasticity Adaptation)](https://proceedings.mlr.press/v162/gurbuz22a/gurbuz22a.pdf) masked multi-Layer perceptron (MLP). 508 509# [NISPA (Neuro-Inspired Stability-Plasticity Adaptation)](https://proceedings.mlr.press/v162/gurbuz22a/gurbuz22a.pdf) is an architecture-based continual learning algorithm. It 510 511# MLP is a dense network architecture, which has several fully-connected layers, each followed by an activation function. The last layer connects to the CL output heads. 512 513# Mask is applied to the weights and biases in each fully-connected layer. The mask is generated from the parameter-wise score and gate function. 514# """ 515 516# def __init__( 517# self, 518# input_dim: int, 519# hidden_dims: list[int], 520# output_dim: int, 521# activation_layer: nn.Module | None = nn.ReLU, 522# bias: bool = True, 523# dropout: float | None = None, 524# ) -> None: 525# r"""Construct and initialize the WSN masked MLP backbone network with task embedding. Note that batch normalization is incompatible with WSN mechanism. 526 527# **Args:** 528# - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. 529# - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension. 530# - **output_dim** (`int`): the output dimension which connects to CL output heads. 531# - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`. 532# - **bias** (`bool`): whether to use bias in the linear layer. Default `True`. 533# - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`. 534# """ 535# # init from both inherited classes 536# super().__init__( 537# input_dim=input_dim, 538# hidden_dims=hidden_dims, 539# output_dim=output_dim, 540# activation_layer=activation_layer, 541# batch_normalization=False, # batch normalization is incompatible with HAT mechanism 542# bias=bias, 543# dropout=dropout, 544# ) 545# self.register_wsn_mask_module_explicitly() # register all `nn.Module`s for WSNMaskBackbone explicitly because the second `__init__()` wipes out them inited by the first `__init__()` 546 547# # construct the parameter score for each weighted layer 548# for layer_idx in range(self.num_fc_layers): 549# full_layer_name = f"fc/{layer_idx}" 550# layer_input_dim = ( 551# input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1] 552# ) 553# layer_output_dim = ( 554# hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim 555# ) 556# self.weight_score_t[full_layer_name] = nn.Embedding( 557# num_embeddings=layer_output_dim, 558# embedding_dim=layer_input_dim, 559# ) 560# self.bias_score_t[full_layer_name] = nn.Embedding( 561# num_embeddings=1, 562# embedding_dim=layer_output_dim, 563# ) 564 565# def forward( 566# self, 567# input: Tensor, 568# stage: str, 569# mask_percentage: float, 570# test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 571# ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]: 572# r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units which are neurons in each fully-connected layer. 573 574# **Args:** 575# - **input** (`Tensor`): The input tensor from data. 576# - **stage** (`str`): the stage of the forward pass; one of: 577# 1. 'train': training stage. 578# 2. 'validation': validation stage. 579# 3. 'test': testing stage. 580# - **mask_percentage** (`float`): the percentage of parameters to be masked. The value should be between 0 and 1. 581# - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): the binary weight and bias mask used for test. Applies only to testing stage. For other stages, it is default `None`. 582 583# **Returns:** 584# - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation. 585# - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight. 586# - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`. 587# - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 588# """ 589# batch_size = input.size(0) 590# activations = {} 591 592# weight_mask, bias_mask = self.get_mask( 593# stage, 594# mask_percentage=mask_percentage, 595# test_mask=test_mask, 596# ) 597 598# x = input.view(batch_size, -1) # flatten before going through MLP 599 600# for layer_idx, layer_name in enumerate(self.weighted_layer_names): 601# weighted_layer = self.fc[layer_idx] 602# weight = weighted_layer.weight 603# bias = weighted_layer.bias 604 605# # mask the weight and bias 606# masked_weight = weight * weight_mask[f"fc/{layer_idx}"] 607# if bias is not None and bias_mask[f"fc/{layer_idx}"] is not None: 608# masked_bias = bias * bias_mask[f"fc/{layer_idx}"] 609# else: 610# masked_bias = None 611 612# # do the forward pass using the masked weight and bias. Do not modify the weight and bias data in the original layer object or it will lose the computation graph. 613# x = nn.functional.linear(x, masked_weight, masked_bias) 614 615# if self.activation: 616# x = self.fc_activation[layer_idx](x) # activation function third 617# activations[layer_name] = x # store the hidden feature 618# if self.dropout: 619# x = self.fc_dropout[layer_idx](x) # dropout last 620 621# output_feature = x 622 623# return output_feature, weight_mask, bias_mask, activations 624# return output_feature, weight_mask, bias_mask, activations
19class MLP(Backbone): 20 """Multi-layer perceptron (MLP), a.k.a. fully connected network. 21 22 MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the output heads. 23 """ 24 25 def __init__( 26 self, 27 input_dim: int, 28 hidden_dims: list[int], 29 output_dim: int, 30 activation_layer: nn.Module | None = nn.ReLU, 31 batch_normalization: bool = False, 32 bias: bool = True, 33 dropout: float | None = None, 34 **kwargs, 35 ) -> None: 36 r"""Construct and initialize the MLP backbone network. 37 38 **Args:** 39 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. Note that it is not required in convolutional networks. 40 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 41 - **output_dim** (`int`): The output dimension that connects to output heads. 42 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 43 - **batch_normalization** (`bool`): Whether to use batch normalization after the fully connected layers. Default `False`. 44 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 45 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. 46 - **kwargs**: Reserved for multiple inheritance. 47 """ 48 super().__init__(output_dim=output_dim, **kwargs) 49 50 self.input_dim: int = input_dim 51 r"""The input dimension of the MLP backbone network.""" 52 self.hidden_dims: list[int] = hidden_dims 53 r"""The hidden dimensions of the MLP backbone network.""" 54 self.output_dim: int = output_dim 55 r"""The output dimension of the MLP backbone network.""" 56 57 self.num_fc_layers: int = len(hidden_dims) + 1 58 r"""The number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass.""" 59 self.batch_normalization: bool = batch_normalization 60 r"""Whether to use batch normalization after the fully-connected layers.""" 61 self.activation: bool = activation_layer is not None 62 r"""Whether to use activation function after the fully-connected layers.""" 63 self.dropout: bool = dropout is not None 64 r"""Whether to use dropout after the fully-connected layers.""" 65 66 self.fc: nn.ModuleList = nn.ModuleList() 67 r"""The list of fully connected (`nn.Linear`) layers.""" 68 if self.batch_normalization: 69 self.fc_bn: nn.ModuleList = nn.ModuleList() 70 r"""The list of batch normalization (`nn.BatchNorm1d`) layers after the fully connected layers.""" 71 if self.activation: 72 self.fc_activation: nn.ModuleList = nn.ModuleList() 73 r"""The list of activation layers after the fully connected layers.""" 74 if self.dropout: 75 self.fc_dropout: nn.ModuleList = nn.ModuleList() 76 r"""The list of dropout layers after the fully connected layers.""" 77 78 # construct the weighted fully connected layers and attached layers (batch norm, activation, dropout, etc.) in a loop 79 for layer_idx in range(self.num_fc_layers): 80 81 # the input and output dim of the current weighted layer 82 layer_input_dim = ( 83 self.input_dim if layer_idx == 0 else self.hidden_dims[layer_idx - 1] 84 ) 85 layer_output_dim = ( 86 self.hidden_dims[layer_idx] 87 if layer_idx != len(self.hidden_dims) 88 else self.output_dim 89 ) 90 91 # construct the fully connected layer 92 self.fc.append( 93 nn.Linear( 94 in_features=layer_input_dim, 95 out_features=layer_output_dim, 96 bias=bias, 97 ) 98 ) 99 100 # update the weighted layer names 101 full_layer_name = f"fc/{layer_idx}" 102 self.weighted_layer_names.append(full_layer_name) 103 104 # construct the batch normalization layer 105 if self.batch_normalization: 106 self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim))) 107 108 # construct the activation layer 109 if self.activation: 110 self.fc_activation.append(activation_layer()) 111 112 # construct the dropout layer 113 if self.dropout: 114 self.fc_dropout.append(nn.Dropout(dropout)) 115 116 def forward( 117 self, input: Tensor, stage: str = None 118 ) -> tuple[Tensor, dict[str, Tensor]]: 119 r"""The forward pass for data. It is the same for all tasks. 120 121 **Args:** 122 - **input** (`Tensor`): The input tensor from data. 123 124 **Returns:** 125 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 126 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for certain algorithms that need to use hidden features for various purposes. 127 """ 128 batch_size = input.size(0) 129 activations = {} 130 131 x = input.view(batch_size, -1) # flatten before going through MLP 132 133 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 134 x = self.fc[layer_idx](x) # fully-connected layer first 135 if self.batch_normalization: 136 x = self.fc_bn[layer_idx]( 137 x 138 ) # batch normalization can be before or after activation. We put it before activation here 139 if self.activation: 140 x = self.fc_activation[layer_idx](x) # activation function third 141 activations[layer_name] = x # store the hidden feature 142 if self.dropout: 143 x = self.fc_dropout[layer_idx](x) # dropout last 144 145 output_feature = x 146 147 return output_feature, activations
Multi-layer perceptron (MLP), a.k.a. fully connected network.
MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the output heads.
25 def __init__( 26 self, 27 input_dim: int, 28 hidden_dims: list[int], 29 output_dim: int, 30 activation_layer: nn.Module | None = nn.ReLU, 31 batch_normalization: bool = False, 32 bias: bool = True, 33 dropout: float | None = None, 34 **kwargs, 35 ) -> None: 36 r"""Construct and initialize the MLP backbone network. 37 38 **Args:** 39 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. Note that it is not required in convolutional networks. 40 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 41 - **output_dim** (`int`): The output dimension that connects to output heads. 42 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 43 - **batch_normalization** (`bool`): Whether to use batch normalization after the fully connected layers. Default `False`. 44 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 45 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. 46 - **kwargs**: Reserved for multiple inheritance. 47 """ 48 super().__init__(output_dim=output_dim, **kwargs) 49 50 self.input_dim: int = input_dim 51 r"""The input dimension of the MLP backbone network.""" 52 self.hidden_dims: list[int] = hidden_dims 53 r"""The hidden dimensions of the MLP backbone network.""" 54 self.output_dim: int = output_dim 55 r"""The output dimension of the MLP backbone network.""" 56 57 self.num_fc_layers: int = len(hidden_dims) + 1 58 r"""The number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass.""" 59 self.batch_normalization: bool = batch_normalization 60 r"""Whether to use batch normalization after the fully-connected layers.""" 61 self.activation: bool = activation_layer is not None 62 r"""Whether to use activation function after the fully-connected layers.""" 63 self.dropout: bool = dropout is not None 64 r"""Whether to use dropout after the fully-connected layers.""" 65 66 self.fc: nn.ModuleList = nn.ModuleList() 67 r"""The list of fully connected (`nn.Linear`) layers.""" 68 if self.batch_normalization: 69 self.fc_bn: nn.ModuleList = nn.ModuleList() 70 r"""The list of batch normalization (`nn.BatchNorm1d`) layers after the fully connected layers.""" 71 if self.activation: 72 self.fc_activation: nn.ModuleList = nn.ModuleList() 73 r"""The list of activation layers after the fully connected layers.""" 74 if self.dropout: 75 self.fc_dropout: nn.ModuleList = nn.ModuleList() 76 r"""The list of dropout layers after the fully connected layers.""" 77 78 # construct the weighted fully connected layers and attached layers (batch norm, activation, dropout, etc.) in a loop 79 for layer_idx in range(self.num_fc_layers): 80 81 # the input and output dim of the current weighted layer 82 layer_input_dim = ( 83 self.input_dim if layer_idx == 0 else self.hidden_dims[layer_idx - 1] 84 ) 85 layer_output_dim = ( 86 self.hidden_dims[layer_idx] 87 if layer_idx != len(self.hidden_dims) 88 else self.output_dim 89 ) 90 91 # construct the fully connected layer 92 self.fc.append( 93 nn.Linear( 94 in_features=layer_input_dim, 95 out_features=layer_output_dim, 96 bias=bias, 97 ) 98 ) 99 100 # update the weighted layer names 101 full_layer_name = f"fc/{layer_idx}" 102 self.weighted_layer_names.append(full_layer_name) 103 104 # construct the batch normalization layer 105 if self.batch_normalization: 106 self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim))) 107 108 # construct the activation layer 109 if self.activation: 110 self.fc_activation.append(activation_layer()) 111 112 # construct the dropout layer 113 if self.dropout: 114 self.fc_dropout.append(nn.Dropout(dropout))
Construct and initialize the MLP backbone network.
Args:
- input_dim (
int): The input dimension. Any data need to be flattened before entering the MLP. Note that it is not required in convolutional networks. - hidden_dims (
list[int]): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. - output_dim (
int): The output dimension that connects to output heads. - activation_layer (
nn.Module|None): Activation function of each layer (if notNone). IfNone, this layer won't be used. Defaultnn.ReLU. - batch_normalization (
bool): Whether to use batch normalization after the fully connected layers. DefaultFalse. - bias (
bool): Whether to use bias in the linear layer. DefaultTrue. - dropout (
float|None): The probability for the dropout layer. IfNone, this layer won't be used. DefaultNone. - kwargs: Reserved for multiple inheritance.
The number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass.
116 def forward( 117 self, input: Tensor, stage: str = None 118 ) -> tuple[Tensor, dict[str, Tensor]]: 119 r"""The forward pass for data. It is the same for all tasks. 120 121 **Args:** 122 - **input** (`Tensor`): The input tensor from data. 123 124 **Returns:** 125 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 126 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for certain algorithms that need to use hidden features for various purposes. 127 """ 128 batch_size = input.size(0) 129 activations = {} 130 131 x = input.view(batch_size, -1) # flatten before going through MLP 132 133 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 134 x = self.fc[layer_idx](x) # fully-connected layer first 135 if self.batch_normalization: 136 x = self.fc_bn[layer_idx]( 137 x 138 ) # batch normalization can be before or after activation. We put it before activation here 139 if self.activation: 140 x = self.fc_activation[layer_idx](x) # activation function third 141 activations[layer_name] = x # store the hidden feature 142 if self.dropout: 143 x = self.fc_dropout[layer_idx](x) # dropout last 144 145 output_feature = x 146 147 return output_feature, activations
The forward pass for data. It is the same for all tasks.
Args:
- input (
Tensor): The input tensor from data.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for certain algorithms that need to use hidden features for various purposes.
150class CLMLP(CLBackbone, MLP): 151 """Multi-layer perceptron (MLP), a.k.a. fully connected network. Used as a continual learning backbone. 152 153 MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads. 154 """ 155 156 def __init__( 157 self, 158 input_dim: int, 159 hidden_dims: list[int], 160 output_dim: int, 161 activation_layer: nn.Module | None = nn.ReLU, 162 batch_normalization: bool = False, 163 bias: bool = True, 164 dropout: float | None = None, 165 **kwargs, 166 ) -> None: 167 r"""Construct and initialize the CLMLP backbone network. 168 169 **Args:** 170 - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks. 171 - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension. 172 - **output_dim** (`int`): the output dimension which connects to CL output heads. 173 - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`. 174 - **batch_normalization** (`bool`): whether to use batch normalization after the fully-connected layers. Default `False`. 175 - **bias** (`bool`): whether to use bias in the linear layer. Default `True`. 176 - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 177 """ 178 super().__init__( 179 input_dim=input_dim, 180 hidden_dims=hidden_dims, 181 output_dim=output_dim, 182 activation_layer=activation_layer, 183 batch_normalization=batch_normalization, 184 bias=bias, 185 dropout=dropout, 186 **kwargs, 187 ) 188 189 def forward( 190 self, input: Tensor, stage: str = None, task_id: int | None = None 191 ) -> tuple[Tensor, dict[str, Tensor]]: 192 r"""The forward pass for data. It is the same for all tasks. 193 194 **Args:** 195 - **input** (`Tensor`): The input tensor from data. 196 197 **Returns:** 198 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 199 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name; value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes. 200 """ 201 return MLP.forward(self, input, stage) # call the MLP forward method
Multi-layer perceptron (MLP), a.k.a. fully connected network. Used as a continual learning backbone.
MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
156 def __init__( 157 self, 158 input_dim: int, 159 hidden_dims: list[int], 160 output_dim: int, 161 activation_layer: nn.Module | None = nn.ReLU, 162 batch_normalization: bool = False, 163 bias: bool = True, 164 dropout: float | None = None, 165 **kwargs, 166 ) -> None: 167 r"""Construct and initialize the CLMLP backbone network. 168 169 **Args:** 170 - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks. 171 - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension. 172 - **output_dim** (`int`): the output dimension which connects to CL output heads. 173 - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`. 174 - **batch_normalization** (`bool`): whether to use batch normalization after the fully-connected layers. Default `False`. 175 - **bias** (`bool`): whether to use bias in the linear layer. Default `True`. 176 - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 177 """ 178 super().__init__( 179 input_dim=input_dim, 180 hidden_dims=hidden_dims, 181 output_dim=output_dim, 182 activation_layer=activation_layer, 183 batch_normalization=batch_normalization, 184 bias=bias, 185 dropout=dropout, 186 **kwargs, 187 )
Construct and initialize the CLMLP backbone network.
Args:
- input_dim (
int): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks. - hidden_dims (
list[int]): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension. - output_dim (
int): the output dimension which connects to CL output heads. - activation_layer (
nn.Module|None): activation function of each layer (if notNone), ifNonethis layer won't be used. Defaultnn.ReLU. - batch_normalization (
bool): whether to use batch normalization after the fully-connected layers. DefaultFalse. - bias (
bool): whether to use bias in the linear layer. DefaultTrue. - dropout (
float|None): the probability for the dropout layer, ifNonethis layer won't be used. DefaultNone. - kwargs: Reserved for multiple inheritance.
189 def forward( 190 self, input: Tensor, stage: str = None, task_id: int | None = None 191 ) -> tuple[Tensor, dict[str, Tensor]]: 192 r"""The forward pass for data. It is the same for all tasks. 193 194 **Args:** 195 - **input** (`Tensor`): The input tensor from data. 196 197 **Returns:** 198 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 199 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name; value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes. 200 """ 201 return MLP.forward(self, input, stage) # call the MLP forward method
The forward pass for data. It is the same for all tasks.
Args:
- input (
Tensor): The input tensor from data.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name; value (Tensor) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
Inherited Members
204class HATMaskMLP(HATMaskBackbone, MLP): 205 r"""HAT-masked multi-layer perceptron (MLP). 206 207 [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters. 208 209 MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads. 210 211 The mask is applied to units (neurons) in each fully connected layer. The mask is generated from the neuron-wise task embedding and the gate function. 212 """ 213 214 def __init__( 215 self, 216 input_dim: int, 217 hidden_dims: list[int], 218 output_dim: int, 219 gate: str, 220 activation_layer: nn.Module | None = nn.ReLU, 221 batch_normalization: str | None = None, 222 bias: bool = True, 223 dropout: float | None = None, 224 ) -> None: 225 r"""Construct and initialize the HAT-masked MLP backbone network with task embeddings. Note that batch normalization is incompatible with the HAT mechanism. 226 227 **Args:** 228 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. 229 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 230 - **output_dim** (`int`): The output dimension that connects to CL output heads. 231 - **gate** (`str`): The type of gate function turning real-valued task embeddings into attention masks; one of: 232 - `sigmoid`: the sigmoid function. 233 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 234 - **batch_normalization** (`str` | `None`): How to use batch normalization after the fully connected layers; one of: 235 - `None`: no batch normalization layers. 236 - `shared`: use a single batch normalization layer for all tasks. Note that this can cause catastrophic forgetting. 237 - `independent`: use independent batch normalization layers for each task. 238 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 239 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 240 """ 241 super().__init__( 242 output_dim=output_dim, 243 gate=gate, 244 input_dim=input_dim, 245 hidden_dims=hidden_dims, 246 activation_layer=activation_layer, 247 batch_normalization=( 248 True if batch_normalization == "shared" or "independent" else False 249 ), 250 bias=bias, 251 dropout=dropout, 252 ) 253 254 # construct the task embedding for each weighted layer 255 for layer_idx in range(self.num_fc_layers): 256 full_layer_name = f"fc/{layer_idx}" 257 layer_output_dim = ( 258 hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim 259 ) 260 self.task_embedding_t[full_layer_name] = nn.Embedding( 261 num_embeddings=1, embedding_dim=layer_output_dim 262 ) 263 264 self.batch_normalization: str | None = batch_normalization 265 r"""The way to use batch normalization after the fully-connected layers. This overrides the `batch_normalization` argument in `MLP` class. """ 266 267 # construct the batch normalization layers if needed 268 if self.batch_normalization == "independent": 269 self.fc_bns: nn.ModuleDict = nn.ModuleDict() # initially empty 270 r"""Independent batch normalization layers are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding batch normalization layers for the `nn.Linear`. We use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for purposes such as automatic device transfer and model summaries. 271 272 Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model.""" 273 self.original_fc_bn_state_dict: dict = deepcopy(self.fc_bn.state_dict()) 274 r"""The original batch normalization state dict as the source for creating new independent batch normalization layers. """ 275 276 def setup_task_id(self, task_id: int) -> None: 277 r"""Set up task `task_id`. This must be done before the `forward()` method is called. 278 279 **Args:** 280 - **task_id** (`int`): The target task ID. 281 """ 282 HATMaskBackbone.setup_task_id(self, task_id=task_id) 283 284 if self.batch_normalization == "independent": 285 if self.task_id not in self.fc_bns.keys(): 286 self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn) 287 288 def get_bn(self, stage: str, test_task_id: int | None) -> nn.Module | None: 289 r"""Get the batch normalization layer used in the `forward()` method for different stages. 290 291 **Args:** 292 - **stage** (`str`): The stage of the forward pass; one of: 293 1. 'train': training stage. 294 2. 'validation': validation stage. 295 3. 'test': testing stage. 296 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 297 298 **Returns:** 299 - **fc_bn** (`nn.Module` | `None`): The batch normalization module. 300 """ 301 if self.batch_normalization == "independent" and stage == "test": 302 return self.fc_bns[f"{test_task_id}"] 303 else: 304 return self.fc_bn 305 306 def initialize_independent_bn(self) -> None: 307 r"""Initialize the independent batch normalization layer for the current task. This is called when a new task is created. Applies only when `batch_normalization` is 'independent'.""" 308 309 if self.batch_normalization == "independent": 310 self.fc_bn.load_state_dict(self.original_fc_bn_state_dict) 311 312 def store_bn(self) -> None: 313 r"""Store the batch normalization layer for the current task `self.task_id`. Applies only when `batch_normalization` is 'independent'.""" 314 315 if self.batch_normalization == "independent": 316 self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn) 317 318 def forward( 319 self, 320 input: Tensor, 321 stage: str, 322 s_max: float | None = None, 323 batch_idx: int | None = None, 324 num_batches: int | None = None, 325 test_task_id: int | None = None, 326 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 327 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer. 328 329 **Args:** 330 - **input** (`Tensor`): The input tensor from data. 331 - **stage** (`str`): The stage of the forward pass; one of: 332 1. 'train': training stage. 333 2. 'validation': validation stage. 334 3. 'test': testing stage. 335 - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 336 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 337 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 338 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 339 340 **Returns:** 341 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 342 - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 343 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class. 344 """ 345 batch_size = input.size(0) 346 activations = {} 347 348 mask = self.get_mask( 349 stage=stage, 350 s_max=s_max, 351 batch_idx=batch_idx, 352 num_batches=num_batches, 353 test_task_id=test_task_id, 354 ) 355 if self.batch_normalization: 356 fc_bn = self.get_bn(stage=stage, test_task_id=test_task_id) 357 x = input.view(batch_size, -1) # flatten before going through MLP 358 359 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 360 361 x = self.fc[layer_idx](x) # fully-connected layer first 362 if self.batch_normalization: 363 x = fc_bn[layer_idx](x) # batch normalization second 364 x = x * mask[f"fc/{layer_idx}"] # apply the mask to the parameters second 365 366 if self.activation: 367 x = self.fc_activation[layer_idx](x) # activation function third 368 activations[layer_name] = x # store the hidden feature 369 if self.dropout: 370 x = self.fc_dropout[layer_idx](x) # dropout last 371 372 output_feature = x 373 374 return output_feature, mask, activations
HAT-masked multi-layer perceptron (MLP).
HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
The mask is applied to units (neurons) in each fully connected layer. The mask is generated from the neuron-wise task embedding and the gate function.
214 def __init__( 215 self, 216 input_dim: int, 217 hidden_dims: list[int], 218 output_dim: int, 219 gate: str, 220 activation_layer: nn.Module | None = nn.ReLU, 221 batch_normalization: str | None = None, 222 bias: bool = True, 223 dropout: float | None = None, 224 ) -> None: 225 r"""Construct and initialize the HAT-masked MLP backbone network with task embeddings. Note that batch normalization is incompatible with the HAT mechanism. 226 227 **Args:** 228 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. 229 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 230 - **output_dim** (`int`): The output dimension that connects to CL output heads. 231 - **gate** (`str`): The type of gate function turning real-valued task embeddings into attention masks; one of: 232 - `sigmoid`: the sigmoid function. 233 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 234 - **batch_normalization** (`str` | `None`): How to use batch normalization after the fully connected layers; one of: 235 - `None`: no batch normalization layers. 236 - `shared`: use a single batch normalization layer for all tasks. Note that this can cause catastrophic forgetting. 237 - `independent`: use independent batch normalization layers for each task. 238 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 239 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 240 """ 241 super().__init__( 242 output_dim=output_dim, 243 gate=gate, 244 input_dim=input_dim, 245 hidden_dims=hidden_dims, 246 activation_layer=activation_layer, 247 batch_normalization=( 248 True if batch_normalization == "shared" or "independent" else False 249 ), 250 bias=bias, 251 dropout=dropout, 252 ) 253 254 # construct the task embedding for each weighted layer 255 for layer_idx in range(self.num_fc_layers): 256 full_layer_name = f"fc/{layer_idx}" 257 layer_output_dim = ( 258 hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim 259 ) 260 self.task_embedding_t[full_layer_name] = nn.Embedding( 261 num_embeddings=1, embedding_dim=layer_output_dim 262 ) 263 264 self.batch_normalization: str | None = batch_normalization 265 r"""The way to use batch normalization after the fully-connected layers. This overrides the `batch_normalization` argument in `MLP` class. """ 266 267 # construct the batch normalization layers if needed 268 if self.batch_normalization == "independent": 269 self.fc_bns: nn.ModuleDict = nn.ModuleDict() # initially empty 270 r"""Independent batch normalization layers are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding batch normalization layers for the `nn.Linear`. We use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for purposes such as automatic device transfer and model summaries. 271 272 Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model.""" 273 self.original_fc_bn_state_dict: dict = deepcopy(self.fc_bn.state_dict()) 274 r"""The original batch normalization state dict as the source for creating new independent batch normalization layers. """
Construct and initialize the HAT-masked MLP backbone network with task embeddings. Note that batch normalization is incompatible with the HAT mechanism.
Args:
- input_dim (
int): The input dimension. Any data need to be flattened before entering the MLP. - hidden_dims (
list[int]): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. - output_dim (
int): The output dimension that connects to CL output heads. - gate (
str): The type of gate function turning real-valued task embeddings into attention masks; one of:sigmoid: the sigmoid function.
- activation_layer (
nn.Module|None): Activation function of each layer (if notNone). IfNone, this layer won't be used. Defaultnn.ReLU. - batch_normalization (
str|None): How to use batch normalization after the fully connected layers; one of:None: no batch normalization layers.shared: use a single batch normalization layer for all tasks. Note that this can cause catastrophic forgetting.independent: use independent batch normalization layers for each task.
- bias (
bool): Whether to use bias in the linear layer. DefaultTrue. - dropout (
float|None): The probability for the dropout layer. IfNone, this layer won't be used. DefaultNone. - kwargs: Reserved for multiple inheritance.
The way to use batch normalization after the fully-connected layers. This overrides the batch_normalization argument in MLP class.
276 def setup_task_id(self, task_id: int) -> None: 277 r"""Set up task `task_id`. This must be done before the `forward()` method is called. 278 279 **Args:** 280 - **task_id** (`int`): The target task ID. 281 """ 282 HATMaskBackbone.setup_task_id(self, task_id=task_id) 283 284 if self.batch_normalization == "independent": 285 if self.task_id not in self.fc_bns.keys(): 286 self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn)
288 def get_bn(self, stage: str, test_task_id: int | None) -> nn.Module | None: 289 r"""Get the batch normalization layer used in the `forward()` method for different stages. 290 291 **Args:** 292 - **stage** (`str`): The stage of the forward pass; one of: 293 1. 'train': training stage. 294 2. 'validation': validation stage. 295 3. 'test': testing stage. 296 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 297 298 **Returns:** 299 - **fc_bn** (`nn.Module` | `None`): The batch normalization module. 300 """ 301 if self.batch_normalization == "independent" and stage == "test": 302 return self.fc_bns[f"{test_task_id}"] 303 else: 304 return self.fc_bn
Get the batch normalization layer used in the forward() method for different stages.
Args:
- stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- test_task_id (
int|None): The test task ID. Applies only to the testing stage. For other stages, it isNone.
Returns:
- fc_bn (
nn.Module|None): The batch normalization module.
306 def initialize_independent_bn(self) -> None: 307 r"""Initialize the independent batch normalization layer for the current task. This is called when a new task is created. Applies only when `batch_normalization` is 'independent'.""" 308 309 if self.batch_normalization == "independent": 310 self.fc_bn.load_state_dict(self.original_fc_bn_state_dict)
Initialize the independent batch normalization layer for the current task. This is called when a new task is created. Applies only when batch_normalization is 'independent'.
312 def store_bn(self) -> None: 313 r"""Store the batch normalization layer for the current task `self.task_id`. Applies only when `batch_normalization` is 'independent'.""" 314 315 if self.batch_normalization == "independent": 316 self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn)
Store the batch normalization layer for the current task self.task_id. Applies only when batch_normalization is 'independent'.
318 def forward( 319 self, 320 input: Tensor, 321 stage: str, 322 s_max: float | None = None, 323 batch_idx: int | None = None, 324 num_batches: int | None = None, 325 test_task_id: int | None = None, 326 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]: 327 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer. 328 329 **Args:** 330 - **input** (`Tensor`): The input tensor from data. 331 - **stage** (`str`): The stage of the forward pass; one of: 332 1. 'train': training stage. 333 2. 'validation': validation stage. 334 3. 'test': testing stage. 335 - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 336 - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`. 337 - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`. 338 - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`. 339 340 **Returns:** 341 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 342 - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 343 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class. 344 """ 345 batch_size = input.size(0) 346 activations = {} 347 348 mask = self.get_mask( 349 stage=stage, 350 s_max=s_max, 351 batch_idx=batch_idx, 352 num_batches=num_batches, 353 test_task_id=test_task_id, 354 ) 355 if self.batch_normalization: 356 fc_bn = self.get_bn(stage=stage, test_task_id=test_task_id) 357 x = input.view(batch_size, -1) # flatten before going through MLP 358 359 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 360 361 x = self.fc[layer_idx](x) # fully-connected layer first 362 if self.batch_normalization: 363 x = fc_bn[layer_idx](x) # batch normalization second 364 x = x * mask[f"fc/{layer_idx}"] # apply the mask to the parameters second 365 366 if self.activation: 367 x = self.fc_activation[layer_idx](x) # activation function third 368 activations[layer_name] = x # store the hidden feature 369 if self.dropout: 370 x = self.fc_dropout[layer_idx](x) # dropout last 371 372 output_feature = x 373 374 return output_feature, mask, activations
The forward pass for data from task self.task_id. Task-specific masks for self.task_id are applied to units (neurons) in each fully connected layer.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- s_max (
float): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the HAT paper. - batch_idx (
int|None): The current batch index. Applies only to the training stage. For other stages, it isNone. - num_batches (
int|None): The total number of batches. Applies only to the training stage. For other stages, it isNone. - test_task_id (
int|None): The test task ID. Applies only to the testing stage. For other stages, it isNone.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - mask (
dict[str, Tensor]): The mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ). - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit thisforward()method of theHATclass.
Inherited Members
377class WSNMaskMLP(MLP, WSNMaskBackbone): 378 r"""[WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) masked multi-layer perceptron (MLP). 379 380 [WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) is an architecture-based continual learning algorithm. It trains learnable parameter-wise importance and selects the most important $c\%$ of the network parameters to be used for each task. 381 382 MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads. 383 384 The mask is applied to the weights and biases in each fully connected layer. The mask is generated from the parameter-wise score and the gate function. 385 """ 386 387 def __init__( 388 self, 389 input_dim: int, 390 hidden_dims: list[int], 391 output_dim: int, 392 activation_layer: nn.Module | None = nn.ReLU, 393 bias: bool = True, 394 dropout: float | None = None, 395 ) -> None: 396 r"""Construct and initialize the WSN-masked MLP backbone network with task embeddings. 397 398 **Args:** 399 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. 400 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 401 - **output_dim** (`int`): The output dimension that connects to CL output heads. 402 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 403 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 404 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 405 """ 406 # init from both inherited classes 407 super().__init__( 408 input_dim=input_dim, 409 hidden_dims=hidden_dims, 410 output_dim=output_dim, 411 activation_layer=activation_layer, 412 batch_normalization=False, 413 bias=bias, 414 dropout=dropout, 415 ) 416 417 # construct the parameter score for each weighted layer 418 for layer_idx in range(self.num_fc_layers): 419 full_layer_name = f"fc/{layer_idx}" 420 layer_input_dim = ( 421 input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1] 422 ) 423 layer_output_dim = ( 424 hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim 425 ) 426 self.weight_score_t[full_layer_name] = nn.Embedding( 427 num_embeddings=layer_output_dim, 428 embedding_dim=layer_input_dim, 429 ) 430 self.bias_score_t[full_layer_name] = nn.Embedding( 431 num_embeddings=1, 432 embedding_dim=layer_output_dim, 433 ) 434 435 def forward( 436 self, 437 input: Tensor, 438 stage: str, 439 mask_percentage: float, 440 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 441 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]: 442 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer. 443 444 **Args:** 445 - **input** (`Tensor`): The input tensor from data. 446 - **stage** (`str`): The stage of the forward pass; one of: 447 1. 'train': training stage. 448 2. 'validation': validation stage. 449 3. 'test': testing stage. 450 - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1. 451 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`. 452 453 **Returns:** 454 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 455 - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. 456 - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 457 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features for various purposes. 458 """ 459 batch_size = input.size(0) 460 activations = {} 461 462 weight_mask, bias_mask = self.get_mask( 463 stage, 464 mask_percentage=mask_percentage, 465 test_mask=test_mask, 466 ) 467 468 x = input.view(batch_size, -1) # flatten before going through MLP 469 470 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 471 weighted_layer = self.fc[layer_idx] 472 weight = weighted_layer.weight 473 bias = weighted_layer.bias 474 475 # mask the weight and bias 476 masked_weight = weight * weight_mask[f"fc/{layer_idx}"] 477 if bias is not None and bias_mask[f"fc/{layer_idx}"] is not None: 478 masked_bias = bias * bias_mask[f"fc/{layer_idx}"] 479 else: 480 masked_bias = None 481 482 # do the forward pass using the masked weight and bias. Do not modify the weight and bias data in the original layer object or it will lose the computation graph. 483 x = nn.functional.linear(x, masked_weight, masked_bias) 484 485 if self.activation: 486 x = self.fc_activation[layer_idx](x) # activation function third 487 activations[layer_name] = x # store the hidden feature 488 if self.dropout: 489 x = self.fc_dropout[layer_idx](x) # dropout last 490 491 output_feature = x 492 493 return output_feature, weight_mask, bias_mask, activations
WSN (Winning Subnetworks, 2022) masked multi-layer perceptron (MLP).
WSN (Winning Subnetworks, 2022) is an architecture-based continual learning algorithm. It trains learnable parameter-wise importance and selects the most important $c\%$ of the network parameters to be used for each task.
MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
The mask is applied to the weights and biases in each fully connected layer. The mask is generated from the parameter-wise score and the gate function.
387 def __init__( 388 self, 389 input_dim: int, 390 hidden_dims: list[int], 391 output_dim: int, 392 activation_layer: nn.Module | None = nn.ReLU, 393 bias: bool = True, 394 dropout: float | None = None, 395 ) -> None: 396 r"""Construct and initialize the WSN-masked MLP backbone network with task embeddings. 397 398 **Args:** 399 - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. 400 - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. 401 - **output_dim** (`int`): The output dimension that connects to CL output heads. 402 - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`. 403 - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`. 404 - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`. - **kwargs**: Reserved for multiple inheritance. 405 """ 406 # init from both inherited classes 407 super().__init__( 408 input_dim=input_dim, 409 hidden_dims=hidden_dims, 410 output_dim=output_dim, 411 activation_layer=activation_layer, 412 batch_normalization=False, 413 bias=bias, 414 dropout=dropout, 415 ) 416 417 # construct the parameter score for each weighted layer 418 for layer_idx in range(self.num_fc_layers): 419 full_layer_name = f"fc/{layer_idx}" 420 layer_input_dim = ( 421 input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1] 422 ) 423 layer_output_dim = ( 424 hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim 425 ) 426 self.weight_score_t[full_layer_name] = nn.Embedding( 427 num_embeddings=layer_output_dim, 428 embedding_dim=layer_input_dim, 429 ) 430 self.bias_score_t[full_layer_name] = nn.Embedding( 431 num_embeddings=1, 432 embedding_dim=layer_output_dim, 433 )
Construct and initialize the WSN-masked MLP backbone network with task embeddings.
Args:
- input_dim (
int): The input dimension. Any data need to be flattened before entering the MLP. - hidden_dims (
list[int]): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension. - output_dim (
int): The output dimension that connects to CL output heads. - activation_layer (
nn.Module|None): Activation function of each layer (if notNone). IfNone, this layer won't be used. Defaultnn.ReLU. - bias (
bool): Whether to use bias in the linear layer. DefaultTrue. - dropout (
float|None): The probability for the dropout layer. IfNone, this layer won't be used. DefaultNone. - kwargs: Reserved for multiple inheritance.
435 def forward( 436 self, 437 input: Tensor, 438 stage: str, 439 mask_percentage: float, 440 test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 441 ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]: 442 r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer. 443 444 **Args:** 445 - **input** (`Tensor`): The input tensor from data. 446 - **stage** (`str`): The stage of the forward pass; one of: 447 1. 'train': training stage. 448 2. 'validation': validation stage. 449 3. 'test': testing stage. 450 - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1. 451 - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`. 452 453 **Returns:** 454 - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation. 455 - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. 456 - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`. 457 - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features for various purposes. 458 """ 459 batch_size = input.size(0) 460 activations = {} 461 462 weight_mask, bias_mask = self.get_mask( 463 stage, 464 mask_percentage=mask_percentage, 465 test_mask=test_mask, 466 ) 467 468 x = input.view(batch_size, -1) # flatten before going through MLP 469 470 for layer_idx, layer_name in enumerate(self.weighted_layer_names): 471 weighted_layer = self.fc[layer_idx] 472 weight = weighted_layer.weight 473 bias = weighted_layer.bias 474 475 # mask the weight and bias 476 masked_weight = weight * weight_mask[f"fc/{layer_idx}"] 477 if bias is not None and bias_mask[f"fc/{layer_idx}"] is not None: 478 masked_bias = bias * bias_mask[f"fc/{layer_idx}"] 479 else: 480 masked_bias = None 481 482 # do the forward pass using the masked weight and bias. Do not modify the weight and bias data in the original layer object or it will lose the computation graph. 483 x = nn.functional.linear(x, masked_weight, masked_bias) 484 485 if self.activation: 486 x = self.fc_activation[layer_idx](x) # activation function third 487 activations[layer_name] = x # store the hidden feature 488 if self.dropout: 489 x = self.fc_dropout[layer_idx](x) # dropout last 490 491 output_feature = x 492 493 return output_feature, weight_mask, bias_mask, activations
The forward pass for data from task self.task_id. Task-specific masks for self.task_id are applied to units (neurons) in each fully connected layer.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): The stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- mask_percentage (
float): The percentage of parameters to be masked. The value should be between 0 and 1. - test_mask (
tuple[dict[str, Tensor], dict[str, Tensor]]|None): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it isNone.
Returns:
- output_feature (
Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation. - weight_mask (
dict[str, Tensor]): The weight mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight. - bias_mask (
dict[str, Tensor]): The bias mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it isNone. - activations (
dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features for various purposes.