clarena.cl_algorithms.hat

The submodule in cl_algorithms for HAT (Hard Attention to the Task) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [HAT (Hard Attention to the Task) algorithm](http://proceedings.mlr.press/v80/serra18a).
  3"""
  4
  5__all__ = ["HAT"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12from torch.utils.data import DataLoader
 13
 14from clarena.backbones import HATMaskBackbone
 15from clarena.cl_algorithms import CLAlgorithm
 16from clarena.cl_algorithms.regularisers import HATMaskSparsityReg
 17from clarena.cl_heads import HeadsCIL, HeadsTIL
 18from clarena.utils import HATNetworkCapacity
 19
 20# always get logger for built-in logging in each module
 21pylogger = logging.getLogger(__name__)
 22
 23
 24class HAT(CLAlgorithm):
 25    r"""HAT (Hard Attention to the Task) algorithm.
 26
 27    [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.
 28
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: HATMaskBackbone,
 34        heads: HeadsTIL | HeadsCIL,
 35        adjustment_mode: str,
 36        s_max: float,
 37        clamp_threshold: float,
 38        mask_sparsity_reg_factor: float,
 39        mask_sparsity_reg_mode: str = "original",
 40        task_embedding_init_mode: str = "N01",
 41        alpha: float | None = None,
 42    ) -> None:
 43        r"""Initialise the HAT algorithm with the network.
 44
 45        **Args:**
 46        - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism.
 47        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 48        - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:
 49            1. 'hat': set the gradients of parameters linking to masked units to zero. This is the way that HAT does, which fixes the part of network for previous tasks completely. See equation (2) in chapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 50            2. 'hat_random': set the gradients of parameters linking to masked units to random 0-1 values. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 51            3. 'hat_const_alpha': set the gradients of parameters linking to masked units to a constant value of `alpha`. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52            4. 'hat_const_1': set the gradients of parameters linking to masked units to a constant value of 1, which means no gradient constraint on any parameter at all. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 54        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 55        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity.
 56        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
 57            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
 58            2. 'cross': the cross version mask sparsity regularisation.
 59        - **task_embedding_init_mode** (`str`): the initialisation mode for task embeddings, should be one of the following:
 60            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 61            2. 'U-11': uniform distribution $U(-1, 1)$.
 62            3. 'U01': uniform distribution $U(0, 1)$.
 63            4. 'U-10': uniform distribution $U(-1, 0)$.
 64            5. 'last': inherit task embedding from last task.
 65        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). It applies only when adjustment_mode is 'hat_const_alpha'.
 66        """
 67        CLAlgorithm.__init__(self, backbone=backbone, heads=heads)
 68
 69        self.adjustment_mode = adjustment_mode
 70        r"""Store the adjustment mode for gradient clipping."""
 71        self.s_max = s_max
 72        r"""Store s_max. """
 73        self.clamp_threshold = clamp_threshold
 74        r"""Store the clamp threshold for task embedding gradient compensation."""
 75        self.mask_sparsity_reg_factor = mask_sparsity_reg_factor
 76        r"""Store the mask sparsity regularisation factor."""
 77        self.mask_sparsity_reg_mode = mask_sparsity_reg_mode
 78        r"""Store the mask sparsity regularisation mode."""
 79        self.mark_sparsity_reg = HATMaskSparsityReg(
 80            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
 81        )
 82        r"""Initialise and store the mask sparsity regulariser."""
 83        self.task_embedding_init_mode = task_embedding_init_mode
 84        r"""Store the task embedding initialisation mode."""
 85        self.alpha = alpha if adjustment_mode == "hat_const_alpha" else None
 86        r"""Store the alpha for `hat_const_alpha`."""
 87        self.epsilon = None
 88        r"""HAT doesn't use the epsilon for `hat_const_alpha`. We still set it here to be consistent with the `epsilon` in `clip_grad_by_adjustment()` method in `HATMaskBackbone`."""
 89
 90        self.masks: dict[str, dict[str, Tensor]] = {}
 91        r"""Store the binary attention mask of each previous task gated from the task embedding. Keys are task IDs (string type) and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
 92
 93        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
 94        r"""Store the cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding cumulative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
 95
 96        # set manual optimisation
 97        self.automatic_optimization = False
 98
 99        HAT.sanity_check(self)
100
101    def sanity_check(self) -> None:
102        r"""Check the sanity of the arguments.
103
104        **Raises:**
105        - **ValueError**: when backbone is not designed for HAT, or the `mask_sparsity_reg_mode` or `task_embedding_init_mode` is not one of the valid options. Also, if `alpha` is not given when `adjustment_mode` is 'hat_const_alpha'.
106        """
107        if not isinstance(self.backbone, HATMaskBackbone):
108            raise ValueError("The backbone should be an instance of HATMaskBackbone.")
109
110        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
111            raise ValueError(
112                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
113            )
114        if self.task_embedding_init_mode not in [
115            "N01",
116            "U01",
117            "U-10",
118            "masked",
119            "unmasked",
120        ]:
121            raise ValueError(
122                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
123            )
124
125        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
126            raise ValueError(
127                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
128            )
129
130    def on_train_start(self) -> None:
131        r"""Initialise the task embedding before training the next task and initialise the cumulative mask at the beginning of first task."""
132
133        self.backbone.initialise_task_embedding(mode=self.task_embedding_init_mode)
134
135        # initialise the cumulative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
136        if self.task_id == 1:
137            for layer_name in self.backbone.weighted_layer_names:
138                layer = self.backbone.get_layer_by_name(
139                    layer_name
140                )  # get the layer by its name
141                num_units = layer.weight.shape[0]
142
143                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
144                    num_units
145                ).to(
146                    self.device
147                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialised as zeros mask ($t = 1$). See equation (2) in chapter 3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or equation (5) in chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
148
149    def clip_grad_by_adjustment(
150        self,
151        **kwargs,
152    ) -> Tensor:
153        r"""Clip the gradients by the adjustment rate.
154
155        Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.
156
157        Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
158
159
160        **Returns:**
161        - **capacity** (`Tensor`): the calculated network capacity.
162        """
163
164        # initialise network capacity metric
165        capacity = HATNetworkCapacity()
166
167        # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists)
168        for layer_name in self.backbone.weighted_layer_names:
169
170            layer = self.backbone.get_layer_by_name(
171                layer_name
172            )  # get the layer by its name
173
174            # placeholder for the adjustment rate to avoid the error of using it before assignment
175            adjustment_rate_weight = 1
176            adjustment_rate_bias = 1
177
178            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
179                unit_wise_measure=self.cumulative_mask_for_previous_tasks,
180                layer_name=layer_name,
181                aggregation="min",
182            )
183
184            if self.adjustment_mode == "hat":
185                adjustment_rate_weight = 1 - weight_mask
186                adjustment_rate_bias = 1 - bias_mask
187
188            elif self.adjustment_mode == "hat_random":
189                adjustment_rate_weight = torch.rand_like(weight_mask) * weight_mask + (
190                    1 - weight_mask
191                )
192                adjustment_rate_bias = torch.rand_like(bias_mask) * bias_mask + (
193                    1 - bias_mask
194                )
195
196            elif self.adjustment_mode == "hat_const_alpha":
197                adjustment_rate_weight = self.alpha * torch.ones_like(
198                    weight_mask
199                ) * weight_mask + (1 - weight_mask)
200                adjustment_rate_bias = self.alpha * torch.ones_like(
201                    bias_mask
202                ) * bias_mask + (1 - bias_mask)
203
204            elif self.adjustment_mode == "hat_const_1":
205                adjustment_rate_weight = torch.ones_like(weight_mask) * weight_mask + (
206                    1 - weight_mask
207                )
208                adjustment_rate_bias = torch.ones_like(bias_mask) * bias_mask + (
209                    1 - bias_mask
210                )
211
212            # apply the adjustment rate to the gradients
213            layer.weight.grad.data *= adjustment_rate_weight
214            if layer.bias is not None:
215                layer.bias.grad.data *= adjustment_rate_bias
216
217            # update network capacity metric
218            capacity.update(adjustment_rate_weight, adjustment_rate_bias)
219
220        return capacity.compute()
221
222    def compensate_task_embedding_gradients(
223        self,
224        batch_idx: int,
225        num_batches: int,
226    ) -> None:
227        r"""Compensate the gradients of task embeddings during training. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
228
229        **Args:**
230        - **batch_idx** (`int`): the current training batch index.
231        - **num_batches** (`int`): the total number of training batches.
232        """
233
234        for te in self.backbone.task_embedding_t.values():
235            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
236                batch_idx - 1
237            ) / (
238                num_batches - 1
239            )  # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a)
240
241            num = (
242                torch.cosh(
243                    torch.clamp(
244                        anneal_scalar * te.weight.data,
245                        -self.clamp_threshold,
246                        self.clamp_threshold,
247                    )
248                )
249                + 1
250            )
251
252            den = torch.cosh(te.weight.data) + 1
253
254            compensation = self.s_max / anneal_scalar * num / den
255
256            te.weight.grad.data *= compensation
257
258    def forward(
259        self,
260        input: torch.Tensor,
261        stage: str,
262        batch_idx: int | None = None,
263        num_batches: int | None = None,
264        task_id: int | None = None,
265    ) -> tuple[Tensor, dict[str, Tensor]]:
266        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
267
268        **Args:**
269        - **input** (`Tensor`): The input tensor from data.
270        - **stage** (`str`): the stage of the forward pass, should be one of the following:
271            1. 'train': training stage.
272            2. 'validation': validation stage.
273            3. 'test': testing stage.
274        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
275        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
276        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
277
278        **Returns:**
279        - **logits** (`Tensor`): the output logits tensor.
280        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
281        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
282        """
283        feature, mask, hidden_features = self.backbone(
284            input,
285            stage=stage,
286            s_max=self.s_max if stage == "train" or stage == "validation" else None,
287            batch_idx=batch_idx if stage == "train" else None,
288            num_batches=num_batches if stage == "train" else None,
289            test_mask=self.masks[f"{task_id}"] if stage == "test" else None,
290        )
291        logits = self.heads(feature, task_id)
292
293        return logits, mask, hidden_features
294
295    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
296        r"""Training step for current task `self.task_id`.
297
298        **Args:**
299        - **batch** (`Any`): a batch of training data.
300        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
301
302        **Returns:**
303        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For HAT, it includes 'mask' and 'capacity' for logging.
304        """
305        x, y = batch
306
307        # zero the gradients before forward pass in manual optimisation mode
308        opt = self.optimizers()
309        opt.zero_grad()
310
311        # classification loss
312        num_batches = self.trainer.num_training_batches
313        logits, mask, hidden_features = self.forward(
314            x,
315            stage="train",
316            batch_idx=batch_idx,
317            num_batches=num_batches,
318            task_id=self.task_id,
319        )
320        loss_cls = self.criterion(logits, y)
321
322        # regularisation loss. See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
323        loss_reg, network_sparsity = self.mark_sparsity_reg(
324            mask, self.cumulative_mask_for_previous_tasks
325        )
326
327        # total loss
328        loss = loss_cls + loss_reg
329
330        # backward step (manually)
331        self.manual_backward(loss)  # calculate the gradients
332        # HAT hard clip gradients by the cumulative masks. See equation (2) inchapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). Network capacity is calculated along with this process. Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
333        capacity = self.clip_grad_by_adjustment(
334            network_sparsity=network_sparsity,  # pass a keyword argument network sparsity here to make it compatible with AdaHAT. AdaHAT inherits this `training_step()` method.
335        )
336        # compensate the gradients of task embedding. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
337        self.compensate_task_embedding_gradients(
338            batch_idx=batch_idx,
339            num_batches=num_batches,
340        )
341        # update parameters with the modified gradients
342        opt.step()
343
344        # accuracy of the batch
345        acc = (logits.argmax(dim=1) == y).float().mean()
346
347        return {
348            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
349            "loss_cls": loss_cls,
350            "loss_reg": loss_reg,
351            "acc": acc,
352            "hidden_features": hidden_features,
353            "mask": mask,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
354            "capacity": capacity,
355        }
356
357    def on_train_end(self) -> None:
358        r"""Store the mask and update cumulative mask after training the task."""
359
360        # store the mask for the current task
361        mask_t = {
362            layer_name: self.backbone.gate_fn(
363                self.backbone.task_embedding_t[layer_name].weight * self.s_max
364            )
365            .squeeze()
366            .detach()
367            for layer_name in self.backbone.weighted_layer_names
368        }
369
370        self.masks[f"{self.task_id}"] = mask_t
371
372        # update the cumulative and summative masks
373        self.cumulative_mask_for_previous_tasks = {
374            layer_name: torch.max(
375                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
376            )
377            for layer_name in self.backbone.weighted_layer_names
378        }
379
380    def validation_step(self, batch: Any) -> dict[str, Tensor]:
381        r"""Validation step for current task `self.task_id`.
382
383        **Args:**
384        - **batch** (`Any`): a batch of validation data.
385
386        **Returns:**
387        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics.
388        """
389        x, y = batch
390        logits, mask, hidden_features = self.forward(
391            x, stage="validation", task_id=self.task_id
392        )
393        loss_cls = self.criterion(logits, y)
394        acc = (logits.argmax(dim=1) == y).float().mean()
395
396        return {
397            "loss_cls": loss_cls,
398            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
399        }
400
401    def test_step(
402        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
403    ) -> dict[str, Tensor]:
404        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
405
406        **Args:**
407        - **batch** (`Any`): a batch of test data.
408        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
409
410        **Returns:**
411        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics.
412        """
413        test_task_id = dataloader_idx + 1
414
415        x, y = batch
416        logits, mask, hidden_features = self.forward(
417            x,
418            stage="test",
419            task_id=test_task_id,
420        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
421        loss_cls = self.criterion(logits, y)
422        acc = (logits.argmax(dim=1) == y).float().mean()
423
424        return {
425            "loss_cls": loss_cls,
426            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
427        }
class HAT(clarena.cl_algorithms.base.CLAlgorithm):
 25class HAT(CLAlgorithm):
 26    r"""HAT (Hard Attention to the Task) algorithm.
 27
 28    [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.
 29
 30    """
 31
 32    def __init__(
 33        self,
 34        backbone: HATMaskBackbone,
 35        heads: HeadsTIL | HeadsCIL,
 36        adjustment_mode: str,
 37        s_max: float,
 38        clamp_threshold: float,
 39        mask_sparsity_reg_factor: float,
 40        mask_sparsity_reg_mode: str = "original",
 41        task_embedding_init_mode: str = "N01",
 42        alpha: float | None = None,
 43    ) -> None:
 44        r"""Initialise the HAT algorithm with the network.
 45
 46        **Args:**
 47        - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism.
 48        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 49        - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:
 50            1. 'hat': set the gradients of parameters linking to masked units to zero. This is the way that HAT does, which fixes the part of network for previous tasks completely. See equation (2) in chapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 51            2. 'hat_random': set the gradients of parameters linking to masked units to random 0-1 values. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52            3. 'hat_const_alpha': set the gradients of parameters linking to masked units to a constant value of `alpha`. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53            4. 'hat_const_1': set the gradients of parameters linking to masked units to a constant value of 1, which means no gradient constraint on any parameter at all. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 55        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 56        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity.
 57        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
 58            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
 59            2. 'cross': the cross version mask sparsity regularisation.
 60        - **task_embedding_init_mode** (`str`): the initialisation mode for task embeddings, should be one of the following:
 61            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 62            2. 'U-11': uniform distribution $U(-1, 1)$.
 63            3. 'U01': uniform distribution $U(0, 1)$.
 64            4. 'U-10': uniform distribution $U(-1, 0)$.
 65            5. 'last': inherit task embedding from last task.
 66        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). It applies only when adjustment_mode is 'hat_const_alpha'.
 67        """
 68        CLAlgorithm.__init__(self, backbone=backbone, heads=heads)
 69
 70        self.adjustment_mode = adjustment_mode
 71        r"""Store the adjustment mode for gradient clipping."""
 72        self.s_max = s_max
 73        r"""Store s_max. """
 74        self.clamp_threshold = clamp_threshold
 75        r"""Store the clamp threshold for task embedding gradient compensation."""
 76        self.mask_sparsity_reg_factor = mask_sparsity_reg_factor
 77        r"""Store the mask sparsity regularisation factor."""
 78        self.mask_sparsity_reg_mode = mask_sparsity_reg_mode
 79        r"""Store the mask sparsity regularisation mode."""
 80        self.mark_sparsity_reg = HATMaskSparsityReg(
 81            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
 82        )
 83        r"""Initialise and store the mask sparsity regulariser."""
 84        self.task_embedding_init_mode = task_embedding_init_mode
 85        r"""Store the task embedding initialisation mode."""
 86        self.alpha = alpha if adjustment_mode == "hat_const_alpha" else None
 87        r"""Store the alpha for `hat_const_alpha`."""
 88        self.epsilon = None
 89        r"""HAT doesn't use the epsilon for `hat_const_alpha`. We still set it here to be consistent with the `epsilon` in `clip_grad_by_adjustment()` method in `HATMaskBackbone`."""
 90
 91        self.masks: dict[str, dict[str, Tensor]] = {}
 92        r"""Store the binary attention mask of each previous task gated from the task embedding. Keys are task IDs (string type) and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
 93
 94        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
 95        r"""Store the cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding cumulative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
 96
 97        # set manual optimisation
 98        self.automatic_optimization = False
 99
100        HAT.sanity_check(self)
101
102    def sanity_check(self) -> None:
103        r"""Check the sanity of the arguments.
104
105        **Raises:**
106        - **ValueError**: when backbone is not designed for HAT, or the `mask_sparsity_reg_mode` or `task_embedding_init_mode` is not one of the valid options. Also, if `alpha` is not given when `adjustment_mode` is 'hat_const_alpha'.
107        """
108        if not isinstance(self.backbone, HATMaskBackbone):
109            raise ValueError("The backbone should be an instance of HATMaskBackbone.")
110
111        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
112            raise ValueError(
113                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
114            )
115        if self.task_embedding_init_mode not in [
116            "N01",
117            "U01",
118            "U-10",
119            "masked",
120            "unmasked",
121        ]:
122            raise ValueError(
123                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
124            )
125
126        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
127            raise ValueError(
128                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
129            )
130
131    def on_train_start(self) -> None:
132        r"""Initialise the task embedding before training the next task and initialise the cumulative mask at the beginning of first task."""
133
134        self.backbone.initialise_task_embedding(mode=self.task_embedding_init_mode)
135
136        # initialise the cumulative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
137        if self.task_id == 1:
138            for layer_name in self.backbone.weighted_layer_names:
139                layer = self.backbone.get_layer_by_name(
140                    layer_name
141                )  # get the layer by its name
142                num_units = layer.weight.shape[0]
143
144                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
145                    num_units
146                ).to(
147                    self.device
148                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialised as zeros mask ($t = 1$). See equation (2) in chapter 3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or equation (5) in chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
149
150    def clip_grad_by_adjustment(
151        self,
152        **kwargs,
153    ) -> Tensor:
154        r"""Clip the gradients by the adjustment rate.
155
156        Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.
157
158        Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
159
160
161        **Returns:**
162        - **capacity** (`Tensor`): the calculated network capacity.
163        """
164
165        # initialise network capacity metric
166        capacity = HATNetworkCapacity()
167
168        # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists)
169        for layer_name in self.backbone.weighted_layer_names:
170
171            layer = self.backbone.get_layer_by_name(
172                layer_name
173            )  # get the layer by its name
174
175            # placeholder for the adjustment rate to avoid the error of using it before assignment
176            adjustment_rate_weight = 1
177            adjustment_rate_bias = 1
178
179            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
180                unit_wise_measure=self.cumulative_mask_for_previous_tasks,
181                layer_name=layer_name,
182                aggregation="min",
183            )
184
185            if self.adjustment_mode == "hat":
186                adjustment_rate_weight = 1 - weight_mask
187                adjustment_rate_bias = 1 - bias_mask
188
189            elif self.adjustment_mode == "hat_random":
190                adjustment_rate_weight = torch.rand_like(weight_mask) * weight_mask + (
191                    1 - weight_mask
192                )
193                adjustment_rate_bias = torch.rand_like(bias_mask) * bias_mask + (
194                    1 - bias_mask
195                )
196
197            elif self.adjustment_mode == "hat_const_alpha":
198                adjustment_rate_weight = self.alpha * torch.ones_like(
199                    weight_mask
200                ) * weight_mask + (1 - weight_mask)
201                adjustment_rate_bias = self.alpha * torch.ones_like(
202                    bias_mask
203                ) * bias_mask + (1 - bias_mask)
204
205            elif self.adjustment_mode == "hat_const_1":
206                adjustment_rate_weight = torch.ones_like(weight_mask) * weight_mask + (
207                    1 - weight_mask
208                )
209                adjustment_rate_bias = torch.ones_like(bias_mask) * bias_mask + (
210                    1 - bias_mask
211                )
212
213            # apply the adjustment rate to the gradients
214            layer.weight.grad.data *= adjustment_rate_weight
215            if layer.bias is not None:
216                layer.bias.grad.data *= adjustment_rate_bias
217
218            # update network capacity metric
219            capacity.update(adjustment_rate_weight, adjustment_rate_bias)
220
221        return capacity.compute()
222
223    def compensate_task_embedding_gradients(
224        self,
225        batch_idx: int,
226        num_batches: int,
227    ) -> None:
228        r"""Compensate the gradients of task embeddings during training. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
229
230        **Args:**
231        - **batch_idx** (`int`): the current training batch index.
232        - **num_batches** (`int`): the total number of training batches.
233        """
234
235        for te in self.backbone.task_embedding_t.values():
236            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
237                batch_idx - 1
238            ) / (
239                num_batches - 1
240            )  # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a)
241
242            num = (
243                torch.cosh(
244                    torch.clamp(
245                        anneal_scalar * te.weight.data,
246                        -self.clamp_threshold,
247                        self.clamp_threshold,
248                    )
249                )
250                + 1
251            )
252
253            den = torch.cosh(te.weight.data) + 1
254
255            compensation = self.s_max / anneal_scalar * num / den
256
257            te.weight.grad.data *= compensation
258
259    def forward(
260        self,
261        input: torch.Tensor,
262        stage: str,
263        batch_idx: int | None = None,
264        num_batches: int | None = None,
265        task_id: int | None = None,
266    ) -> tuple[Tensor, dict[str, Tensor]]:
267        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
268
269        **Args:**
270        - **input** (`Tensor`): The input tensor from data.
271        - **stage** (`str`): the stage of the forward pass, should be one of the following:
272            1. 'train': training stage.
273            2. 'validation': validation stage.
274            3. 'test': testing stage.
275        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
276        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
277        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
278
279        **Returns:**
280        - **logits** (`Tensor`): the output logits tensor.
281        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
282        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
283        """
284        feature, mask, hidden_features = self.backbone(
285            input,
286            stage=stage,
287            s_max=self.s_max if stage == "train" or stage == "validation" else None,
288            batch_idx=batch_idx if stage == "train" else None,
289            num_batches=num_batches if stage == "train" else None,
290            test_mask=self.masks[f"{task_id}"] if stage == "test" else None,
291        )
292        logits = self.heads(feature, task_id)
293
294        return logits, mask, hidden_features
295
296    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
297        r"""Training step for current task `self.task_id`.
298
299        **Args:**
300        - **batch** (`Any`): a batch of training data.
301        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
302
303        **Returns:**
304        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For HAT, it includes 'mask' and 'capacity' for logging.
305        """
306        x, y = batch
307
308        # zero the gradients before forward pass in manual optimisation mode
309        opt = self.optimizers()
310        opt.zero_grad()
311
312        # classification loss
313        num_batches = self.trainer.num_training_batches
314        logits, mask, hidden_features = self.forward(
315            x,
316            stage="train",
317            batch_idx=batch_idx,
318            num_batches=num_batches,
319            task_id=self.task_id,
320        )
321        loss_cls = self.criterion(logits, y)
322
323        # regularisation loss. See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
324        loss_reg, network_sparsity = self.mark_sparsity_reg(
325            mask, self.cumulative_mask_for_previous_tasks
326        )
327
328        # total loss
329        loss = loss_cls + loss_reg
330
331        # backward step (manually)
332        self.manual_backward(loss)  # calculate the gradients
333        # HAT hard clip gradients by the cumulative masks. See equation (2) inchapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). Network capacity is calculated along with this process. Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
334        capacity = self.clip_grad_by_adjustment(
335            network_sparsity=network_sparsity,  # pass a keyword argument network sparsity here to make it compatible with AdaHAT. AdaHAT inherits this `training_step()` method.
336        )
337        # compensate the gradients of task embedding. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
338        self.compensate_task_embedding_gradients(
339            batch_idx=batch_idx,
340            num_batches=num_batches,
341        )
342        # update parameters with the modified gradients
343        opt.step()
344
345        # accuracy of the batch
346        acc = (logits.argmax(dim=1) == y).float().mean()
347
348        return {
349            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
350            "loss_cls": loss_cls,
351            "loss_reg": loss_reg,
352            "acc": acc,
353            "hidden_features": hidden_features,
354            "mask": mask,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
355            "capacity": capacity,
356        }
357
358    def on_train_end(self) -> None:
359        r"""Store the mask and update cumulative mask after training the task."""
360
361        # store the mask for the current task
362        mask_t = {
363            layer_name: self.backbone.gate_fn(
364                self.backbone.task_embedding_t[layer_name].weight * self.s_max
365            )
366            .squeeze()
367            .detach()
368            for layer_name in self.backbone.weighted_layer_names
369        }
370
371        self.masks[f"{self.task_id}"] = mask_t
372
373        # update the cumulative and summative masks
374        self.cumulative_mask_for_previous_tasks = {
375            layer_name: torch.max(
376                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
377            )
378            for layer_name in self.backbone.weighted_layer_names
379        }
380
381    def validation_step(self, batch: Any) -> dict[str, Tensor]:
382        r"""Validation step for current task `self.task_id`.
383
384        **Args:**
385        - **batch** (`Any`): a batch of validation data.
386
387        **Returns:**
388        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics.
389        """
390        x, y = batch
391        logits, mask, hidden_features = self.forward(
392            x, stage="validation", task_id=self.task_id
393        )
394        loss_cls = self.criterion(logits, y)
395        acc = (logits.argmax(dim=1) == y).float().mean()
396
397        return {
398            "loss_cls": loss_cls,
399            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
400        }
401
402    def test_step(
403        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
404    ) -> dict[str, Tensor]:
405        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
406
407        **Args:**
408        - **batch** (`Any`): a batch of test data.
409        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
410
411        **Returns:**
412        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics.
413        """
414        test_task_id = dataloader_idx + 1
415
416        x, y = batch
417        logits, mask, hidden_features = self.forward(
418            x,
419            stage="test",
420            task_id=test_task_id,
421        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
422        loss_cls = self.criterion(logits, y)
423        acc = (logits.argmax(dim=1) == y).float().mean()
424
425        return {
426            "loss_cls": loss_cls,
427            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
428        }

HAT (Hard Attention to the Task) algorithm.

HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.

HAT( backbone: clarena.backbones.HATMaskBackbone, heads: clarena.cl_heads.HeadsTIL | clarena.cl_heads.HeadsCIL, adjustment_mode: str, s_max: float, clamp_threshold: float, mask_sparsity_reg_factor: float, mask_sparsity_reg_mode: str = 'original', task_embedding_init_mode: str = 'N01', alpha: float | None = None)
 32    def __init__(
 33        self,
 34        backbone: HATMaskBackbone,
 35        heads: HeadsTIL | HeadsCIL,
 36        adjustment_mode: str,
 37        s_max: float,
 38        clamp_threshold: float,
 39        mask_sparsity_reg_factor: float,
 40        mask_sparsity_reg_mode: str = "original",
 41        task_embedding_init_mode: str = "N01",
 42        alpha: float | None = None,
 43    ) -> None:
 44        r"""Initialise the HAT algorithm with the network.
 45
 46        **Args:**
 47        - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism.
 48        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 49        - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:
 50            1. 'hat': set the gradients of parameters linking to masked units to zero. This is the way that HAT does, which fixes the part of network for previous tasks completely. See equation (2) in chapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 51            2. 'hat_random': set the gradients of parameters linking to masked units to random 0-1 values. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52            3. 'hat_const_alpha': set the gradients of parameters linking to masked units to a constant value of `alpha`. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53            4. 'hat_const_1': set the gradients of parameters linking to masked units to a constant value of 1, which means no gradient constraint on any parameter at all. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 55        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 56        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity.
 57        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
 58            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
 59            2. 'cross': the cross version mask sparsity regularisation.
 60        - **task_embedding_init_mode** (`str`): the initialisation mode for task embeddings, should be one of the following:
 61            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 62            2. 'U-11': uniform distribution $U(-1, 1)$.
 63            3. 'U01': uniform distribution $U(0, 1)$.
 64            4. 'U-10': uniform distribution $U(-1, 0)$.
 65            5. 'last': inherit task embedding from last task.
 66        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). It applies only when adjustment_mode is 'hat_const_alpha'.
 67        """
 68        CLAlgorithm.__init__(self, backbone=backbone, heads=heads)
 69
 70        self.adjustment_mode = adjustment_mode
 71        r"""Store the adjustment mode for gradient clipping."""
 72        self.s_max = s_max
 73        r"""Store s_max. """
 74        self.clamp_threshold = clamp_threshold
 75        r"""Store the clamp threshold for task embedding gradient compensation."""
 76        self.mask_sparsity_reg_factor = mask_sparsity_reg_factor
 77        r"""Store the mask sparsity regularisation factor."""
 78        self.mask_sparsity_reg_mode = mask_sparsity_reg_mode
 79        r"""Store the mask sparsity regularisation mode."""
 80        self.mark_sparsity_reg = HATMaskSparsityReg(
 81            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
 82        )
 83        r"""Initialise and store the mask sparsity regulariser."""
 84        self.task_embedding_init_mode = task_embedding_init_mode
 85        r"""Store the task embedding initialisation mode."""
 86        self.alpha = alpha if adjustment_mode == "hat_const_alpha" else None
 87        r"""Store the alpha for `hat_const_alpha`."""
 88        self.epsilon = None
 89        r"""HAT doesn't use the epsilon for `hat_const_alpha`. We still set it here to be consistent with the `epsilon` in `clip_grad_by_adjustment()` method in `HATMaskBackbone`."""
 90
 91        self.masks: dict[str, dict[str, Tensor]] = {}
 92        r"""Store the binary attention mask of each previous task gated from the task embedding. Keys are task IDs (string type) and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
 93
 94        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
 95        r"""Store the cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding cumulative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
 96
 97        # set manual optimisation
 98        self.automatic_optimization = False
 99
100        HAT.sanity_check(self)

Initialise the HAT algorithm with the network.

Args:

  • backbone (HATMaskBackbone): must be a backbone network with HAT mask mechanism.
  • heads (HeadsTIL | HeadsCIL): output heads.
  • adjustment_mode (str): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:
    1. 'hat': set the gradients of parameters linking to masked units to zero. This is the way that HAT does, which fixes the part of network for previous tasks completely. See equation (2) in chapter 2.3 "Network Training" in HAT paper.
    2. 'hat_random': set the gradients of parameters linking to masked units to random 0-1 values. See the "Baselines" section in chapter 4.1 in AdaHAT paper.
    3. 'hat_const_alpha': set the gradients of parameters linking to masked units to a constant value of alpha. See the "Baselines" section in chapter 4.1 in AdaHAT paper.
    4. 'hat_const_1': set the gradients of parameters linking to masked units to a constant value of 1, which means no gradient constraint on any parameter at all. See the "Baselines" section in chapter 4.1 in AdaHAT paper.
  • s_max (float): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in HAT paper.
  • clamp_threshold (float): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in HAT paper.
  • mask_sparsity_reg_factor (float): hyperparameter, the regularisation factor for mask sparsity.
  • mask_sparsity_reg_mode (str): the mode of mask sparsity regularisation, should be one of the following:
    1. 'original' (default): the original mask sparsity regularisation in HAT paper.
    2. 'cross': the cross version mask sparsity regularisation.
  • task_embedding_init_mode (str): the initialisation mode for task embeddings, should be one of the following:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit task embedding from last task.
  • alpha (float | None): the alpha in the 'HAT-const-alpha' mode. See the "Baselines" section in chapter 4.1 in AdaHAT paper. It applies only when adjustment_mode is 'hat_const_alpha'.
adjustment_mode

Store the adjustment mode for gradient clipping.

s_max

Store s_max.

clamp_threshold

Store the clamp threshold for task embedding gradient compensation.

mask_sparsity_reg_factor

Store the mask sparsity regularisation factor.

mask_sparsity_reg_mode

Store the mask sparsity regularisation mode.

mark_sparsity_reg

Initialise and store the mask sparsity regulariser.

task_embedding_init_mode

Store the task embedding initialisation mode.

alpha

Store the alpha for hat_const_alpha.

epsilon

HAT doesn't use the epsilon for hat_const_alpha. We still set it here to be consistent with the epsilon in clip_grad_by_adjustment() method in HATMaskBackbone.

masks: dict[str, dict[str, torch.Tensor]]

Store the binary attention mask of each previous task gated from the task embedding. Keys are task IDs (string type) and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units).

