clarena.backbones.mlp

The submodule in backbones for the MLP backbone network. It includes multiple versions of MLP, including the basic MLP, the continual learning MLP, the HAT masked MLP, and the WSN masked MLP.

  1r"""
  2The submodule in `backbones` for the MLP backbone network. It includes multiple versions of MLP, including the basic MLP, the continual learning MLP, the [HAT](http://proceedings.mlr.press/v80/serra18a) masked MLP, and the [WSN](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) masked MLP.
  3"""
  4
  5__all__ = ["MLP", "CLMLP", "HATMaskMLP", "WSNMaskMLP"]
  6
  7import logging
  8from copy import deepcopy
  9
 10from torch import Tensor, nn
 11
 12from clarena.backbones import Backbone, CLBackbone, HATMaskBackbone, WSNMaskBackbone
 13
 14# always get logger for built-in logging in each module
 15pylogger = logging.getLogger(__name__)
 16
 17
 18class MLP(Backbone):
 19    """Multi-layer perceptron (MLP), a.k.a. fully connected network.
 20
 21    MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the output heads.
 22    """
 23
 24    def __init__(
 25        self,
 26        input_dim: int,
 27        hidden_dims: list[int],
 28        output_dim: int,
 29        activation_layer: nn.Module | None = nn.ReLU,
 30        batch_normalization: bool = False,
 31        bias: bool = True,
 32        dropout: float | None = None,
 33        **kwargs,
 34    ) -> None:
 35        r"""Construct and initialize the MLP backbone network.
 36
 37        **Args:**
 38        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. Note that it is not required in convolutional networks.
 39        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
 40        - **output_dim** (`int`): The output dimension that connects to output heads.
 41        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
 42        - **batch_normalization** (`bool`): Whether to use batch normalization after the fully connected layers. Default `False`.
 43        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
 44        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.
 45        - **kwargs**: Reserved for multiple inheritance.
 46        """
 47        super().__init__(output_dim=output_dim, **kwargs)
 48
 49        self.input_dim: int = input_dim
 50        r"""The input dimension of the MLP backbone network."""
 51        self.hidden_dims: list[int] = hidden_dims
 52        r"""The hidden dimensions of the MLP backbone network."""
 53        self.output_dim: int = output_dim
 54        r"""The output dimension of the MLP backbone network."""
 55
 56        self.num_fc_layers: int = len(hidden_dims) + 1
 57        r"""The number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass."""
 58        self.batch_normalization: bool = batch_normalization
 59        r"""Whether to use batch normalization after the fully-connected layers."""
 60        self.activation: bool = activation_layer is not None
 61        r"""Whether to use activation function after the fully-connected layers."""
 62        self.dropout: bool = dropout is not None
 63        r"""Whether to use dropout after the fully-connected layers."""
 64
 65        self.fc: nn.ModuleList = nn.ModuleList()
 66        r"""The list of fully connected (`nn.Linear`) layers."""
 67        if self.batch_normalization:
 68            self.fc_bn: nn.ModuleList = nn.ModuleList()
 69            r"""The list of batch normalization (`nn.BatchNorm1d`) layers after the fully connected layers."""
 70        if self.activation:
 71            self.fc_activation: nn.ModuleList = nn.ModuleList()
 72            r"""The list of activation layers after the fully connected layers."""
 73        if self.dropout:
 74            self.fc_dropout: nn.ModuleList = nn.ModuleList()
 75            r"""The list of dropout layers after the fully connected layers."""
 76
 77        # construct the weighted fully connected layers and attached layers (batch norm, activation, dropout, etc.) in a loop
 78        for layer_idx in range(self.num_fc_layers):
 79
 80            # the input and output dim of the current weighted layer
 81            layer_input_dim = (
 82                self.input_dim if layer_idx == 0 else self.hidden_dims[layer_idx - 1]
 83            )
 84            layer_output_dim = (
 85                self.hidden_dims[layer_idx]
 86                if layer_idx != len(self.hidden_dims)
 87                else self.output_dim
 88            )
 89
 90            # construct the fully connected layer
 91            self.fc.append(
 92                nn.Linear(
 93                    in_features=layer_input_dim,
 94                    out_features=layer_output_dim,
 95                    bias=bias,
 96                )
 97            )
 98
 99            # update the weighted layer names
