clarena.backbones.mlp

The submodule in backbones for MLP backbone network.

  1r"""
  2The submodule in `backbones` for MLP backbone network.
  3"""
  4
  5__all__ = ["MLP", "HATMaskMLP"]
  6
  7from torch import Tensor, nn
  8
  9from clarena.backbones import CLBackbone, HATMaskBackbone
 10
 11
 12class MLP(CLBackbone):
 13    """**Multi-layer perceptron (MLP)** a.k.a. fully-connected network.
 14
 15    MLP is an dense network architecture, which has several fully-connected layers, each followed by an activation function. The last layer connects to the CL output heads.
 16    """
 17
 18    def __init__(
 19        self,
 20        input_dim: int,
 21        hidden_dims: list[int],
 22        output_dim: int,
 23        activation_layer: nn.Module | None = nn.ReLU,
 24        batch_normalisation: bool = False,
 25        bias: bool = True,
 26        dropout: float | None = None,
 27    ) -> None:
 28        r"""Construct and initialise the MLP backbone network.
 29
 30        **Args:**
 31        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks.
 32        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
 33        - **output_dim** (`int`): the output dimension which connects to CL output heads.
 34        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
 35        - **batch_normalisation** (`bool`): whether to use batch normalisation after the fully-connected layers. Default `False`.
 36        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
 37        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.
 38        """
 39        CLBackbone.__init__(self, output_dim=output_dim)
 40
 41        self.num_fc_layers: int = len(hidden_dims) + 1
 42        r"""Store the number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass."""
 43        self.batch_normalisation: bool = batch_normalisation
 44        r"""Store whether to use batch normalisation after the fully-connected layers."""
 45        self.activation: bool = activation_layer is not None
 46        r"""Store whether to use activation function after the fully-connected layers."""
 47        self.dropout: bool = dropout is not None
 48        r"""Store whether to use dropout after the fully-connected layers."""
 49
 50        self.fc: nn.ModuleList = nn.ModuleList()
 51        r"""The list of fully-connected (`nn.Linear`) layers. """
 52        if self.batch_normalisation:
 53            self.fc_bn: nn.ModuleList = nn.ModuleList()
 54            r"""The list of batch normalisation (`nn.BatchNorm1d`) layers after the fully-connected layers."""
 55        if self.activation:
 56            self.fc_activation: nn.ModuleList = nn.ModuleList()
 57            r"""The list of activation layers after the fully-connected layers. """
 58        if self.dropout:
 59            self.fc_dropout: nn.ModuleList = nn.ModuleList()
 60            r"""The list of dropout layers after the fully-connected layers. """
 61
 62        # construct the weighted fully-connected layers and attached layers (batchnorm, activation, dropout, etc) in a loop
 63        for layer_idx in range(self.num_fc_layers):
 64
 65            # the input and output dim of the current weighted layer
 66            layer_input_dim = (
 67                input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1]
 68            )
 69            layer_output_dim = (
 70                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
 71            )
 72
 73            # construct the fully connected layer
 74            self.fc.append(
 75                nn.Linear(
 76                    in_features=layer_input_dim,
 77                    out_features=layer_output_dim,
 78                    bias=bias,
 79                )
 80            )
 81
 82            # update the weighted layer names
 83            full_layer_name = f"fc/{layer_idx}"
 84            self.weighted_layer_names.append(full_layer_name)
 85
 86            # construct the batch normalisation layer
 87            if self.batch_normalisation:
 88                self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim)))
 89
 90            # construct the activation layer
 91            if self.activation:
 92                self.fc_activation.append(activation_layer())
 93
 94            # construct the dropout layer
 95            if self.dropout:
 96                self.fc_dropout.append(nn.Dropout(dropout))
 97
 98    def forward(
 99        self, input: Tensor, stage: str = None, task_id: int | None = None
100    ) -> tuple[Tensor, dict[str, Tensor]]:
101        r"""The forward pass for data. It is the same for all tasks.
102
103        **Args:**
104        - **input** (`Tensor`): the input tensor from data.
105
106        **Returns:**
107        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
108        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
109        """
110        batch_size = input.size(0)
111        hidden_features = {}
112
113        x = input.view(batch_size, -1)  # flatten before going through MLP
114
115        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
116            x = self.fc[layer_idx](x)  # fully-connected layer first
117            if self.batch_normalisation:
118                x = self.fc_bn[layer_idx](x)  # batch normalisation second
119            if self.activation:
120                x = self.fc_activation[layer_idx](x)  # activation function third
121            hidden_features[layer_name] = x  # store the hidden feature
122            if self.dropout:
123                x = self.fc_dropout[layer_idx](x)  # dropout last
124
125        output_feature = x
126
127        return output_feature, hidden_features
128
129
130class HATMaskMLP(MLP, HATMaskBackbone):
131    r"""HAT masked multi-Layer perceptron (MLP).
132
133    [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.
134
135    MLP is a dense network architecture, which has several fully-connected layers, each followed by an activation function. The last layer connects to the CL output heads.
136
137    Mask is applied to the units which are neurons in each fully-connected layer. The mask is generated from the unit-wise task embedding and gate function.
138    """
139
140    def __init__(
141        self,
142        input_dim: int,
143        hidden_dims: list[int],
144        output_dim: int,
145        gate: str,
146        activation_layer: nn.Module | None = nn.ReLU,
147        bias: bool = True,
148        dropout: float | None = None,
149    ) -> None:
150        r"""Construct and initialise the HAT masked MLP backbone network with task embedding. Note that batch normalisation is incompatible with HAT mechanism.
151
152        **Args:**
153        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP.
154        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
155        - **output_dim** (`int`): the output dimension which connects to CL output heads.
156        - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
157            - `sigmoid`: the sigmoid function.
158        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
159        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
160        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.
161        """
162        # init from both inherited classes
163        HATMaskBackbone.__init__(self, output_dim=output_dim, gate=gate)
164        MLP.__init__(
165            self,
166            input_dim=input_dim,
167            hidden_dims=hidden_dims,
168            output_dim=output_dim,
169            activation_layer=activation_layer,
170            batch_normalisation=False,  # batch normalisation is incompatible with HAT mechanism
171            bias=bias,
172            dropout=dropout,
173        )
174        self.register_hat_mask_module_explicitly(
175            gate=gate
176        )  # register all `nn.Module`s for HATMaskBackbone explicitly because the second `__init__()` wipes out them inited by the first `__init__()`
177
178        # construct the task embedding for each weighted layer
179        for layer_idx in range(self.num_fc_layers):
180            full_layer_name = f"fc/{layer_idx}"
181            layer_output_dim = (
182                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
183            )
184            self.task_embedding_t[full_layer_name] = nn.Embedding(
185                num_embeddings=1, embedding_dim=layer_output_dim
186            )
187
188    def forward(
189        self,
190        input: Tensor,
191        stage: str,
192        s_max: float | None = None,
193        batch_idx: int | None = None,
194        num_batches: int | None = None,
195        test_mask: dict[str, Tensor] | None = None,
196    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
197        r"""The forward pass for data from task `task_id`. Task-specific mask for `task_id` are applied to the units which are neurons in each fully-connected layer.
198
199        **Args:**
200        - **input** (`Tensor`): the input tensor from data.
201        - **stage** (`str`): the stage of the forward pass, should be one of the following:
202            1. 'train': training stage.
203            2. 'validation': validation stage.
204            3. 'test': testing stage.
205        - **s_max** (`float` | `None`): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
206        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
207        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
208        - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`.
209
210        **Returns:**
211        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
212        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
213        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
214        """
215        batch_size = input.size(0)
216        hidden_features = {}
217
218        # get the mask for the current task from the task embedding in this stage
219        mask = self.get_mask(
220            stage=stage,
221            s_max=s_max,
222            batch_idx=batch_idx,
223            num_batches=num_batches,
224            test_mask=test_mask,
225        )
226
227        x = input.view(batch_size, -1)  # flatten before going through MLP
228
229        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
230            x = self.fc[layer_idx](x)  # fully-connected layer first
231            x = x * mask[f"fc/{layer_idx}"]  # apply the mask to the parameters second
232            if self.activation:
233                x = self.fc_activation[layer_idx](x)  # activation function third
234            hidden_features[layer_name] = x  # store the hidden feature
235            if self.dropout:
236                x = self.fc_dropout[layer_idx](x)  # dropout last
237
238        output_feature = x
239
240        return output_feature, mask, hidden_features
class MLP(clarena.backbones.base.CLBackbone):
 13class MLP(CLBackbone):
 14    """**Multi-layer perceptron (MLP)** a.k.a. fully-connected network.
 15
 16    MLP is an dense network architecture, which has several fully-connected layers, each followed by an activation function. The last layer connects to the CL output heads.
 17    """
 18
 19    def __init__(
 20        self,
 21        input_dim: int,
 22        hidden_dims: list[int],
 23        output_dim: int,
 24        activation_layer: nn.Module | None = nn.ReLU,
 25        batch_normalisation: bool = False,
 26        bias: bool = True,
 27        dropout: float | None = None,
 28    ) -> None:
 29        r"""Construct and initialise the MLP backbone network.
 30
 31        **Args:**
 32        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks.
 33        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
 34        - **output_dim** (`int`): the output dimension which connects to CL output heads.
 35        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
 36        - **batch_normalisation** (`bool`): whether to use batch normalisation after the fully-connected layers. Default `False`.
 37        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
 38        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.
 39        """
 40        CLBackbone.__init__(self, output_dim=output_dim)
 41
 42        self.num_fc_layers: int = len(hidden_dims) + 1
 43        r"""Store the number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass."""
 44        self.batch_normalisation: bool = batch_normalisation
 45        r"""Store whether to use batch normalisation after the fully-connected layers."""
 46        self.activation: bool = activation_layer is not None
 47        r"""Store whether to use activation function after the fully-connected layers."""
 48        self.dropout: bool = dropout is not None
 49        r"""Store whether to use dropout after the fully-connected layers."""
 50
 51        self.fc: nn.ModuleList = nn.ModuleList()
 52        r"""The list of fully-connected (`nn.Linear`) layers. """
 53        if self.batch_normalisation:
 54            self.fc_bn: nn.ModuleList = nn.ModuleList()
 55            r"""The list of batch normalisation (`nn.BatchNorm1d`) layers after the fully-connected layers."""
 56        if self.activation:
 57            self.fc_activation: nn.ModuleList = nn.ModuleList()
 58            r"""The list of activation layers after the fully-connected layers. """
 59        if self.dropout:
 60            self.fc_dropout: nn.ModuleList = nn.ModuleList()
 61            r"""The list of dropout layers after the fully-connected layers. """
 62
 63        # construct the weighted fully-connected layers and attached layers (batchnorm, activation, dropout, etc) in a loop
 64        for layer_idx in range(self.num_fc_layers):
 65
 66            # the input and output dim of the current weighted layer
 67            layer_input_dim = (
 68                input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1]
 69            )
 70            layer_output_dim = (
 71                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
 72            )
 73
 74            # construct the fully connected layer
 75            self.fc.append(
 76                nn.Linear(
 77                    in_features=layer_input_dim,
 78                    out_features=layer_output_dim,
 79                    bias=bias,
 80                )
 81            )
 82
 83            # update the weighted layer names
 84            full_layer_name = f"fc/{layer_idx}"
 85            self.weighted_layer_names.append(full_layer_name)
 86
 87            # construct the batch normalisation layer
 88            if self.batch_normalisation:
 89                self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim)))
 90
 91            # construct the activation layer
 92            if self.activation:
 93                self.fc_activation.append(activation_layer())
 94
 95            # construct the dropout layer
 96            if self.dropout:
 97                self.fc_dropout.append(nn.Dropout(dropout))
 98
 99    def forward(
100        self, input: Tensor, stage: str = None, task_id: int | None = None
101    ) -> tuple[Tensor, dict[str, Tensor]]:
102        r"""The forward pass for data. It is the same for all tasks.
103
104        **Args:**
105        - **input** (`Tensor`): the input tensor from data.
106
107        **Returns:**
108        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
109        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
110        """
111        batch_size = input.size(0)
112        hidden_features = {}
113
114        x = input.view(batch_size, -1)  # flatten before going through MLP
115
116        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
117            x = self.fc[layer_idx](x)  # fully-connected layer first
118            if self.batch_normalisation:
119                x = self.fc_bn[layer_idx](x)  # batch normalisation second
120            if self.activation:
121                x = self.fc_activation[layer_idx](x)  # activation function third
122            hidden_features[layer_name] = x  # store the hidden feature
123            if self.dropout:
124                x = self.fc_dropout[layer_idx](x)  # dropout last
125
126        output_feature = x
127
128        return output_feature, hidden_features