cumulative_mask_for_previous_tasks: dict[str, torch.Tensor]

Store the cumulative binary attention mask $\mathrm{M}^{

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
102    def sanity_check(self) -> None:
103        r"""Check the sanity of the arguments.
104
105        **Raises:**
106        - **ValueError**: when backbone is not designed for HAT, or the `mask_sparsity_reg_mode` or `task_embedding_init_mode` is not one of the valid options. Also, if `alpha` is not given when `adjustment_mode` is 'hat_const_alpha'.
107        """
108        if not isinstance(self.backbone, HATMaskBackbone):
109            raise ValueError("The backbone should be an instance of HATMaskBackbone.")
110
111        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
112            raise ValueError(
113                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
114            )
115        if self.task_embedding_init_mode not in [
116            "N01",
117            "U01",
118            "U-10",
119            "masked",
120            "unmasked",
121        ]:
122            raise ValueError(
123                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
124            )
125
126        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
127            raise ValueError(
128                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
129            )

Check the sanity of the arguments.

Raises:

def on_train_start(self) -> None:
131    def on_train_start(self) -> None:
132        r"""Initialise the task embedding before training the next task and initialise the cumulative mask at the beginning of first task."""
133
134        self.backbone.initialise_task_embedding(mode=self.task_embedding_init_mode)
135
136        # initialise the cumulative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
137        if self.task_id == 1:
138            for layer_name in self.backbone.weighted_layer_names:
139                layer = self.backbone.get_layer_by_name(
140                    layer_name
141                )  # get the layer by its name
142                num_units = layer.weight.shape[0]
143
144                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
145                    num_units
146                ).to(
147                    self.device
148                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialised as zeros mask ($t = 1$). See equation (2) in chapter 3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or equation (5) in chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).