100            full_layer_name = f"fc/{layer_idx}"
101            self.weighted_layer_names.append(full_layer_name)
102
103            # construct the batch normalization layer
104            if self.batch_normalization:
105                self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim)))
106
107            # construct the activation layer
108            if self.activation:
109                self.fc_activation.append(activation_layer())
110
111            # construct the dropout layer
112            if self.dropout:
113                self.fc_dropout.append(nn.Dropout(dropout))
114
115    def forward(
116        self, input: Tensor, stage: str = None
117    ) -> tuple[Tensor, dict[str, Tensor]]:
118        r"""The forward pass for data. It is the same for all tasks.
119
120        **Args:**
121        - **input** (`Tensor`): The input tensor from data.
122
123        **Returns:**
124        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
125        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for certain algorithms that need to use hidden features for various purposes.
126        """
127        batch_size = input.size(0)
128        activations = {}
129
130        x = input.view(batch_size, -1)  # flatten before going through MLP
131
132        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
133            x = self.fc[layer_idx](x)  # fully-connected layer first
134            if self.batch_normalization:
135                x = self.fc_bn[layer_idx](
136                    x
137                )  # batch normalization can be before or after activation. We put it before activation here
138            if self.activation:
139                x = self.fc_activation[layer_idx](x)  # activation function third
140            activations[layer_name] = x  # store the hidden feature
141            if self.dropout:
142                x = self.fc_dropout[layer_idx](x)  # dropout last
143
144        output_feature = x
145
146        return output_feature, activations
147
148
149class CLMLP(CLBackbone, MLP):
150    """Multi-layer perceptron (MLP), a.k.a. fully connected network. Used as a continual learning backbone.
151
152    MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
153    """
154
155    def __init__(
156        self,
157        input_dim: int,
158        hidden_dims: list[int],
159        output_dim: int,
160        activation_layer: nn.Module | None = nn.ReLU,
161        batch_normalization: bool = False,
162        bias: bool = True,
163        dropout: float | None = None,
164        **kwargs,
165    ) -> None:
166        r"""Construct and initialize the CLMLP backbone network.
167
168        **Args:**
169        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks.
170        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
171        - **output_dim** (`int`): the output dimension which connects to CL output heads.
172        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
173        - **batch_normalization** (`bool`): whether to use batch normalization after the fully-connected layers. Default `False`.
174        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
175        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
176        """
177        super().__init__(
178            input_dim=input_dim,
179            hidden_dims=hidden_dims,
180            output_dim=output_dim,
181            activation_layer=activation_layer,
182            batch_normalization=batch_normalization,
183            bias=bias,
184            dropout=dropout,
185            **kwargs,
186        )
187
188    def forward(
189        self, input: Tensor, stage: str = None, task_id: int | None = None
190    ) -> tuple[Tensor, dict[str, Tensor]]:
191        r"""The forward pass for data. It is the same for all tasks.
192
193        **Args:**
194        - **input** (`Tensor`): The input tensor from data.
195
196        **Returns:**
197        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
198        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name; value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
199        """
200        return MLP.forward(self, input, stage)  # call the MLP forward method
201
202
203class HATMaskMLP(HATMaskBackbone, MLP):
204    r"""HAT-masked multi-layer perceptron (MLP).
205
206    [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
207
208    MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
209
210    The mask is applied to units (neurons) in each fully connected layer. The mask is generated from the neuron-wise task embedding and the gate function.
211    """
212
213    def __init__(
214        self,
215        input_dim: int,
216        hidden_dims: list[int],
217        output_dim: int,
218        gate: str,
219        activation_layer: nn.Module | None = nn.ReLU,
220        batch_normalization: str | None = None,
221        bias: bool = True,
222        dropout: float | None = None,
223    ) -> None:
224        r"""Construct and initialize the HAT-masked MLP backbone network with task embeddings. Note that batch normalization is incompatible with the HAT mechanism.
225
226        **Args:**
227        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP.
228        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
229        - **output_dim** (`int`): The output dimension that connects to CL output heads.
230        - **gate** (`str`): The type of gate function turning real-valued task embeddings into attention masks; one of:
231            - `sigmoid`: the sigmoid function.
232        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
233        - **batch_normalization** (`str` | `None`): How to use batch normalization after the fully connected layers; one of:
234            - `None`: no batch normalization layers.
235            - `shared`: use a single batch normalization layer for all tasks. Note that this can cause catastrophic forgetting.
236            - `independent`: use independent batch normalization layers for each task.
237        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
238        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
239        """
240        super().__init__(
241            output_dim=output_dim,
242            gate=gate,
243            input_dim=input_dim,
244            hidden_dims=hidden_dims,
245            activation_layer=activation_layer,
246            batch_normalization=(
247                True if batch_normalization == "shared" or "independent" else False
248            ),
249            bias=bias,
250            dropout=dropout,
251        )
252
253        # construct the task embedding for each weighted layer
254        for layer_idx in range(self.num_fc_layers):
255            full_layer_name = f"fc/{layer_idx}"
256            layer_output_dim = (
257                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
258            )
259            self.task_embedding_t[full_layer_name] = nn.Embedding(
260                num_embeddings=1, embedding_dim=layer_output_dim
261            )
262
263        self.batch_normalization: str | None = batch_normalization
264        r"""The way to use batch normalization after the fully-connected layers. This overrides the `batch_normalization` argument in `MLP` class. """
265
266        # construct the batch normalization layers if needed
267        if self.batch_normalization == "independent":
268            self.fc_bns: nn.ModuleDict = nn.ModuleDict()  # initially empty
269            r"""Independent batch normalization layers are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding batch normalization layers for the `nn.Linear`. We use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for purposes such as automatic device transfer and model summaries.
270            
271            Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model."""
272            self.original_fc_bn_state_dict: dict = deepcopy(self.fc_bn.state_dict())
273            r"""The original batch normalization state dict as the source for creating new independent batch normalization layers. """
274
275    def setup_task_id(self, task_id: int) -> None:
276        r"""Set up task `task_id`. This must be done before the `forward()` method is called.
277
278        **Args:**
279        - **task_id** (`int`): The target task ID.
280        """
281        HATMaskBackbone.setup_task_id(self, task_id=task_id)
282
283        if self.batch_normalization == "independent":
284            if self.task_id not in self.fc_bns.keys():
285                self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn)
286
287    def get_bn(self, stage: str, test_task_id: int | None) -> nn.Module | None:
288        r"""Get the batch normalization layer used in the `forward()` method for different stages.
289
290        **Args:**
291        - **stage** (`str`): The stage of the forward pass; one of:
292            1. 'train': training stage.
293            2. 'validation': validation stage.
294            3. 'test': testing stage.
295        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
296
297        **Returns:**
298        - **fc_bn** (`nn.Module` | `None`): The batch normalization module.
299        """
300        if self.batch_normalization == "independent" and stage == "test":
301            return self.fc_bns[f"{test_task_id}"]
302        else:
303            return self.fc_bn
304
305    def initialize_independent_bn(self) -> None:
306        r"""Initialize the independent batch normalization layer for the current task. This is called when a new task is created. Applies only when `batch_normalization` is 'independent'."""
307
308        if self.batch_normalization == "independent":
309            self.fc_bn.load_state_dict(self.original_fc_bn_state_dict)
310
311    def store_bn(self) -> None:
312        r"""Store the batch normalization layer for the current task `self.task_id`. Applies only when `batch_normalization` is 'independent'."""
313
314        if self.batch_normalization == "independent":
315            self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn)
316
317    def forward(
318        self,
319        input: Tensor,
320        stage: str,
321        s_max: float | None = None,
322        batch_idx: int | None = None,
323        num_batches: int | None = None,
324        test_task_id: int | None = None,
325    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
326        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer.
327
328        **Args:**
329        - **input** (`Tensor`): The input tensor from data.
330        - **stage** (`str`): The stage of the forward pass; one of:
331            1. 'train': training stage.
332            2. 'validation': validation stage.
333            3. 'test': testing stage.
334        - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
335        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
336        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
337        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
338
339        **Returns:**
340        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
341        - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
342        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class.
343        """
344        batch_size = input.size(0)
345        activations = {}
346
347        mask = self.get_mask(
348            stage=stage,
349            s_max=s_max,
350            batch_idx=batch_idx,
351            num_batches=num_batches,
352            test_task_id=test_task_id,
353        )
354        if self.batch_normalization:
355            fc_bn = self.get_bn(stage=stage, test_task_id=test_task_id)
356        x = input.view(batch_size, -1)  # flatten before going through MLP
357
358        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
359
360            x = self.fc[layer_idx](x)  # fully-connected layer first
361            if self.batch_normalization:
362                x = fc_bn[layer_idx](x)  # batch normalization second
363            x = x * mask[f"fc/{layer_idx}"]  # apply the mask to the parameters second
364
365            if self.activation:
366                x = self.fc_activation[layer_idx](x)  # activation function third
367            activations[layer_name] = x  # store the hidden feature
368            if self.dropout:
369                x = self.fc_dropout[layer_idx](x)  # dropout last
370
371        output_feature = x
372
373        return output_feature, mask, activations
374
375
376class WSNMaskMLP(MLP, WSNMaskBackbone):
377    r"""[WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) masked multi-layer perceptron (MLP).
378
379    [WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) is an architecture-based continual learning algorithm. It trains learnable parameter-wise importance and selects the most important $c\%$ of the network parameters to be used for each task.
380
381    MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
382
383    The mask is applied to the weights and biases in each fully connected layer. The mask is generated from the parameter-wise score and the gate function.
384    """
385
386    def __init__(
387        self,
388        input_dim: int,
389        hidden_dims: list[int],
390        output_dim: int,
391        activation_layer: nn.Module | None = nn.ReLU,
392        bias: bool = True,
393        dropout: float | None = None,
394    ) -> None:
395        r"""Construct and initialize the WSN-masked MLP backbone network with task embeddings.
396
397        **Args:**
398        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP.
399        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
400        - **output_dim** (`int`): The output dimension that connects to CL output heads.
401        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
402        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
403        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
404        """
405        # init from both inherited classes
406        super().__init__(
407            input_dim=input_dim,
408            hidden_dims=hidden_dims,
409            output_dim=output_dim,
410            activation_layer=activation_layer,
411            batch_normalization=False,
412            bias=bias,
413            dropout=dropout,
414        )
415
416        # construct the parameter score for each weighted layer
417        for layer_idx in range(self.num_fc_layers):
418            full_layer_name = f"fc/{layer_idx}"
419            layer_input_dim = (
420                input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1]
421            )
422            layer_output_dim = (
423                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
424            )
425            self.weight_score_t[full_layer_name] = nn.Embedding(
426                num_embeddings=layer_output_dim,
427                embedding_dim=layer_input_dim,
428            )
429            self.bias_score_t[full_layer_name] = nn.Embedding(
430                num_embeddings=1,
431                embedding_dim=layer_output_dim,
432            )
433
434    def forward(
435        self,
436        input: Tensor,
437        stage: str,
438        mask_percentage: float,
439        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
440    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]:
441        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer.
442
443        **Args:**
444        - **input** (`Tensor`): The input tensor from data.
445        - **stage** (`str`): The stage of the forward pass; one of:
446            1. 'train': training stage.
447            2. 'validation': validation stage.
448            3. 'test': testing stage.
449        - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1.
450        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`.
451
452        **Returns:**
453        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
454        - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
455        - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
456        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features for various purposes.
457        """
458        batch_size = input.size(0)
459        activations = {}
460
461        weight_mask, bias_mask = self.get_mask(
462            stage,
463            mask_percentage=mask_percentage,
464            test_mask=test_mask,
465        )
466
467        x = input.view(batch_size, -1)  # flatten before going through MLP
468
469        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
470            weighted_layer = self.fc[layer_idx]
471            weight = weighted_layer.weight
472            bias = weighted_layer.bias
473
474            # mask the weight and bias
475            masked_weight = weight * weight_mask[f"fc/{layer_idx}"]
476            if bias is not None and bias_mask[f"fc/{layer_idx}"] is not None:
477                masked_bias = bias * bias_mask[f"fc/{layer_idx}"]
478            else:
479                masked_bias = None
480
481            # do the forward pass using the masked weight and bias. Do not modify the weight and bias data in the original layer object or it will lose the computation graph.
482            x = nn.functional.linear(x, masked_weight, masked_bias)
483
484            if self.activation:
485                x = self.fc_activation[layer_idx](x)  # activation function third
486            activations[layer_name] = x  # store the hidden feature
487            if self.dropout:
488                x = self.fc_dropout[layer_idx](x)  # dropout last
489
490        output_feature = x
491
492        return output_feature, weight_mask, bias_mask, activations
493
494
495# r"""
496# The submodule in `backbones` for [NISPA (Neuro-Inspired Stability-Plasticity Adaptation)](https://proceedings.mlr.press/v162/gurbuz22a/gurbuz22a.pdf) masked MLP backbone network.
497# """
498
499# __all__ = ["NISPAMaskMLP"]
500
501# from torch import Tensor, nn
502
503# from clarena.backbones import MLP, NISPAMaskBackbone
504
505
506# class NISPAMaskMLP(MLP, NISPAMaskBackbone):
507#     r"""[NISPA (Neuro-Inspired Stability-Plasticity Adaptation)](https://proceedings.mlr.press/v162/gurbuz22a/gurbuz22a.pdf) masked multi-Layer perceptron (MLP).
508
509#     [NISPA (Neuro-Inspired Stability-Plasticity Adaptation)](https://proceedings.mlr.press/v162/gurbuz22a/gurbuz22a.pdf) is an architecture-based continual learning algorithm. It
510
511#     MLP is a dense network architecture, which has several fully-connected layers, each followed by an activation function. The last layer connects to the CL output heads.
512
513#     Mask is applied to the weights and biases in each fully-connected layer. The mask is generated from the parameter-wise score and gate function.
514#     """
515
516#     def __init__(
517#         self,
518#         input_dim: int,
519#         hidden_dims: list[int],
520#         output_dim: int,
521#         activation_layer: nn.Module | None = nn.ReLU,
522#         bias: bool = True,
523#         dropout: float | None = None,
524#     ) -> None:
525#         r"""Construct and initialize the WSN masked MLP backbone network with task embedding. Note that batch normalization is incompatible with WSN mechanism.
526
527#         **Args:**
528#         - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP.
529#         - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
530#         - **output_dim** (`int`): the output dimension which connects to CL output heads.
531#         - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
532#         - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
533#         - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.
534#         """
535#         # init from both inherited classes
536#         super().__init__(
537#             input_dim=input_dim,
538#             hidden_dims=hidden_dims,
539#             output_dim=output_dim,
540#             activation_layer=activation_layer,
541#             batch_normalization=False,  # batch normalization is incompatible with HAT mechanism
542#             bias=bias,
543#             dropout=dropout,
544#         )
545#         self.register_wsn_mask_module_explicitly()  # register all `nn.Module`s for WSNMaskBackbone explicitly because the second `__init__()` wipes out them inited by the first `__init__()`
546
547#         # construct the parameter score for each weighted layer
548#         for layer_idx in range(self.num_fc_layers):
549#             full_layer_name = f"fc/{layer_idx}"
550#             layer_input_dim = (
551#                 input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1]
552#             )
553#             layer_output_dim = (
554#                 hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
555#             )
556#             self.weight_score_t[full_layer_name] = nn.Embedding(
557#                 num_embeddings=layer_output_dim,
558#                 embedding_dim=layer_input_dim,
559#             )
560#             self.bias_score_t[full_layer_name] = nn.Embedding(
561#                 num_embeddings=1,
562#                 embedding_dim=layer_output_dim,
563#             )
564
565#     def forward(
566#         self,
567#         input: Tensor,
568#         stage: str,
569#         mask_percentage: float,
570#         test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
571#     ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]:
572#         r"""The forward pass for data from task `self.task_id`. Task-specific mask for `self.task_id` are applied to the units which are neurons in each fully-connected layer.
573
574#         **Args:**
575#         - **input** (`Tensor`): The input tensor from data.
576#         - **stage** (`str`): the stage of the forward pass; one of:
577#             1. 'train': training stage.
578#             2. 'validation': validation stage.
579#             3. 'test': testing stage.
580#         - **mask_percentage** (`float`): the percentage of parameters to be masked. The value should be between 0 and 1.
581#         - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): the binary weight and bias mask used for test. Applies only to testing stage. For other stages, it is default `None`.
582
583#         **Returns:**
584#         - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
585#         - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight.
586#         - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`.
587#         - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
588#         """
589#         batch_size = input.size(0)
590#         activations = {}
591
592#         weight_mask, bias_mask = self.get_mask(
593#             stage,
594#             mask_percentage=mask_percentage,
595#             test_mask=test_mask,
596#         )
597
598#         x = input.view(batch_size, -1)  # flatten before going through MLP
599
600#         for layer_idx, layer_name in enumerate(self.weighted_layer_names):
601#             weighted_layer = self.fc[layer_idx]
602#             weight = weighted_layer.weight
603#             bias = weighted_layer.bias
604
605#             # mask the weight and bias
606#             masked_weight = weight * weight_mask[f"fc/{layer_idx}"]
607#             if bias is not None and bias_mask[f"fc/{layer_idx}"] is not None:
608#                 masked_bias = bias * bias_mask[f"fc/{layer_idx}"]
609#             else:
610#                 masked_bias = None
611
612#             # do the forward pass using the masked weight and bias. Do not modify the weight and bias data in the original layer object or it will lose the computation graph.
613#             x = nn.functional.linear(x, masked_weight, masked_bias)
614
615#             if self.activation:
616#                 x = self.fc_activation[layer_idx](x)  # activation function third
617#             activations[layer_name] = x  # store the hidden feature
618#             if self.dropout:
619#                 x = self.fc_dropout[layer_idx](x)  # dropout last
620
621#         output_feature = x
622
623#         return output_feature, weight_mask, bias_mask, activations
624#         return output_feature, weight_mask, bias_mask, activations
class MLP(clarena.backbones.base.Backbone):
 19class MLP(Backbone):
 20    """Multi-layer perceptron (MLP), a.k.a. fully connected network.
 21
 22    MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the output heads.
 23    """
 24
 25    def __init__(
 26        self,
 27        input_dim: int,
 28        hidden_dims: list[int],
 29        output_dim: int,
 30        activation_layer: nn.Module | None = nn.ReLU,
 31        batch_normalization: bool = False,
 32        bias: bool = True,
 33        dropout: float | None = None,
 34        **kwargs,
 35    ) -> None:
 36        r"""Construct and initialize the MLP backbone network.
 37
 38        **Args:**
 39        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. Note that it is not required in convolutional networks.
 40        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
 41        - **output_dim** (`int`): The output dimension that connects to output heads.
 42        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
 43        - **batch_normalization** (`bool`): Whether to use batch normalization after the fully connected layers. Default `False`.
 44        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
 45        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.
 46        - **kwargs**: Reserved for multiple inheritance.
 47        """
 48        super().__init__(output_dim=output_dim, **kwargs)
 49
 50        self.input_dim: int = input_dim
 51        r"""The input dimension of the MLP backbone network."""
 52        self.hidden_dims: list[int] = hidden_dims
 53        r"""The hidden dimensions of the MLP backbone network."""
 54        self.output_dim: int = output_dim
 55        r"""The output dimension of the MLP backbone network."""
 56
 57        self.num_fc_layers: int = len(hidden_dims) + 1
 58        r"""The number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass."""
 59        self.batch_normalization: bool = batch_normalization
 60        r"""Whether to use batch normalization after the fully-connected layers."""
 61        self.activation: bool = activation_layer is not None
 62        r"""Whether to use activation function after the fully-connected layers."""
 63        self.dropout: bool = dropout is not None
 64        r"""Whether to use dropout after the fully-connected layers."""
 65
 66        self.fc: nn.ModuleList = nn.ModuleList()
 67        r"""The list of fully connected (`nn.Linear`) layers."""
 68        if self.batch_normalization:
 69            self.fc_bn: nn.ModuleList = nn.ModuleList()
 70            r"""The list of batch normalization (`nn.BatchNorm1d`) layers after the fully connected layers."""
 71        if self.activation:
 72            self.fc_activation: nn.ModuleList = nn.ModuleList()
 73            r"""The list of activation layers after the fully connected layers."""
 74        if self.dropout:
 75            self.fc_dropout: nn.ModuleList = nn.ModuleList()
 76            r"""The list of dropout layers after the fully connected layers."""
 77
 78        # construct the weighted fully connected layers and attached layers (batch norm, activation, dropout, etc.) in a loop
 79        for layer_idx in range(self.num_fc_layers):
 80
 81            # the input and output dim of the current weighted layer
 82            layer_input_dim = (
 83                self.input_dim if layer_idx == 0 else self.hidden_dims[layer_idx - 1]
 84            )
 85            layer_output_dim = (
 86                self.hidden_dims[layer_idx]
 87                if layer_idx != len(self.hidden_dims)
 88                else self.output_dim
 89            )
 90
 91            # construct the fully connected layer
 92            self.fc.append(
 93                nn.Linear(
 94                    in_features=layer_input_dim,
 95                    out_features=layer_output_dim,
 96                    bias=bias,
 97                )
 98            )
 99
100            # update the weighted layer names
101            full_layer_name = f"fc/{layer_idx}"
102            self.weighted_layer_names.append(full_layer_name)
103
104            # construct the batch normalization layer
105            if self.batch_normalization:
106                self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim)))
107
108            # construct the activation layer
109            if self.activation:
110                self.fc_activation.append(activation_layer())
111
112            # construct the dropout layer
113            if self.dropout:
114                self.fc_dropout.append(nn.Dropout(dropout))
115
116    def forward(
117        self, input: Tensor, stage: str = None
118    ) -> tuple[Tensor, dict[str, Tensor]]:
119        r"""The forward pass for data. It is the same for all tasks.
120
121        **Args:**
122        - **input** (`Tensor`): The input tensor from data.
123
124        **Returns:**
125        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
126        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for certain algorithms that need to use hidden features for various purposes.
127        """
128        batch_size = input.size(0)
129        activations = {}
130
131        x = input.view(batch_size, -1)  # flatten before going through MLP
132
133        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
134            x = self.fc[layer_idx](x)  # fully-connected layer first
135            if self.batch_normalization:
136                x = self.fc_bn[layer_idx](
137                    x
138                )  # batch normalization can be before or after activation. We put it before activation here
139            if self.activation:
140                x = self.fc_activation[layer_idx](x)  # activation function third
141            activations[layer_name] = x  # store the hidden feature
142            if self.dropout:
143                x = self.fc_dropout[layer_idx](x)  # dropout last
144
145        output_feature = x
146
147        return output_feature, activations