Multi-layer perceptron (MLP) a.k.a. fully-connected network.

MLP is an dense network architecture, which has several fully-connected layers, each followed by an activation function. The last layer connects to the CL output heads.

MLP( input_dim: int, hidden_dims: list[int], output_dim: int, activation_layer: torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.ReLU'>, batch_normalisation: bool = False, bias: bool = True, dropout: float | None = None)
19    def __init__(
20        self,
21        input_dim: int,
22        hidden_dims: list[int],
23        output_dim: int,
24        activation_layer: nn.Module | None = nn.ReLU,
25        batch_normalisation: bool = False,
26        bias: bool = True,
27        dropout: float | None = None,
28    ) -> None:
29        r"""Construct and initialise the MLP backbone network.
30
31        **Args:**
32        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks.
33        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
34        - **output_dim** (`int`): the output dimension which connects to CL output heads.
35        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
36        - **batch_normalisation** (`bool`): whether to use batch normalisation after the fully-connected layers. Default `False`.
37        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
38        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.
39        """
40        CLBackbone.__init__(self, output_dim=output_dim)
41
42        self.num_fc_layers: int = len(hidden_dims) + 1
43        r"""Store the number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass."""
44        self.batch_normalisation: bool = batch_normalisation
45        r"""Store whether to use batch normalisation after the fully-connected layers."""
46        self.activation: bool = activation_layer is not None
47        r"""Store whether to use activation function after the fully-connected layers."""
48        self.dropout: bool = dropout is not None
49        r"""Store whether to use dropout after the fully-connected layers."""
50
51        self.fc: nn.ModuleList = nn.ModuleList()
52        r"""The list of fully-connected (`nn.Linear`) layers. """
53        if self.batch_normalisation:
54            self.fc_bn: nn.ModuleList = nn.ModuleList()
55            r"""The list of batch normalisation (`nn.BatchNorm1d`) layers after the fully-connected layers."""
56        if self.activation:
57            self.fc_activation: nn.ModuleList = nn.ModuleList()
58            r"""The list of activation layers after the fully-connected layers. """
59        if self.dropout:
60            self.fc_dropout: nn.ModuleList = nn.ModuleList()
61            r"""The list of dropout layers after the fully-connected layers. """
62
63        # construct the weighted fully-connected layers and attached layers (batchnorm, activation, dropout, etc) in a loop
64        for layer_idx in range(self.num_fc_layers):
65
66            # the input and output dim of the current weighted layer
67            layer_input_dim = (
68                input_dim if layer_idx == 0 else hidden_dims[layer_idx - 1]
69            )
70            layer_output_dim = (
71                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
72            )
73
74            # construct the fully connected layer
75            self.fc.append(
76                nn.Linear(
77                    in_features=layer_input_dim,
78                    out_features=layer_output_dim,
79                    bias=bias,
80                )
81            )
82
83            # update the weighted layer names
84            full_layer_name = f"fc/{layer_idx}"
85            self.weighted_layer_names.append(full_layer_name)
86
87            # construct the batch normalisation layer
88            if self.batch_normalisation:
89                self.fc_bn.append(nn.BatchNorm1d(num_features=(layer_output_dim)))
90
91            # construct the activation layer
92            if self.activation:
93                self.fc_activation.append(activation_layer())
94
95            # construct the dropout layer
96            if self.dropout:
97                self.fc_dropout.append(nn.Dropout(dropout))

Construct and initialise the MLP backbone network.

Args:

  • input_dim (int): the input dimension. Any data need to be flattened before going in MLP. Note that it is not required in convolutional networks.
  • hidden_dims (list[int]): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
  • output_dim (int): the output dimension which connects to CL output heads.
  • activation_layer (nn.Module | None): activation function of each layer (if not None), if None this layer won't be used. Default nn.ReLU.
  • batch_normalisation (bool): whether to use batch normalisation after the fully-connected layers. Default False.
  • bias (bool): whether to use bias in the linear layer. Default True.
  • dropout (float | None): the probability for the dropout layer, if None this layer won't be used. Default None.
num_fc_layers: int

Store the number of fully-connected layers in the MLP backbone network, which helps form the loops in constructing layers and forward pass.

batch_normalisation: bool

Store whether to use batch normalisation after the fully-connected layers.

activation: bool

Store whether to use activation function after the fully-connected layers.

dropout: bool

Store whether to use dropout after the fully-connected layers.

fc: torch.nn.modules.container.ModuleList

The list of fully-connected (nn.Linear) layers.

def forward( self, input: torch.Tensor, stage: str = None, task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
 99    def forward(
100        self, input: Tensor, stage: str = None, task_id: int | None = None
101    ) -> tuple[Tensor, dict[str, Tensor]]:
102        r"""The forward pass for data. It is the same for all tasks.
103
104        **Args:**
105        - **input** (`Tensor`): the input tensor from data.
106
107        **Returns:**
108        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
109        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
110        """
111        batch_size = input.size(0)
112        hidden_features = {}
113
114        x = input.view(batch_size, -1)  # flatten before going through MLP
115
116        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
117            x = self.fc[layer_idx](x)  # fully-connected layer first
118            if self.batch_normalisation:
119                x = self.fc_bn[layer_idx](x)  # batch normalisation second
120            if self.activation:
121                x = self.fc_activation[layer_idx](x)  # activation function third
122            hidden_features[layer_name] = x  # store the hidden feature
123            if self.dropout:
124                x = self.fc_dropout[layer_idx](x)  # dropout last
125
126        output_feature = x
127
128        return output_feature, hidden_features

The forward pass for data. It is the same for all tasks.

Args:

  • input (Tensor): the input tensor from data.

Returns:

  • output_feature (Tensor): the output feature tensor to be passed into heads. This is the main target of backpropagation.
  • hidden_features (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
class HATMaskMLP(MLP, clarena.backbones.base.HATMaskBackbone):
131class HATMaskMLP(MLP, HATMaskBackbone):
132    r"""HAT masked multi-Layer perceptron (MLP).
133
134    [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.
135
136    MLP is a dense network architecture, which has several fully-connected layers, each followed by an activation function. The last layer connects to the CL output heads.
137
138    Mask is applied to the units which are neurons in each fully-connected layer. The mask is generated from the unit-wise task embedding and gate function.
139    """
140
141    def __init__(
142        self,
143        input_dim: int,
144        hidden_dims: list[int],
145        output_dim: int,
146        gate: str,
147        activation_layer: nn.Module | None = nn.ReLU,
148        bias: bool = True,
149        dropout: float | None = None,
150    ) -> None:
151        r"""Construct and initialise the HAT masked MLP backbone network with task embedding. Note that batch normalisation is incompatible with HAT mechanism.
152
153        **Args:**
154        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP.
155        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
156        - **output_dim** (`int`): the output dimension which connects to CL output heads.
157        - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
158            - `sigmoid`: the sigmoid function.
159        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
160        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
161        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.
162        """
163        # init from both inherited classes
164        HATMaskBackbone.__init__(self, output_dim=output_dim, gate=gate)
165        MLP.__init__(
166            self,
167            input_dim=input_dim,
168            hidden_dims=hidden_dims,
169            output_dim=output_dim,
170            activation_layer=activation_layer,
171            batch_normalisation=False,  # batch normalisation is incompatible with HAT mechanism
172            bias=bias,
173            dropout=dropout,
174        )
175        self.register_hat_mask_module_explicitly(
176            gate=gate
177        )  # register all `nn.Module`s for HATMaskBackbone explicitly because the second `__init__()` wipes out them inited by the first `__init__()`
178
179        # construct the task embedding for each weighted layer
180        for layer_idx in range(self.num_fc_layers):
181            full_layer_name = f"fc/{layer_idx}"
182            layer_output_dim = (
183                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
184            )
185            self.task_embedding_t[full_layer_name] = nn.Embedding(
186                num_embeddings=1, embedding_dim=layer_output_dim
187            )
188
189    def forward(
190        self,
191        input: Tensor,
192        stage: str,
193        s_max: float | None = None,
194        batch_idx: int | None = None,
195        num_batches: int | None = None,
196        test_mask: dict[str, Tensor] | None = None,
197    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
198        r"""The forward pass for data from task `task_id`. Task-specific mask for `task_id` are applied to the units which are neurons in each fully-connected layer.
199
200        **Args:**
201        - **input** (`Tensor`): the input tensor from data.
202        - **stage** (`str`): the stage of the forward pass, should be one of the following:
203            1. 'train': training stage.
204            2. 'validation': validation stage.
205            3. 'test': testing stage.
206        - **s_max** (`float` | `None`): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
207        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
208        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
209        - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`.
210
211        **Returns:**
212        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
213        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
214        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
215        """
216        batch_size = input.size(0)
217        hidden_features = {}
218
219        # get the mask for the current task from the task embedding in this stage
220        mask = self.get_mask(
221            stage=stage,
222            s_max=s_max,
223            batch_idx=batch_idx,
224            num_batches=num_batches,
225            test_mask=test_mask,
226        )
227
228        x = input.view(batch_size, -1)  # flatten before going through MLP
229
230        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
231            x = self.fc[layer_idx](x)  # fully-connected layer first
232            x = x * mask[f"fc/{layer_idx}"]  # apply the mask to the parameters second
233            if self.activation:
234                x = self.fc_activation[layer_idx](x)  # activation function third
235            hidden_features[layer_name] = x  # store the hidden feature
236            if self.dropout:
237                x = self.fc_dropout[layer_idx](x)  # dropout last
238
239        output_feature = x
240
241        return output_feature, mask, hidden_features

HAT masked multi-Layer perceptron (MLP).

HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.

MLP is a dense network architecture, which has several fully-connected layers, each followed by an activation function. The last layer connects to the CL output heads.

Mask is applied to the units which are neurons in each fully-connected layer. The mask is generated from the unit-wise task embedding and gate function.

HATMaskMLP( input_dim: int, hidden_dims: list[int], output_dim: int, gate: str, activation_layer: torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.ReLU'>, bias: bool = True, dropout: float | None = None)
141    def __init__(
142        self,
143        input_dim: int,
144        hidden_dims: list[int],
145        output_dim: int,
146        gate: str,
147        activation_layer: nn.Module | None = nn.ReLU,
148        bias: bool = True,
149        dropout: float | None = None,
150    ) -> None:
151        r"""Construct and initialise the HAT masked MLP backbone network with task embedding. Note that batch normalisation is incompatible with HAT mechanism.
152
153        **Args:**
154        - **input_dim** (`int`): the input dimension. Any data need to be flattened before going in MLP.
155        - **hidden_dims** (`list[int]`): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
156        - **output_dim** (`int`): the output dimension which connects to CL output heads.
157        - **gate** (`str`): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
158            - `sigmoid`: the sigmoid function.
159        - **activation_layer** (`nn.Module` | `None`): activation function of each layer (if not `None`), if `None` this layer won't be used. Default `nn.ReLU`.
160        - **bias** (`bool`): whether to use bias in the linear layer. Default `True`.
161        - **dropout** (`float` | `None`): the probability for the dropout layer, if `None` this layer won't be used. Default `None`.
162        """
163        # init from both inherited classes
164        HATMaskBackbone.__init__(self, output_dim=output_dim, gate=gate)
165        MLP.__init__(
166            self,
167            input_dim=input_dim,
168            hidden_dims=hidden_dims,
169            output_dim=output_dim,
170            activation_layer=activation_layer,
171            batch_normalisation=False,  # batch normalisation is incompatible with HAT mechanism
172            bias=bias,
173            dropout=dropout,
174        )
175        self.register_hat_mask_module_explicitly(
176            gate=gate
177        )  # register all `nn.Module`s for HATMaskBackbone explicitly because the second `__init__()` wipes out them inited by the first `__init__()`
178
179        # construct the task embedding for each weighted layer
180        for layer_idx in range(self.num_fc_layers):
181            full_layer_name = f"fc/{layer_idx}"
182            layer_output_dim = (
183                hidden_dims[layer_idx] if layer_idx != len(hidden_dims) else output_dim
184            )
185            self.task_embedding_t[full_layer_name] = nn.Embedding(
186                num_embeddings=1, embedding_dim=layer_output_dim
187            )

Construct and initialise the HAT masked MLP backbone network with task embedding. Note that batch normalisation is incompatible with HAT mechanism.

Args:

  • input_dim (int): the input dimension. Any data need to be flattened before going in MLP.
  • hidden_dims (list[int]): list of hidden layer dimensions. It can be empty list which means single-layer MLP, and it can be as many layers as you want. Note that it doesn't include the last dimension which we take as output dimension.
  • output_dim (int): the output dimension which connects to CL output heads.
  • gate (str): the type of gate function turning the real value task embeddings into attention masks, should be one of the following:
    • sigmoid: the sigmoid function.
  • activation_layer (nn.Module | None): activation function of each layer (if not None), if None this layer won't be used. Default nn.ReLU.
  • bias (bool): whether to use bias in the linear layer. Default True.
  • dropout (float | None): the probability for the dropout layer, if None this layer won't be used. Default None.
def forward( self, input: torch.Tensor, stage: str, s_max: float | None = None, batch_idx: int | None = None, num_batches: int | None = None, test_mask: dict[str, torch.Tensor] | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor]]:
189    def forward(
190        self,
191        input: Tensor,
192        stage: str,
193        s_max: float | None = None,
194        batch_idx: int | None = None,
195        num_batches: int | None = None,
196        test_mask: dict[str, Tensor] | None = None,
197    ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor]]:
198        r"""The forward pass for data from task `task_id`. Task-specific mask for `task_id` are applied to the units which are neurons in each fully-connected layer.
199
200        **Args:**
201        - **input** (`Tensor`): the input tensor from data.
202        - **stage** (`str`): the stage of the forward pass, should be one of the following:
203            1. 'train': training stage.
204            2. 'validation': validation stage.
205            3. 'test': testing stage.
206        - **s_max** (`float` | `None`): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
207        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
208        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
209        - **test_mask** (`dict[str, Tensor]` | `None`): the binary mask used for test. Applies only to testing stage. For other stages, it is default `None`.
210
211        **Returns:**
212        - **output_feature** (`Tensor`): the output feature tensor to be passed into heads. This is the main target of backpropagation.
213        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
214        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
215        """
216        batch_size = input.size(0)
217        hidden_features = {}
218
219        # get the mask for the current task from the task embedding in this stage
220        mask = self.get_mask(
221            stage=stage,
222            s_max=s_max,
223            batch_idx=batch_idx,
224            num_batches=num_batches,
225            test_mask=test_mask,
226        )
227
228        x = input.view(batch_size, -1)  # flatten before going through MLP
229
230        for layer_idx, layer_name in enumerate(self.weighted_layer_names):
231            x = self.fc[layer_idx](x)  # fully-connected layer first
232            x = x * mask[f"fc/{layer_idx}"]  # apply the mask to the parameters second
233            if self.activation:
234                x = self.fc_activation[layer_idx](x)  # activation function third
235            hidden_features[layer_name] = x  # store the hidden feature
236            if self.dropout:
237                x = self.fc_dropout[layer_idx](x)  # dropout last
238
239        output_feature = x
240
241        return output_feature, mask, hidden_features

The forward pass for data from task task_id. Task-specific mask for task_id are applied to the units which are neurons in each fully-connected layer.

Args:

  • input (Tensor): the input tensor from data.
  • stage (str): the stage of the forward pass, should be one of the following:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • s_max (float | None): the maximum scaling factor in the gate function. Doesn't apply to testing stage. See chapter 2.4 "Hard Attention Training" in HAT paper.
  • batch_idx (int | None): the current batch index. Applies only to training stage. For other stages, it is default None.
  • num_batches (int | None): the total number of batches. Applies only to training stage. For other stages, it is default None.
  • test_mask (dict[str, Tensor] | None): the binary mask used for test. Applies only to testing stage. For other stages, it is default None.

Returns:

  • output_feature (Tensor): the output feature tensor to be passed into heads. This is the main target of backpropagation.
  • mask (dict[str, Tensor]): the mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units).
  • hidden_features (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this forward() method of HAT class.