Initialise the task embedding before training the next task and initialise the cumulative mask at the beginning of first task.

def clip_grad_by_adjustment(self, **kwargs) -> torch.Tensor:
150    def clip_grad_by_adjustment(
151        self,
152        **kwargs,
153    ) -> Tensor:
154        r"""Clip the gradients by the adjustment rate.
155
156        Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.
157
158        Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
159
160
161        **Returns:**
162        - **capacity** (`Tensor`): the calculated network capacity.
163        """
164
165        # initialise network capacity metric
166        capacity = HATNetworkCapacity()
167
168        # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists)
169        for layer_name in self.backbone.weighted_layer_names:
170
171            layer = self.backbone.get_layer_by_name(
172                layer_name
173            )  # get the layer by its name
174
175            # placeholder for the adjustment rate to avoid the error of using it before assignment
176            adjustment_rate_weight = 1
177            adjustment_rate_bias = 1
178
179            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
180                unit_wise_measure=self.cumulative_mask_for_previous_tasks,
181                layer_name=layer_name,
182                aggregation="min",
183            )
184
185            if self.adjustment_mode == "hat":
186                adjustment_rate_weight = 1 - weight_mask
187                adjustment_rate_bias = 1 - bias_mask
188
189            elif self.adjustment_mode == "hat_random":
190                adjustment_rate_weight = torch.rand_like(weight_mask) * weight_mask + (
191                    1 - weight_mask
192                )
193                adjustment_rate_bias = torch.rand_like(bias_mask) * bias_mask + (
194                    1 - bias_mask
195                )
196
197            elif self.adjustment_mode == "hat_const_alpha":
198                adjustment_rate_weight = self.alpha * torch.ones_like(
199                    weight_mask
200                ) * weight_mask + (1 - weight_mask)
201                adjustment_rate_bias = self.alpha * torch.ones_like(
202                    bias_mask
203                ) * bias_mask + (1 - bias_mask)
204
205            elif self.adjustment_mode == "hat_const_1":
206                adjustment_rate_weight = torch.ones_like(weight_mask) * weight_mask + (
207                    1 - weight_mask
208                )
209                adjustment_rate_bias = torch.ones_like(bias_mask) * bias_mask + (
210                    1 - bias_mask
211                )
212
213            # apply the adjustment rate to the gradients
214            layer.weight.grad.data *= adjustment_rate_weight
215            if layer.bias is not None:
216                layer.bias.grad.data *= adjustment_rate_bias
217
218            # update network capacity metric
219            capacity.update(adjustment_rate_weight, adjustment_rate_bias)
220
221        return capacity.compute()