Multi-layer perceptron (MLP), a.k.a. fully connected network.

MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the output heads.

MLP( input_dim: int, hidden_dims: list[int], output_dim: int, activation_layer: torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.ReLU'>, batch_normalization: bool = False, bias: bool = True, dropout: float | None = None, **kwargs)
 25    def __init__(
 26        self,
 27        input_dim: int,
 28        hidden_dims: list[int],
 29        output_dim: int,
 30        activation_layer: nn.Module | None = nn.ReLU,
 31        batch_normalization: bool = False,
 32        bias: bool = True,
 33        dropout: float | None = None,
 34        **kwargs,
 35    ) -> None:
 36        r"""Construct and initialize the MLP backbone network.
 37
 38        **Args:**
 39        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP. Note that it is not required in convolutional networks.
 40        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
 41        - **output_dim** (`int`): The output dimension that connects to output heads.
 42        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
 43        - **batch_normalization** (`bool`): Whether to use batch normalization after the fully connected layers. Default `False`.
 44        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
 45        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.
 46        - **kwargs**: Reserved for multiple inheritance.
 47        """
 48        super().__init__(output_dim=output_dim, **kwargs)
 49
 50        self.input_dim: int = input_dim
 51        r"""The input dimension of the MLP backbone network."""
 52        self.hidden_dims: list[int] = hidden_dims
 53        r"""The hidden dimensions of the MLP backbone network."""
 54        self.output_dim: int = output_dim
 55        r"""The output dimension of the MLP backbone network."""
 56
 57        self.num_fc_layers: int = len(hidden_dims) + 1
 58        r"""The number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass."""
 59        self.batch_normalization: bool = batch_normalization
 60        r"""Whether to use batch normalization after the fully-connected layers."""
 61        self.activation: bool = activation_layer is not None
 62        r"""Whether to use activation function after the fully-connected layers."""
 63        self.dropout: bool = dropout is not None
 64        r"""Whether to use dropout after the fully-connected layers."""
 65
 66        self.fc: nn.ModuleList = nn.ModuleList()
 67        r"""The list of fully connected (`nn.Linear`) layers."""
 68        if self.batch_normalization:
 69            self.fc_bn: nn.ModuleList = nn.ModuleList()
 70            r"""The list of batch normalization (`nn.BatchNorm1d`) layers after the fully connected layers."""
 71        if self.activation:
 72            self.fc_activation: nn.ModuleList = nn.ModuleList()
 73            r"""The list of activation layers after the fully connected layers."""
 74        if self.dropout:
 75            self.fc_dropout: nn.ModuleList = nn.ModuleList()
 76            r"""The list of dropout layers after the fully connected layers."""
 77
 78        # construct the weighted fully connected layers and attached layers (batch norm, activation, dropout, etc.) in a loop
 79        for layer_idx in range(self.num_fc_layers):
 80
 81            # the input and output dim of the current weighted layer
 82            layer_input_dim = (
 83                self.input_dim if layer_idx == 0 else self.hidden_dims[layer_idx - 1]
 84            )
 85            layer_output_dim = (
 86                self.hidden_dims[layer_idx]
 87                if layer_idx != len(self.hidden_dims)
 88                else self.output_dim
 89            )
 90
 91            # construct the fully connected layer
 92            self.fc.append(
 93                nn.Linear(
 94                    in_features=layer_input_dim,
 95                    out_features=layer_output_dim,
 96                    bias=bias,
 97                )
 98            )
 99
100            # update the weighted layer names
101            full_layer_name = f"fc/{layer_idx}"
102            self.weighted_layer_names.append(full_layer_name)
103
104            # construct the batch normalization layer
105            if self.batch_normalization:
106                self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim)))
107
108            # construct the activation layer
109            if self.activation:
110                self.fc_activation.append(activation_layer())
111
112            # construct the dropout layer
113            if self.dropout:
114                self.fc_dropout.append(nn.Dropout(dropout))