Clip the gradients by the adjustment rate.

Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.

Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in AdaHAT paper.

Returns:

  • capacity (Tensor): the calculated network capacity.
def compensate_task_embedding_gradients(self, batch_idx: int, num_batches: int) -> None:
223    def compensate_task_embedding_gradients(
224        self,
225        batch_idx: int,
226        num_batches: int,
227    ) -> None:
228        r"""Compensate the gradients of task embeddings during training. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
229
230        **Args:**
231        - **batch_idx** (`int`): the current training batch index.
232        - **num_batches** (`int`): the total number of training batches.
233        """
234
235        for te in self.backbone.task_embedding_t.values():
236            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
237                batch_idx - 1
238            ) / (
239                num_batches - 1
240            )  # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a)
241
242            num = (
243                torch.cosh(
244                    torch.clamp(
245                        anneal_scalar * te.weight.data,
246                        -self.clamp_threshold,
247                        self.clamp_threshold,
248                    )
249                )
250                + 1
251            )
252
253            den = torch.cosh(te.weight.data) + 1
254
255            compensation = self.s_max / anneal_scalar * num / den
256
257            te.weight.grad.data *= compensation

Compensate the gradients of task embeddings during training. See chapter 2.5 "Embedding Gradient Compensation" in HAT paper.

Args:

  • batch_idx (int): the current training batch index.
  • num_batches (int): the total number of training batches.