Construct and initialize the MLP backbone network.

Args:

  • input_dim (int): The input dimension. Any data need to be flattened before entering the MLP. Note that it is not required in convolutional networks.
  • hidden_dims (list[int]): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
  • output_dim (int): The output dimension that connects to output heads.
  • activation_layer (nn.Module | None): Activation function of each layer (if not None). If None, this layer won't be used. Default nn.ReLU.
  • batch_normalization (bool): Whether to use batch normalization after the fully connected layers. Default False.
  • bias (bool): Whether to use bias in the linear layer. Default True.
  • dropout (float | None): The probability for the dropout layer. If None, this layer won't be used. Default None.
  • kwargs: Reserved for multiple inheritance.
input_dim: int

The input dimension of the MLP backbone network.

hidden_dims: list[int]

The hidden dimensions of the MLP backbone network.

output_dim: int

The output dimension of the MLP backbone network.

num_fc_layers: int

The number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass.

batch_normalization: bool

Whether to use batch normalization after the fully-connected layers.

activation: bool

Whether to use activation function after the fully-connected layers.

dropout: bool

Whether to use dropout after the fully-connected layers.

fc: torch.nn.modules.container.ModuleList

The list of fully connected (nn.Linear) layers.

def forward( self, input: torch.Tensor, stage: str = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
116    def forward(
117        self, input: Tensor, stage: str = None
118    ) -> tuple[Tensor, dict[str, Tensor]]:
119        r"""The forward pass for data. It is the same for all tasks.
120
121        **Args:**
122        - **input** (`Tensor`): The input tensor from data.
123
124        **Returns:**
125        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
126        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for certain algorithms that need to use hidden features for various purposes.
127        """
128        batch_size = input.size(0)
129        activations = {}
130
131        x = input.view(batch_size, -1)  # flatten before going through MLP
132
133        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
134            x = self.fc[layer_idx](x)  # fully-connected layer first
135            if self.batch_normalization:
136                x = self.fc_bn[layer_idx](
137                    x
138                )  # batch normalization can be before or after activation. We put it before activation here
139            if self.activation:
140                x = self.fc_activation[layer_idx](x)  # activation function third
141            activations[layer_name] = x  # store the hidden feature
142            if self.dropout:
143                x = self.fc_dropout[layer_idx](x)  # dropout last
144
145        output_feature = x
146
147        return output_feature, activations

The forward pass for data. It is the same for all tasks.

Args:

  • input (Tensor): The input tensor from data.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for certain algorithms that need to use hidden features for various purposes.
class CLMLP(clarena.backbones.base.CLBackbone, MLP):
150class CLMLP(CLBackbone, MLP):
151    """Multi-layer perceptron (MLP), a.k.a. fully connected network. Used as a continual learning backbone.
152
153    MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
154    """
155
156    def __init__(
157        self,
158        input_dim: int,
159        hidden_dims: list[int],
160        output_dim: int,
161        activation_layer: nn.Module | None = nn.ReLU,
162        batch_normalization: bool = False,
163        bias: bool = True,
164        dropout: float | None = None,
165        **kwargs,
166    ) -> None:
167        r"""Construct and initialize the CLMLP backbone network.
168
169        **Args:**
170        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks.
171        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
172        - **output_dim** (`int`): the output dimension which connects to CL output heads.
173        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
174        - **batch_normalization** (`bool`): whether to use batch normalization after the fully-connected layers. Default `False`.
175        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
176        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
177        """
178        super().__init__(
179            input_dim=input_dim,
180            hidden_dims=hidden_dims,
181            output_dim=output_dim,
182            activation_layer=activation_layer,
183            batch_normalization=batch_normalization,
184            bias=bias,
185            dropout=dropout,
186            **kwargs,
187        )
188
189    def forward(
190        self, input: Tensor, stage: str = None, task_id: int | None = None
191    ) -> tuple[Tensor, dict[str, Tensor]]:
192        r"""The forward pass for data. It is the same for all tasks.
193
194        **Args:**
195        - **input** (`Tensor`): The input tensor from data.
196
197        **Returns:**
198        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
199        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name; value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
200        """
201        return MLP.forward(self, input, stage)  # call the MLP forward method

Multi-layer perceptron (MLP), a.k.a. fully connected network. Used as a continual learning backbone.

MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.

CLMLP( input_dim: int, hidden_dims: list[int], output_dim: int, activation_layer: torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.ReLU'>, batch_normalization: bool = False, bias: bool = True, dropout: float | None = None, **kwargs)
156    def __init__(
157        self,
158        input_dim: int,
159        hidden_dims: list[int],
160        output_dim: int,
161        activation_layer: nn.Module | None = nn.ReLU,
162        batch_normalization: bool = False,
163        bias: bool = True,
164        dropout: float | None = None,
165        **kwargs,
166    ) -> None:
167        r"""Construct and initialize the CLMLP backbone network.
168
169        **Args:**
170        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks.
171        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
172        - **output_dim** (`int`): the output dimension which connects to CL output heads.
173        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
174        - **batch_normalization** (`bool`): whether to use batch normalization after the fully-connected layers. Default `False`.
175        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
176        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
177        """
178        super().__init__(
179            input_dim=input_dim,
180            hidden_dims=hidden_dims,
181            output_dim=output_dim,
182            activation_layer=activation_layer,
183            batch_normalization=batch_normalization,
184            bias=bias,
185            dropout=dropout,
186            **kwargs,
187        )

Construct and initialize the CLMLP backbone network.

Args:

  • input_dim (int): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks.
  • hidden_dims (list[int]): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
  • output_dim (int): the output dimension which connects to CL output heads.
  • activation_layer (nn.Module | None): activation function of each layer (if not None), if None this layer won't be used. Default nn.ReLU.
  • batch_normalization (bool): whether to use batch normalization after the fully-connected layers. Default False.
  • bias (bool): whether to use bias in the linear layer. Default True.
  • dropout (float | None): the probability for the dropout layer, if None this layer won't be used. Default None. - kwargs: Reserved for multiple inheritance.
def forward( self, input: torch.Tensor, stage: str = None, task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
189    def forward(
190        self, input: Tensor, stage: str = None, task_id: int | None = None
191    ) -> tuple[Tensor, dict[str, Tensor]]:
192        r"""The forward pass for data. It is the same for all tasks.
193
194        **Args:**
195        - **input** (`Tensor`): The input tensor from data.
196
197        **Returns:**
198        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
199        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name; value (`Tensor`) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
200        """
201        return MLP.forward(self, input, stage)  # call the MLP forward method

The forward pass for data. It is the same for all tasks.

Args:

  • input (Tensor): The input tensor from data.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name; value (Tensor) is the hidden feature tensor. This is used for continual learning algorithms that need hidden features for various purposes.
class HATMaskMLP(clarena.backbones.base.HATMaskBackbone, MLP):
204class HATMaskMLP(HATMaskBackbone, MLP):
205    r"""HAT-masked multi-layer perceptron (MLP).
206
207    [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
208
209    MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
210
211    The mask is applied to units (neurons) in each fully connected layer. The mask is generated from the neuron-wise task embedding and the gate function.
212    """
213
214    def __init__(
215        self,
216        input_dim: int,
217        hidden_dims: list[int],
218        output_dim: int,
219        gate: str,
220        activation_layer: nn.Module | None = nn.ReLU,
221        batch_normalization: str | None = None,
222        bias: bool = True,
223        dropout: float | None = None,
224    ) -> None:
225        r"""Construct and initialize the HAT-masked MLP backbone network with task embeddings. Note that batch normalization is incompatible with the HAT mechanism.
226
227        **Args:**
228        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP.
229        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
230        - **output_dim** (`int`): The output dimension that connects to CL output heads.
231        - **gate** (`str`): The type of gate function turning real-valued task embeddings into attention masks; one of:
232            - `sigmoid`: the sigmoid function.
233        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
234        - **batch_normalization** (`str` | `None`): How to use batch normalization after the fully connected layers; one of:
235            - `None`: no batch normalization layers.
236            - `shared`: use a single batch normalization layer for all tasks. Note that this can cause catastrophic forgetting.
237            - `independent`: use independent batch normalization layers for each task.
238        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
239        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
240        """
241        super().__init__(
242            output_dim=output_dim,
243            gate=gate,
244            input_dim=input_dim,
245            hidden_dims=hidden_dims,
246            activation_layer=activation_layer,
247            batch_normalization=(
248                True if batch_normalization == "shared" or "independent" else False
249            ),
250            bias=bias,
251            dropout=dropout,
252        )
253
254        # construct the task embedding for each weighted layer
255        for layer_idx in range(self.num_fc_layers):
256            full_layer_name = f"fc/{layer_idx}"
257            layer_output_dim = (
258                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
259            )
260            self.task_embedding_t[full_layer_name] = nn.Embedding(
261                num_embeddings=1, embedding_dim=layer_output_dim
262            )
263
264        self.batch_normalization: str | None = batch_normalization
265        r"""The way to use batch normalization after the fully-connected layers. This overrides the `batch_normalization` argument in `MLP` class. """
266
267        # construct the batch normalization layers if needed
268        if self.batch_normalization == "independent":
269            self.fc_bns: nn.ModuleDict = nn.ModuleDict()  # initially empty
270            r"""Independent batch normalization layers are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding batch normalization layers for the `nn.Linear`. We use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for purposes such as automatic device transfer and model summaries.
271            
272            Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model."""
273            self.original_fc_bn_state_dict: dict = deepcopy(self.fc_bn.state_dict())
274            r"""The original batch normalization state dict as the source for creating new independent batch normalization layers. """
275
276    def setup_task_id(self, task_id: int) -> None:
277        r"""Set up task `task_id`. This must be done before the `forward()` method is called.
278
279        **Args:**
280        - **task_id** (`int`): The target task ID.
281        """
282        HATMaskBackbone.setup_task_id(self, task_id=task_id)
283
284        if self.batch_normalization == "independent":
285            if self.task_id not in self.fc_bns.keys():
286                self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn)
287
288    def get_bn(self, stage: str, test_task_id: int | None) -> nn.Module | None:
289        r"""Get the batch normalization layer used in the `forward()` method for different stages.
290
291        **Args:**
292        - **stage** (`str`): The stage of the forward pass; one of:
293            1. 'train': training stage.
294            2. 'validation': validation stage.
295            3. 'test': testing stage.
296        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
297
298        **Returns:**
299        - **fc_bn** (`nn.Module` | `None`): The batch normalization module.
300        """
301        if self.batch_normalization == "independent" and stage == "test":
302            return self.fc_bns[f"{test_task_id}"]
303        else:
304            return self.fc_bn
305
306    def initialize_independent_bn(self) -> None:
307        r"""Initialize the independent batch normalization layer for the current task. This is called when a new task is created. Applies only when `batch_normalization` is 'independent'."""
308
309        if self.batch_normalization == "independent":
310            self.fc_bn.load_state_dict(self.original_fc_bn_state_dict)
311
312    def store_bn(self) -> None:
313        r"""Store the batch normalization layer for the current task `self.task_id`. Applies only when `batch_normalization` is 'independent'."""
314
315        if self.batch_normalization == "independent":
316            self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn)
317
318    def forward(
319        self,
320        input: Tensor,
321        stage: str,
322        s_max: float | None = None,
323        batch_idx: int | None = None,
324        num_batches: int | None = None,
325        test_task_id: int | None = None,
326    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
327        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer.
328
329        **Args:**
330        - **input** (`Tensor`): The input tensor from data.
331        - **stage** (`str`): The stage of the forward pass; one of:
332            1. 'train': training stage.
333            2. 'validation': validation stage.
334            3. 'test': testing stage.
335        - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
336        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
337        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
338        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
339
340        **Returns:**
341        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
342        - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
343        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class.
344        """
345        batch_size = input.size(0)
346        activations = {}
347
348        mask = self.get_mask(
349            stage=stage,
350            s_max=s_max,
351            batch_idx=batch_idx,
352            num_batches=num_batches,
353            test_task_id=test_task_id,
354        )
355        if self.batch_normalization:
356            fc_bn = self.get_bn(stage=stage, test_task_id=test_task_id)
357        x = input.view(batch_size, -1)  # flatten before going through MLP
358
359        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
360
361            x = self.fc[layer_idx](x)  # fully-connected layer first
362            if self.batch_normalization:
363                x = fc_bn[layer_idx](x)  # batch normalization second
364            x = x * mask[f"fc/{layer_idx}"]  # apply the mask to the parameters second
365
366            if self.activation:
367                x = self.fc_activation[layer_idx](x)  # activation function third
368            activations[layer_name] = x  # store the hidden feature
369            if self.dropout:
370                x = self.fc_dropout[layer_idx](x)  # dropout last
371
372        output_feature = x
373
374        return output_feature, mask, activations

HAT-masked multi-layer perceptron (MLP).

HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.

MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.

The mask is applied to units (neurons) in each fully connected layer. The mask is generated from the neuron-wise task embedding and the gate function.

HATMaskMLP( input_dim: int, hidden_dims: list[int], output_dim: int, gate: str, activation_layer: torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.ReLU'>, batch_normalization: str | None = None, bias: bool = True, dropout: float | None = None)
214    def __init__(
215        self,
216        input_dim: int,
217        hidden_dims: list[int],
218        output_dim: int,
219        gate: str,
220        activation_layer: nn.Module | None = nn.ReLU,
221        batch_normalization: str | None = None,
222        bias: bool = True,
223        dropout: float | None = None,
224    ) -> None:
225        r"""Construct and initialize the HAT-masked MLP backbone network with task embeddings. Note that batch normalization is incompatible with the HAT mechanism.
226
227        **Args:**
228        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP.
229        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
230        - **output_dim** (`int`): The output dimension that connects to CL output heads.
231        - **gate** (`str`): The type of gate function turning real-valued task embeddings into attention masks; one of:
232            - `sigmoid`: the sigmoid function.
233        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
234        - **batch_normalization** (`str` | `None`): How to use batch normalization after the fully connected layers; one of:
235            - `None`: no batch normalization layers.
236            - `shared`: use a single batch normalization layer for all tasks. Note that this can cause catastrophic forgetting.
237            - `independent`: use independent batch normalization layers for each task.
238        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
239        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
240        """
241        super().__init__(
242            output_dim=output_dim,
243            gate=gate,
244            input_dim=input_dim,
245            hidden_dims=hidden_dims,
246            activation_layer=activation_layer,
247            batch_normalization=(
248                True if batch_normalization == "shared" or "independent" else False
249            ),
250            bias=bias,
251            dropout=dropout,
252        )
253
254        # construct the task embedding for each weighted layer
255        for layer_idx in range(self.num_fc_layers):
256            full_layer_name = f"fc/{layer_idx}"
257            layer_output_dim = (
258                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
259            )
260            self.task_embedding_t[full_layer_name] = nn.Embedding(
261                num_embeddings=1, embedding_dim=layer_output_dim
262            )
263
264        self.batch_normalization: str | None = batch_normalization
265        r"""The way to use batch normalization after the fully-connected layers. This overrides the `batch_normalization` argument in `MLP` class. """
266
267        # construct the batch normalization layers if needed
268        if self.batch_normalization == "independent":
269            self.fc_bns: nn.ModuleDict = nn.ModuleDict()  # initially empty
270            r"""Independent batch normalization layers are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding batch normalization layers for the `nn.Linear`. We use `ModuleDict` rather than `dict` to ensure `LightningModule` can track these model parameters for purposes such as automatic device transfer and model summaries.
271            
272            Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model."""
273            self.original_fc_bn_state_dict: dict = deepcopy(self.fc_bn.state_dict())
274            r"""The original batch normalization state dict as the source for creating new independent batch normalization layers. """

Construct and initialize the HAT-masked MLP backbone network with task embeddings. Note that batch normalization is incompatible with the HAT mechanism.

Args:

  • input_dim (int): The input dimension. Any data need to be flattened before entering the MLP.
  • hidden_dims (list[int]): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
  • output_dim (int): The output dimension that connects to CL output heads.
  • gate (str): The type of gate function turning real-valued task embeddings into attention masks; one of:
    • sigmoid: the sigmoid function.
  • activation_layer (nn.Module | None): Activation function of each layer (if not None). If None, this layer won't be used. Default nn.ReLU.
  • batch_normalization (str | None): How to use batch normalization after the fully connected layers; one of:
    • None: no batch normalization layers.
    • shared: use a single batch normalization layer for all tasks. Note that this can cause catastrophic forgetting.
    • independent: use independent batch normalization layers for each task.
  • bias (bool): Whether to use bias in the linear layer. Default True.
  • dropout (float | None): The probability for the dropout layer. If None, this layer won't be used. Default None. - kwargs: Reserved for multiple inheritance.
batch_normalization: str | None

The way to use batch normalization after the fully-connected layers. This overrides the batch_normalization argument in MLP class.

def setup_task_id(self, task_id: int) -> None:
276    def setup_task_id(self, task_id: int) -> None:
277        r"""Set up task `task_id`. This must be done before the `forward()` method is called.
278
279        **Args:**
280        - **task_id** (`int`): The target task ID.
281        """
282        HATMaskBackbone.setup_task_id(self, task_id=task_id)
283
284        if self.batch_normalization == "independent":
285            if self.task_id not in self.fc_bns.keys():
286                self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn)

Set up task task_id. This must be done before the forward() method is called.

Args:

  • task_id (int): The target task ID.
def get_bn( self, stage: str, test_task_id: int | None) -> torch.nn.modules.module.Module | None:
288    def get_bn(self, stage: str, test_task_id: int | None) -> nn.Module | None:
289        r"""Get the batch normalization layer used in the `forward()` method for different stages.
290
291        **Args:**
292        - **stage** (`str`): The stage of the forward pass; one of:
293            1. 'train': training stage.
294            2. 'validation': validation stage.
295            3. 'test': testing stage.
296        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
297
298        **Returns:**
299        - **fc_bn** (`nn.Module` | `None`): The batch normalization module.
300        """
301        if self.batch_normalization == "independent" and stage == "test":
302            return self.fc_bns[f"{test_task_id}"]
303        else:
304            return self.fc_bn

Get the batch normalization layer used in the forward() method for different stages.

Args:

  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • test_task_id (int | None): The test task ID. Applies only to the testing stage. For other stages, it is None.

Returns:

  • fc_bn (nn.Module | None): The batch normalization module.
def initialize_independent_bn(self) -> None:
306    def initialize_independent_bn(self) -> None:
307        r"""Initialize the independent batch normalization layer for the current task. This is called when a new task is created. Applies only when `batch_normalization` is 'independent'."""
308
309        if self.batch_normalization == "independent":
310            self.fc_bn.load_state_dict(self.original_fc_bn_state_dict)

Initialize the independent batch normalization layer for the current task. This is called when a new task is created. Applies only when batch_normalization is 'independent'.

def store_bn(self) -> None:
312    def store_bn(self) -> None:
313        r"""Store the batch normalization layer for the current task `self.task_id`. Applies only when `batch_normalization` is 'independent'."""
314
315        if self.batch_normalization == "independent":
316            self.fc_bns[f"{self.task_id}"] = deepcopy(self.fc_bn)

Store the batch normalization layer for the current task self.task_id. Applies only when batch_normalization is 'independent'.

def forward( self, input: torch.Tensor, stage: str, s_max: float | None = None, batch_idx: int | None = None, num_batches: int | None = None, test_task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor]]:
318    def forward(
319        self,
320        input: Tensor,
321        stage: str,
322        s_max: float | None = None,
323        batch_idx: int | None = None,
324        num_batches: int | None = None,
325        test_task_id: int | None = None,
326    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
327        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer.
328
329        **Args:**
330        - **input** (`Tensor`): The input tensor from data.
331        - **stage** (`str`): The stage of the forward pass; one of:
332            1. 'train': training stage.
333            2. 'validation': validation stage.
334            3. 'test': testing stage.
335        - **s_max** (`float`): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
336        - **batch_idx** (`int` | `None`): The current batch index. Applies only to the training stage. For other stages, it is `None`.
337        - **num_batches** (`int` | `None`): The total number of batches. Applies only to the training stage. For other stages, it is `None`.
338        - **test_task_id** (`int` | `None`): The test task ID. Applies only to the testing stage. For other stages, it is `None`.
339
340        **Returns:**
341        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
342        - **mask** (`dict[str, Tensor]`): The mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
343        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this `forward()` method of the `HAT` class.
344        """
345        batch_size = input.size(0)
346        activations = {}
347
348        mask = self.get_mask(
349            stage=stage,
350            s_max=s_max,
351            batch_idx=batch_idx,
352            num_batches=num_batches,
353            test_task_id=test_task_id,
354        )
355        if self.batch_normalization:
356            fc_bn = self.get_bn(stage=stage, test_task_id=test_task_id)
357        x = input.view(batch_size, -1)  # flatten before going through MLP
358
359        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
360
361            x = self.fc[layer_idx](x)  # fully-connected layer first
362            if self.batch_normalization:
363                x = fc_bn[layer_idx](x)  # batch normalization second
364            x = x * mask[f"fc/{layer_idx}"]  # apply the mask to the parameters second
365
366            if self.activation:
367                x = self.fc_activation[layer_idx](x)  # activation function third
368            activations[layer_name] = x  # store the hidden feature
369            if self.dropout:
370                x = self.fc_dropout[layer_idx](x)  # dropout last
371
372        output_feature = x
373
374        return output_feature, mask, activations

The forward pass for data from task self.task_id. Task-specific masks for self.task_id are applied to units (neurons) in each fully connected layer.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • s_max (float): The maximum scaling factor in the gate function. Doesn't apply to the testing stage. See Sec. 2.4 in the HAT paper.
  • batch_idx (int | None): The current batch index. Applies only to the training stage. For other stages, it is None.
  • num_batches (int | None): The total number of batches. Applies only to the training stage. For other stages, it is None.
  • test_task_id (int | None): The test task ID. Applies only to the testing stage. For other stages, it is None.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • mask (dict[str, Tensor]): The mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features. Although the HAT algorithm does not need this, it is still provided for API consistency for other HAT-based algorithms that inherit this forward() method of the HAT class.
class WSNMaskMLP(MLP, clarena.backbones.base.WSNMaskBackbone):
377class WSNMaskMLP(MLP, WSNMaskBackbone):
378    r"""[WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) masked multi-layer perceptron (MLP).
379
380    [WSN (Winning Subnetworks, 2022)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) is an architecture-based continual learning algorithm. It trains learnable parameter-wise importance and selects the most important $c\%$ of the network parameters to be used for each task.
381
382    MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.
383
384    The mask is applied to the weights and biases in each fully connected layer. The mask is generated from the parameter-wise score and the gate function.
385    """
386
387    def __init__(
388        self,
389        input_dim: int,
390        hidden_dims: list[int],
391        output_dim: int,
392        activation_layer: nn.Module | None = nn.ReLU,
393        bias: bool = True,
394        dropout: float | None = None,
395    ) -> None:
396        r"""Construct and initialize the WSN-masked MLP backbone network with task embeddings.
397
398        **Args:**
399        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP.
400        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
401        - **output_dim** (`int`): The output dimension that connects to CL output heads.
402        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
403        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
404        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
405        """
406        # init from both inherited classes
407        super().__init__(
408            input_dim=input_dim,
409            hidden_dims=hidden_dims,
410            output_dim=output_dim,
411            activation_layer=activation_layer,
412            batch_normalization=False,
413            bias=bias,
414            dropout=dropout,
415        )
416
417        # construct the parameter score for each weighted layer
418        for layer_idx in range(self.num_fc_layers):
419            full_layer_name = f"fc/{layer_idx}"
420            layer_input_dim = (
421                input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1]
422            )
423            layer_output_dim = (
424                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
425            )
426            self.weight_score_t[full_layer_name] = nn.Embedding(
427                num_embeddings=layer_output_dim,
428                embedding_dim=layer_input_dim,
429            )
430            self.bias_score_t[full_layer_name] = nn.Embedding(
431                num_embeddings=1,
432                embedding_dim=layer_output_dim,
433            )
434
435    def forward(
436        self,
437        input: Tensor,
438        stage: str,
439        mask_percentage: float,
440        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
441    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]:
442        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer.
443
444        **Args:**
445        - **input** (`Tensor`): The input tensor from data.
446        - **stage** (`str`): The stage of the forward pass; one of:
447            1. 'train': training stage.
448            2. 'validation': validation stage.
449            3. 'test': testing stage.
450        - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1.
451        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`.
452
453        **Returns:**
454        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
455        - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
456        - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
457        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features for various purposes.
458        """
459        batch_size = input.size(0)
460        activations = {}
461
462        weight_mask, bias_mask = self.get_mask(
463            stage,
464            mask_percentage=mask_percentage,
465            test_mask=test_mask,
466        )
467
468        x = input.view(batch_size, -1)  # flatten before going through MLP
469
470        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
471            weighted_layer = self.fc[layer_idx]
472            weight = weighted_layer.weight
473            bias = weighted_layer.bias
474
475            # mask the weight and bias
476            masked_weight = weight * weight_mask[f"fc/{layer_idx}"]
477            if bias is not None and bias_mask[f"fc/{layer_idx}"] is not None:
478                masked_bias = bias * bias_mask[f"fc/{layer_idx}"]
479            else:
480                masked_bias = None
481
482            # do the forward pass using the masked weight and bias. Do not modify the weight and bias data in the original layer object or it will lose the computation graph.
483            x = nn.functional.linear(x, masked_weight, masked_bias)
484
485            if self.activation:
486                x = self.fc_activation[layer_idx](x)  # activation function third
487            activations[layer_name] = x  # store the hidden feature
488            if self.dropout:
489                x = self.fc_dropout[layer_idx](x)  # dropout last
490
491        output_feature = x
492
493        return output_feature, weight_mask, bias_mask, activations

WSN (Winning Subnetworks, 2022) masked multi-layer perceptron (MLP).

WSN (Winning Subnetworks, 2022) is an architecture-based continual learning algorithm. It trains learnable parameter-wise importance and selects the most important $c\%$ of the network parameters to be used for each task.

MLP is a dense network architecture with several fully connected layers, each followed by an activation function. The last layer connects to the CL output heads.

The mask is applied to the weights and biases in each fully connected layer. The mask is generated from the parameter-wise score and the gate function.

WSNMaskMLP( input_dim: int, hidden_dims: list[int], output_dim: int, activation_layer: torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.ReLU'>, bias: bool = True, dropout: float | None = None)
387    def __init__(
388        self,
389        input_dim: int,
390        hidden_dims: list[int],
391        output_dim: int,
392        activation_layer: nn.Module | None = nn.ReLU,
393        bias: bool = True,
394        dropout: float | None = None,
395    ) -> None:
396        r"""Construct and initialize the WSN-masked MLP backbone network with task embeddings.
397
398        **Args:**
399        - **input_dim** (`int`): The input dimension. Any data need to be flattened before entering the MLP.
400        - **hidden_dims** (`list[int]`): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
401        - **output_dim** (`int`): The output dimension that connects to CL output heads.
402        - **activation_layer** (`nn.Module` | `None`): Activation function of each layer (if not `None`). If `None`, this layer won't be used. Default `nn.ReLU`.
403        - **bias** (`bool`): Whether to use bias in the linear layer. Default `True`.
404        - **dropout** (`float` | `None`): The probability for the dropout layer. If `None`, this layer won't be used. Default `None`.        - **kwargs**: Reserved for multiple inheritance.
405        """
406        # init from both inherited classes
407        super().__init__(
408            input_dim=input_dim,
409            hidden_dims=hidden_dims,
410            output_dim=output_dim,
411            activation_layer=activation_layer,
412            batch_normalization=False,
413            bias=bias,
414            dropout=dropout,
415        )
416
417        # construct the parameter score for each weighted layer
418        for layer_idx in range(self.num_fc_layers):
419            full_layer_name = f"fc/{layer_idx}"
420            layer_input_dim = (
421                input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1]
422            )
423            layer_output_dim = (
424                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
425            )
426            self.weight_score_t[full_layer_name] = nn.Embedding(
427                num_embeddings=layer_output_dim,
428                embedding_dim=layer_input_dim,
429            )
430            self.bias_score_t[full_layer_name] = nn.Embedding(
431                num_embeddings=1,
432                embedding_dim=layer_output_dim,
433            )

Construct and initialize the WSN-masked MLP backbone network with task embeddings.

Args:

  • input_dim (int): The input dimension. Any data need to be flattened before entering the MLP.
  • hidden_dims (list[int]): List of hidden layer dimensions. It can be an empty list, which means a single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension, which we take as the output dimension.
  • output_dim (int): The output dimension that connects to CL output heads.
  • activation_layer (nn.Module | None): Activation function of each layer (if not None). If None, this layer won't be used. Default nn.ReLU.
  • bias (bool): Whether to use bias in the linear layer. Default True.
  • dropout (float | None): The probability for the dropout layer. If None, this layer won't be used. Default None. - kwargs: Reserved for multiple inheritance.
def forward( self, input: torch.Tensor, stage: str, mask_percentage: float, test_mask: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, torch.Tensor]]:
435    def forward(
436        self,
437        input: Tensor,
438        stage: str,
439        mask_percentage: float,
440        test_mask: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
441    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor], dict[str, Tensor]]:
442        r"""The forward pass for data from task `self.task_id`. Task-specific masks for `self.task_id` are applied to units (neurons) in each fully connected layer.
443
444        **Args:**
445        - **input** (`Tensor`): The input tensor from data.
446        - **stage** (`str`): The stage of the forward pass; one of:
447            1. 'train': training stage.
448            2. 'validation': validation stage.
449            3. 'test': testing stage.
450        - **mask_percentage** (`float`): The percentage of parameters to be masked. The value should be between 0 and 1.
451        - **test_mask** (`tuple[dict[str, Tensor], dict[str, Tensor]]` | `None`): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is `None`.
452
453        **Returns:**
454        - **output_feature** (`Tensor`): The output feature tensor to be passed into heads. This is the main target of backpropagation.
455        - **weight_mask** (`dict[str, Tensor]`): The weight mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
456        - **bias_mask** (`dict[str, Tensor]`): The bias mask for the current task. Keys (`str`) are the layer names and values (`Tensor`) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is `None`.
457        - **activations** (`dict[str, Tensor]`): The hidden features (after activation) in each weighted layer. Keys (`str`) are the weighted layer names and values (`Tensor`) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features for various purposes.
458        """
459        batch_size = input.size(0)
460        activations = {}
461
462        weight_mask, bias_mask = self.get_mask(
463            stage,
464            mask_percentage=mask_percentage,
465            test_mask=test_mask,
466        )
467
468        x = input.view(batch_size, -1)  # flatten before going through MLP
469
470        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
471            weighted_layer = self.fc[layer_idx]
472            weight = weighted_layer.weight
473            bias = weighted_layer.bias
474
475            # mask the weight and bias
476            masked_weight = weight * weight_mask[f"fc/{layer_idx}"]
477            if bias is not None and bias_mask[f"fc/{layer_idx}"] is not None:
478                masked_bias = bias * bias_mask[f"fc/{layer_idx}"]
479            else:
480                masked_bias = None
481
482            # do the forward pass using the masked weight and bias. Do not modify the weight and bias data in the original layer object or it will lose the computation graph.
483            x = nn.functional.linear(x, masked_weight, masked_bias)
484
485            if self.activation:
486                x = self.fc_activation[layer_idx](x)  # activation function third
487            activations[layer_name] = x  # store the hidden feature
488            if self.dropout:
489                x = self.fc_dropout[layer_idx](x)  # dropout last
490
491        output_feature = x
492
493        return output_feature, weight_mask, bias_mask, activations

The forward pass for data from task self.task_id. Task-specific masks for self.task_id are applied to units (neurons) in each fully connected layer.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): The stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • mask_percentage (float): The percentage of parameters to be masked. The value should be between 0 and 1.
  • test_mask (tuple[dict[str, Tensor], dict[str, Tensor]] | None): The binary weight and bias masks used for testing. Applies only to the testing stage. For other stages, it is None.

Returns:

  • output_feature (Tensor): The output feature tensor to be passed into heads. This is the main target of backpropagation.
  • weight_mask (dict[str, Tensor]): The weight mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, input features) as the weight.
  • bias_mask (dict[str, Tensor]): The bias mask for the current task. Keys (str) are the layer names and values (Tensor) are the mask tensors. The mask tensor has the same size (output features, ) as the bias. If the layer doesn't have a bias, it is None.
  • activations (dict[str, Tensor]): The hidden features (after activation) in each weighted layer. Keys (str) are the weighted layer names and values (Tensor) are the hidden feature tensors. This is used for continual learning algorithms that need hidden features for various purposes.