def forward( self, input: torch.Tensor, stage: str, batch_idx: int | None = None, num_batches: int | None = None, task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
259    def forward(
260        self,
261        input: torch.Tensor,
262        stage: str,
263        batch_idx: int | None = None,
264        num_batches: int | None = None,
265        task_id: int | None = None,
266    ) -> tuple[Tensor, dict[str, Tensor]]:
267        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
268
269        **Args:**
270        - **input** (`Tensor`): The input tensor from data.
271        - **stage** (`str`): the stage of the forward pass, should be one of the following:
272            1. 'train': training stage.
273            2. 'validation': validation stage.
274            3. 'test': testing stage.
275        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
276        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
277        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
278
279        **Returns:**
280        - **logits** (`Tensor`): the output logits tensor.
281        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units).
282        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
283        """
284        feature, mask, hidden_features = self.backbone(
285            input,
286            stage=stage,
287            s_max=self.s_max if stage == "train" or stage == "validation" else None,
288            batch_idx=batch_idx if stage == "train" else None,
289            num_batches=num_batches if stage == "train" else None,
290            test_mask=self.masks[f"{task_id}"] if stage == "test" else None,
291        )
292        logits = self.heads(feature, task_id)
293
294        return logits, mask, hidden_features

The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): the stage of the forward pass, should be one of the following:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • batch_idx (int | None): the current batch index. Applies only to training stage. For other stages, it is default None.
  • num_batches (int | None): the total number of batches. Applies only to training stage. For other stages, it is default None.
  • task_id (int| None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task self.task_id. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.

Returns:

  • logits (Tensor): the output logits tensor.
  • mask (dict[str, Tensor]): the mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units).
  • hidden_features (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this forward() method of HAT class.
def training_step(self, batch: Any, batch_idx: int) -> dict[str, torch.Tensor]:
296    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
297        r"""Training step for current task `self.task_id`.
298
299        **Args:**
300        - **batch** (`Any`): a batch of training data.
301        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
302
303        **Returns:**
304        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For HAT, it includes 'mask' and 'capacity' for logging.
305        """
306        x, y = batch
307
308        # zero the gradients before forward pass in manual optimisation mode
309        opt = self.optimizers()
310        opt.zero_grad()
311
312        # classification loss
313        num_batches = self.trainer.num_training_batches
314        logits, mask, hidden_features = self.forward(
315            x,
316            stage="train",
317            batch_idx=batch_idx,
318            num_batches=num_batches,
319            task_id=self.task_id,
320        )
321        loss_cls = self.criterion(logits, y)
322
323        # regularisation loss. See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
324        loss_reg, network_sparsity = self.mark_sparsity_reg(
325            mask, self.cumulative_mask_for_previous_tasks
326        )
327
328        # total loss
329        loss = loss_cls + loss_reg
330
331        # backward step (manually)
332        self.manual_backward(loss)  # calculate the gradients
333        # HAT hard clip gradients by the cumulative masks. See equation (2) inchapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). Network capacity is calculated along with this process. Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
334        capacity = self.clip_grad_by_adjustment(
335            network_sparsity=network_sparsity,  # pass a keyword argument network sparsity here to make it compatible with AdaHAT. AdaHAT inherits this `training_step()` method.
336        )
337        # compensate the gradients of task embedding. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
338        self.compensate_task_embedding_gradients(
339            batch_idx=batch_idx,
340            num_batches=num_batches,
341        )
342        # update parameters with the modified gradients
343        opt.step()
344
345        # accuracy of the batch
346        acc = (logits.argmax(dim=1) == y).float().mean()
347
348        return {
349            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
350            "loss_cls": loss_cls,
351            "loss_reg": loss_reg,
352            "acc": acc,
353            "hidden_features": hidden_features,
354            "mask": mask,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
355            "capacity": capacity,
356        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.
  • batch_idx (int): the index of the batch. Used for calculating annealed scalar in HAT. See chapter 2.4 "Hard Attention Training" in HAT paper.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Key (str) is the metrics name, value (Tensor) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For HAT, it includes 'mask' and 'capacity' for logging.
def on_train_end(self) -> None:
358    def on_train_end(self) -> None:
359        r"""Store the mask and update cumulative mask after training the task."""
360
361        # store the mask for the current task
362        mask_t = {
363            layer_name: self.backbone.gate_fn(
364                self.backbone.task_embedding_t[layer_name].weight * self.s_max
365            )
366            .squeeze()
367            .detach()
368            for layer_name in self.backbone.weighted_layer_names
369        }
370
371        self.masks[f"{self.task_id}"] = mask_t
372
373        # update the cumulative and summative masks
374        self.cumulative_mask_for_previous_tasks = {
375            layer_name: torch.max(
376                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
377            )
378            for layer_name in self.backbone.weighted_layer_names
379        }

Store the mask and update cumulative mask after training the task.

def validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
381    def validation_step(self, batch: Any) -> dict[str, Tensor]:
382        r"""Validation step for current task `self.task_id`.
383
384        **Args:**
385        - **batch** (`Any`): a batch of validation data.
386
387        **Returns:**
388        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics.
389        """
390        x, y = batch
391        logits, mask, hidden_features = self.forward(
392            x, stage="validation", task_id=self.task_id
393        )
394        loss_cls = self.criterion(logits, y)
395        acc = (logits.argmax(dim=1) == y).float().mean()
396
397        return {
398            "loss_cls": loss_cls,
399            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
400        }

Validation step for current task self.task_id.

Args:

  • batch (Any): a batch of validation data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Key (str) is the metrics name, value (Tensor) is the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
402    def test_step(
403        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
404    ) -> dict[str, Tensor]:
405        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
406
407        **Args:**
408        - **batch** (`Any`): a batch of test data.
409        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
410
411        **Returns:**
412        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics.
413        """
414        test_task_id = dataloader_idx + 1
415
416        x, y = batch
417        logits, mask, hidden_features = self.forward(
418            x,
419            stage="test",
420            task_id=test_task_id,
421        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
422        loss_cls = self.criterion(logits, y)
423        acc = (logits.argmax(dim=1) == y).float().mean()
424
425        return {
426            "loss_cls": loss_cls,
427            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
428        }

Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Key (str) is the metrics name, value (Tensor) is the metrics.