clarena.cl_algorithms.hat

The submodule in cl_algorithms for HAT (Hard Attention to the Task) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [HAT (Hard Attention to the Task) algorithm](http://proceedings.mlr.press/v80/serra18a).
  3"""
  4
  5__all__ = ["HAT"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12from torch.utils.data import DataLoader
 13
 14from clarena.backbones import HATMaskBackbone
 15from clarena.cl_algorithms import CLAlgorithm
 16from clarena.cl_algorithms.regularizers import HATMaskSparsityReg
 17from clarena.heads import HeadDIL, HeadsTIL
 18from clarena.utils.metrics import HATNetworkCapacityMetric
 19
 20# always get logger for built-in logging in each module
 21pylogger = logging.getLogger(__name__)
 22
 23
 24class HAT(CLAlgorithm):
 25    r"""[HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm.
 26
 27    An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
 28    """
 29
 30    def __init__(
 31        self,
 32        backbone: HATMaskBackbone,
 33        heads: HeadsTIL | HeadDIL,
 34        adjustment_mode: str,
 35        s_max: float,
 36        clamp_threshold: float,
 37        mask_sparsity_reg_factor: float,
 38        mask_sparsity_reg_mode: str = "original",
 39        task_embedding_init_mode: str = "N01",
 40        alpha: float | None = None,
 41        non_algorithmic_hparams: dict[str, Any] = {},
 42        **kwargs,
 43    ) -> None:
 44        r"""Initialize the HAT algorithm with the network.
 45
 46        **Args:**
 47        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 48        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. HAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 49        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 50            1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 51            2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52            3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53            4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 55        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 56        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 57        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 58            1. 'original' (default): the original mask sparsity regularization in the HAT paper.
 59            2. 'cross': the cross version of mask sparsity regularization.
 60        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 61            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 62            2. 'U-11': uniform distribution $U(-1, 1)$.
 63            3. 'U01': uniform distribution $U(0, 1)$.
 64            4. 'U-10': uniform distribution $U(-1, 0)$.
 65            5. 'last': inherit the task embedding from the last task.
 66        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'.
 67        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 68        - **kwargs**: Reserved for multiple inheritance.
 69
 70        """
 71        super().__init__(
 72            backbone=backbone,
 73            heads=heads,
 74            non_algorithmic_hparams=non_algorithmic_hparams,
 75            **kwargs,
 76        )
 77
 78        # save additional algorithmic hyperparameters
 79        self.save_hyperparameters(
 80            "adjustment_mode",
 81            "s_max",
 82            "clamp_threshold",
 83            "mask_sparsity_reg_factor",
 84            "mask_sparsity_reg_mode",
 85            "task_embedding_init_mode",
 86            "alpha",
 87        )
 88
 89        self.adjustment_mode: str = adjustment_mode
 90        r"""The adjustment mode for gradient clipping."""
 91        self.s_max: float = s_max
 92        r"""The hyperparameter s_max."""
 93        self.clamp_threshold: float = clamp_threshold
 94        r"""The clamp threshold for task embedding gradient compensation."""
 95        self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor
 96        r"""The mask sparsity regularization factor."""
 97        self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode
 98        r"""The mask sparsity regularization mode."""
 99        self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg(
100            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
101        )
102        r"""The mask sparsity regularizer."""
103        self.task_embedding_init_mode: str = task_embedding_init_mode
104        r"""The task embedding initialization mode."""
105        self.alpha: float | None = alpha
106        r"""The hyperparameter alpha for `hat_const_alpha`."""
107        # self.epsilon: float | None = None
108        # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`."""
109
110        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
111        r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """
112
113        # set manual optimization
114        self.automatic_optimization = False
115
116        HAT.sanity_check(self)
117
118    def sanity_check(self) -> None:
119        r"""Sanity check."""
120
121        # check the backbone and heads
122        if not isinstance(self.backbone, HATMaskBackbone):
123            raise ValueError("The backbone should be an instance of `HATMaskBackbone`.")
124        if not isinstance(self.heads, HeadsTIL) and not isinstance(self.heads, HeadDIL):
125            raise ValueError(
126                "The heads should be an instance of `HeadsTIL` or `HeadDIL`."
127            )
128
129        # check marker sparsity regularization mode
130        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
131            raise ValueError(
132                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
133            )
134
135        # check task embedding initialization mode
136        if self.task_embedding_init_mode not in [
137            "N01",
138            "U01",
139            "U-10",
140            "masked",
141            "unmasked",
142        ]:
143            raise ValueError(
144                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
145            )
146
147        # check adjustment mode `hat_const_alpha`
148        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
149            raise ValueError(
150                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
151            )
152
153    def on_train_start(self) -> None:
154        r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task."""
155
156        self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode)
157
158        if self.backbone.batch_normalization is not None:
159            self.backbone.initialize_independent_bn()
160
161        # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
162        if self.cumulative_mask_for_previous_tasks == {}:
163            for layer_name in self.backbone.weighted_layer_names:
164                layer = self.backbone.get_layer_by_name(
165                    layer_name
166                )  # get the layer by its name
167                num_units = layer.weight.shape[0]
168
169                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
170                    num_units
171                ).to(
172                    self.device
173                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
174
175    def clip_grad_by_mask(
176        self,
177        mask: dict[int, Tensor],
178        aggregation_mode: str,
179    ) -> None:
180        r"""Clip the gradients by neuron mask.
181
182        **Args:**
183        - **mask** (`dict[int, Tensor]`): the neuron-wise mask which the gradients outside are clipped. Keys are layer names and values are the corresponding mask tensor for the layer.
184        - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
185            - 'min': takes the minimum of the two connected unit measures.
186            - 'max': takes the maximum of the two connected unit measures.
187            - 'mean': takes the mean of the two connected unit measures.
188
189        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
190        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
191        """
192
193        for layer_name in self.backbone.weighted_layer_names:
194
195            layer = self.backbone.get_layer_by_name(
196                layer_name
197            )  # get the layer by its name
198
199            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
200                neuron_wise_measure=mask,
201                layer_name=layer_name,
202                aggregation_mode=aggregation_mode,
203            )
204
205            # apply the adjustment rate to the gradients
206            layer.weight.grad.data *= weight_mask
207            if layer.bias is not None:
208                layer.bias.grad.data *= bias_mask
209
210    def clip_grad_by_adjustment(
211        self,
212        **kwargs,
213    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
214        r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
215
216        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
217        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
218
219        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters.
220        See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
221
222        **Returns:**
223        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
224        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors.
225        - **capacity** (`Tensor`): the calculated network capacity.
226        """
227
228        # initialize network capacity metric
229        capacity = HATNetworkCapacityMetric().to(self.device)
230        adjustment_rate_weight = {}
231        adjustment_rate_bias = {}
232
233        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist)
234        for layer_name in self.backbone.weighted_layer_names:
235
236            layer = self.backbone.get_layer_by_name(
237                layer_name
238            )  # get the layer by its name
239
240            # placeholder for the adjustment rate to avoid the error of using it before assignment
241            adjustment_rate_weight_layer = 1
242            adjustment_rate_bias_layer = 1
243
244            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
245                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
246                layer_name=layer_name,
247                aggregation_mode="min",
248            )
249
250            if self.adjustment_mode == "hat":
251                adjustment_rate_weight_layer = 1 - weight_mask
252                adjustment_rate_bias_layer = 1 - bias_mask
253
254            elif self.adjustment_mode == "hat_random":
255                adjustment_rate_weight_layer = torch.rand_like(
256                    weight_mask
257                ) * weight_mask + (1 - weight_mask)
258                adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + (
259                    1 - bias_mask
260                )
261
262            elif self.adjustment_mode == "hat_const_alpha":
263                adjustment_rate_weight_layer = self.alpha * torch.ones_like(
264                    weight_mask
265                ) * weight_mask + (1 - weight_mask)
266                adjustment_rate_bias_layer = self.alpha * torch.ones_like(
267                    bias_mask
268                ) * bias_mask + (1 - bias_mask)
269
270            elif self.adjustment_mode == "hat_const_1":
271                adjustment_rate_weight_layer = torch.ones_like(
272                    weight_mask
273                ) * weight_mask + (1 - weight_mask)
274                adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + (
275                    1 - bias_mask
276                )
277
278            # apply the adjustment rate to the gradients
279            layer.weight.grad.data *= adjustment_rate_weight_layer
280            if layer.bias is not None:
281                layer.bias.grad.data *= adjustment_rate_bias_layer
282
283            # store the adjustment rate for logging
284            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
285            if layer.bias is not None:
286                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
287
288            # update network capacity metric
289            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
290
291        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
292
293    def compensate_task_embedding_gradients(
294        self,
295        batch_idx: int,
296        num_batches: int,
297    ) -> None:
298        r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
299
300        **Args:**
301        - **batch_idx** (`int`): the current training batch index.
302        - **num_batches** (`int`): the total number of training batches.
303        """
304
305        for te in self.backbone.task_embedding_t.values():
306            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
307                batch_idx - 1
308            ) / (
309                num_batches - 1
310            )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
311
312            num = (
313                torch.cosh(
314                    torch.clamp(
315                        anneal_scalar * te.weight.data,
316                        -self.clamp_threshold,
317                        self.clamp_threshold,
318                    )
319                )
320                + 1
321            )
322
323            den = torch.cosh(te.weight.data) + 1
324
325            compensation = self.s_max / anneal_scalar * num / den
326
327            te.weight.grad.data *= compensation
328
329    def forward(
330        self,
331        input: torch.Tensor,
332        stage: str,
333        task_id: int | None = None,
334        batch_idx: int | None = None,
335        num_batches: int | None = None,
336    ) -> tuple[Tensor, dict[str, Tensor]]:
337        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
338
339        **Args:**
340        - **input** (`Tensor`): The input tensor from data.
341        - **stage** (`str`): the stage of the forward pass; one of:
342            1. 'train': training stage.
343            2. 'validation': validation stage.
344            3. 'test': testing stage.
345        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
346        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
347        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
348
349        **Returns:**
350        - **logits** (`Tensor`): the output logits tensor.
351        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ).
352        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
353        """
354        feature, mask, activations = self.backbone(
355            input,
356            stage=stage,
357            s_max=self.s_max if stage == "train" or stage == "validation" else None,
358            batch_idx=batch_idx if stage == "train" else None,
359            num_batches=num_batches if stage == "train" else None,
360            test_task_id=task_id if stage == "test" else None,
361        )
362        logits = self.heads(feature, task_id)
363
364        return (
365            logits
366            if self.if_forward_func_return_logits_only
367            else (logits, mask, activations)
368        )
369
370    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
371        r"""Training step for current task `self.task_id`.
372
373        **Args:**
374        - **batch** (`Any`): a batch of training data.
375        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
376
377        **Returns:**
378        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
379        """
380        x, y = batch
381
382        # zero the gradients before forward pass in manual optimization mode
383        opt = self.optimizers()
384        opt.zero_grad()
385
386        # classification loss
387        num_batches = self.trainer.num_training_batches
388        logits, mask, activations = self.forward(
389            x,
390            stage="train",
391            batch_idx=batch_idx,
392            num_batches=num_batches,
393            task_id=self.task_id,
394        )
395        loss_cls = self.criterion(logits, y)
396
397        # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
398        hat_mask_sparsity_reg, network_sparsity = self.mark_sparsity_reg(
399            mask, self.cumulative_mask_for_previous_tasks
400        )
401
402        # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
403        loss = loss_cls + hat_mask_sparsity_reg
404
405        # backward step (manually)
406        self.manual_backward(loss)  # calculate the gradients
407        # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
408        # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
409
410        adjustment_rate_weight, adjustment_rate_bias, capacity = (
411            self.clip_grad_by_adjustment(
412                network_sparsity=network_sparsity,  # passed for compatibility with AdaHAT, which inherits this method
413            )
414        )
415        # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
416        self.compensate_task_embedding_gradients(
417            batch_idx=batch_idx,
418            num_batches=num_batches,
419        )
420        # update parameters with the modified gradients
421        opt.step()
422
423        # predicted labels
424        preds = logits.argmax(dim=1)
425
426        # accuracy of the batch
427        acc = (preds == y).float().mean()
428
429        return {
430            "preds": preds,
431            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
432            "loss_cls": loss_cls,
433            "hat_mask_sparsity_reg": hat_mask_sparsity_reg,
434            "acc": acc,
435            "activations": activations,
436            "logits": logits,
437            "mask": mask,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
438            "input": x,  # return the input batch for Captum to use
439            "target": y,  # return the target batch for Captum to use
440            "adjustment_rate_weight": adjustment_rate_weight,  # return the adjustment rate for weights and biases for logging
441            "adjustment_rate_bias": adjustment_rate_bias,
442            "capacity": capacity,  # return the network capacity for logging
443        }
444
445    def on_train_end(self) -> None:
446        r"""The mask and update the cumulative mask after training the task."""
447
448        # store the mask for the current task
449        mask_t = self.backbone.store_mask()
450
451        # store the batch normalization if necessary
452        if self.backbone.batch_normalization is not None:
453            self.backbone.store_bn()
454
455        # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
456        self.cumulative_mask_for_previous_tasks = {
457            layer_name: torch.max(
458                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
459            )
460            for layer_name in self.backbone.weighted_layer_names
461        }
462
463    def validation_step(self, batch: Any) -> dict[str, Tensor]:
464        r"""Validation step for current task `self.task_id`.
465
466        **Args:**
467        - **batch** (`Any`): a batch of validation data.
468
469        **Returns:**
470        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
471        """
472        x, y = batch
473        logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
474        loss_cls = self.criterion(logits, y)
475        preds = logits.argmax(dim=1)
476        acc = (preds == y).float().mean()
477
478        return {
479            "preds": preds,
480            "loss_cls": loss_cls,
481            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
482        }
483
484    def test_step(
485        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
486    ) -> dict[str, Tensor]:
487        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
488
489        **Args:**
490        - **batch** (`Any`): a batch of test data.
491        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
492
493        **Returns:**
494        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
495        """
496        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
497
498        x, y = batch
499        logits, _, _ = self.forward(
500            x,
501            stage="test",
502            task_id=test_task_id,
503        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
504        loss_cls = self.criterion(logits, y)
505        preds = logits.argmax(dim=1)
506        acc = (preds == y).float().mean()
507
508        return {
509            "preds": preds,
510            "loss_cls": loss_cls,
511            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
512        }
class HAT(clarena.cl_algorithms.base.CLAlgorithm):
 25class HAT(CLAlgorithm):
 26    r"""[HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm.
 27
 28    An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: HATMaskBackbone,
 34        heads: HeadsTIL | HeadDIL,
 35        adjustment_mode: str,
 36        s_max: float,
 37        clamp_threshold: float,
 38        mask_sparsity_reg_factor: float,
 39        mask_sparsity_reg_mode: str = "original",
 40        task_embedding_init_mode: str = "N01",
 41        alpha: float | None = None,
 42        non_algorithmic_hparams: dict[str, Any] = {},
 43        **kwargs,
 44    ) -> None:
 45        r"""Initialize the HAT algorithm with the network.
 46
 47        **Args:**
 48        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 49        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. HAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 50        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 51            1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 52            2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53            3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54            4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 55        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 56        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 57        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 58        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 59            1. 'original' (default): the original mask sparsity regularization in the HAT paper.
 60            2. 'cross': the cross version of mask sparsity regularization.
 61        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 62            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 63            2. 'U-11': uniform distribution $U(-1, 1)$.
 64            3. 'U01': uniform distribution $U(0, 1)$.
 65            4. 'U-10': uniform distribution $U(-1, 0)$.
 66            5. 'last': inherit the task embedding from the last task.
 67        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'.
 68        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 69        - **kwargs**: Reserved for multiple inheritance.
 70
 71        """
 72        super().__init__(
 73            backbone=backbone,
 74            heads=heads,
 75            non_algorithmic_hparams=non_algorithmic_hparams,
 76            **kwargs,
 77        )
 78
 79        # save additional algorithmic hyperparameters
 80        self.save_hyperparameters(
 81            "adjustment_mode",
 82            "s_max",
 83            "clamp_threshold",
 84            "mask_sparsity_reg_factor",
 85            "mask_sparsity_reg_mode",
 86            "task_embedding_init_mode",
 87            "alpha",
 88        )
 89
 90        self.adjustment_mode: str = adjustment_mode
 91        r"""The adjustment mode for gradient clipping."""
 92        self.s_max: float = s_max
 93        r"""The hyperparameter s_max."""
 94        self.clamp_threshold: float = clamp_threshold
 95        r"""The clamp threshold for task embedding gradient compensation."""
 96        self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor
 97        r"""The mask sparsity regularization factor."""
 98        self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode
 99        r"""The mask sparsity regularization mode."""
100        self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg(
101            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
102        )
103        r"""The mask sparsity regularizer."""
104        self.task_embedding_init_mode: str = task_embedding_init_mode
105        r"""The task embedding initialization mode."""
106        self.alpha: float | None = alpha
107        r"""The hyperparameter alpha for `hat_const_alpha`."""
108        # self.epsilon: float | None = None
109        # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`."""
110
111        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
112        r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """
113
114        # set manual optimization
115        self.automatic_optimization = False
116
117        HAT.sanity_check(self)
118
119    def sanity_check(self) -> None:
120        r"""Sanity check."""
121
122        # check the backbone and heads
123        if not isinstance(self.backbone, HATMaskBackbone):
124            raise ValueError("The backbone should be an instance of `HATMaskBackbone`.")
125        if not isinstance(self.heads, HeadsTIL) and not isinstance(self.heads, HeadDIL):
126            raise ValueError(
127                "The heads should be an instance of `HeadsTIL` or `HeadDIL`."
128            )
129
130        # check marker sparsity regularization mode
131        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
132            raise ValueError(
133                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
134            )
135
136        # check task embedding initialization mode
137        if self.task_embedding_init_mode not in [
138            "N01",
139            "U01",
140            "U-10",
141            "masked",
142            "unmasked",
143        ]:
144            raise ValueError(
145                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
146            )
147
148        # check adjustment mode `hat_const_alpha`
149        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
150            raise ValueError(
151                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
152            )
153
154    def on_train_start(self) -> None:
155        r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task."""
156
157        self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode)
158
159        if self.backbone.batch_normalization is not None:
160            self.backbone.initialize_independent_bn()
161
162        # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
163        if self.cumulative_mask_for_previous_tasks == {}:
164            for layer_name in self.backbone.weighted_layer_names:
165                layer = self.backbone.get_layer_by_name(
166                    layer_name
167                )  # get the layer by its name
168                num_units = layer.weight.shape[0]
169
170                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
171                    num_units
172                ).to(
173                    self.device
174                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
175
176    def clip_grad_by_mask(
177        self,
178        mask: dict[int, Tensor],
179        aggregation_mode: str,
180    ) -> None:
181        r"""Clip the gradients by neuron mask.
182
183        **Args:**
184        - **mask** (`dict[int, Tensor]`): the neuron-wise mask which the gradients outside are clipped. Keys are layer names and values are the corresponding mask tensor for the layer.
185        - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
186            - 'min': takes the minimum of the two connected unit measures.
187            - 'max': takes the maximum of the two connected unit measures.
188            - 'mean': takes the mean of the two connected unit measures.
189
190        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
191        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
192        """
193
194        for layer_name in self.backbone.weighted_layer_names:
195
196            layer = self.backbone.get_layer_by_name(
197                layer_name
198            )  # get the layer by its name
199
200            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
201                neuron_wise_measure=mask,
202                layer_name=layer_name,
203                aggregation_mode=aggregation_mode,
204            )
205
206            # apply the adjustment rate to the gradients
207            layer.weight.grad.data *= weight_mask
208            if layer.bias is not None:
209                layer.bias.grad.data *= bias_mask
210
211    def clip_grad_by_adjustment(
212        self,
213        **kwargs,
214    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
215        r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
216
217        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
218        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
219
220        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters.
221        See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
222
223        **Returns:**
224        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
225        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors.
226        - **capacity** (`Tensor`): the calculated network capacity.
227        """
228
229        # initialize network capacity metric
230        capacity = HATNetworkCapacityMetric().to(self.device)
231        adjustment_rate_weight = {}
232        adjustment_rate_bias = {}
233
234        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist)
235        for layer_name in self.backbone.weighted_layer_names:
236
237            layer = self.backbone.get_layer_by_name(
238                layer_name
239            )  # get the layer by its name
240
241            # placeholder for the adjustment rate to avoid the error of using it before assignment
242            adjustment_rate_weight_layer = 1
243            adjustment_rate_bias_layer = 1
244
245            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
246                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
247                layer_name=layer_name,
248                aggregation_mode="min",
249            )
250
251            if self.adjustment_mode == "hat":
252                adjustment_rate_weight_layer = 1 - weight_mask
253                adjustment_rate_bias_layer = 1 - bias_mask
254
255            elif self.adjustment_mode == "hat_random":
256                adjustment_rate_weight_layer = torch.rand_like(
257                    weight_mask
258                ) * weight_mask + (1 - weight_mask)
259                adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + (
260                    1 - bias_mask
261                )
262
263            elif self.adjustment_mode == "hat_const_alpha":
264                adjustment_rate_weight_layer = self.alpha * torch.ones_like(
265                    weight_mask
266                ) * weight_mask + (1 - weight_mask)
267                adjustment_rate_bias_layer = self.alpha * torch.ones_like(
268                    bias_mask
269                ) * bias_mask + (1 - bias_mask)
270
271            elif self.adjustment_mode == "hat_const_1":
272                adjustment_rate_weight_layer = torch.ones_like(
273                    weight_mask
274                ) * weight_mask + (1 - weight_mask)
275                adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + (
276                    1 - bias_mask
277                )
278
279            # apply the adjustment rate to the gradients
280            layer.weight.grad.data *= adjustment_rate_weight_layer
281            if layer.bias is not None:
282                layer.bias.grad.data *= adjustment_rate_bias_layer
283
284            # store the adjustment rate for logging
285            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
286            if layer.bias is not None:
287                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
288
289            # update network capacity metric
290            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
291
292        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
293
294    def compensate_task_embedding_gradients(
295        self,
296        batch_idx: int,
297        num_batches: int,
298    ) -> None:
299        r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
300
301        **Args:**
302        - **batch_idx** (`int`): the current training batch index.
303        - **num_batches** (`int`): the total number of training batches.
304        """
305
306        for te in self.backbone.task_embedding_t.values():
307            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
308                batch_idx - 1
309            ) / (
310                num_batches - 1
311            )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
312
313            num = (
314                torch.cosh(
315                    torch.clamp(
316                        anneal_scalar * te.weight.data,
317                        -self.clamp_threshold,
318                        self.clamp_threshold,
319                    )
320                )
321                + 1
322            )
323
324            den = torch.cosh(te.weight.data) + 1
325
326            compensation = self.s_max / anneal_scalar * num / den
327
328            te.weight.grad.data *= compensation
329
330    def forward(
331        self,
332        input: torch.Tensor,
333        stage: str,
334        task_id: int | None = None,
335        batch_idx: int | None = None,
336        num_batches: int | None = None,
337    ) -> tuple[Tensor, dict[str, Tensor]]:
338        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
339
340        **Args:**
341        - **input** (`Tensor`): The input tensor from data.
342        - **stage** (`str`): the stage of the forward pass; one of:
343            1. 'train': training stage.
344            2. 'validation': validation stage.
345            3. 'test': testing stage.
346        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
347        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
348        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
349
350        **Returns:**
351        - **logits** (`Tensor`): the output logits tensor.
352        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ).
353        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
354        """
355        feature, mask, activations = self.backbone(
356            input,
357            stage=stage,
358            s_max=self.s_max if stage == "train" or stage == "validation" else None,
359            batch_idx=batch_idx if stage == "train" else None,
360            num_batches=num_batches if stage == "train" else None,
361            test_task_id=task_id if stage == "test" else None,
362        )
363        logits = self.heads(feature, task_id)
364
365        return (
366            logits
367            if self.if_forward_func_return_logits_only
368            else (logits, mask, activations)
369        )
370
371    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
372        r"""Training step for current task `self.task_id`.
373
374        **Args:**
375        - **batch** (`Any`): a batch of training data.
376        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
377
378        **Returns:**
379        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
380        """
381        x, y = batch
382
383        # zero the gradients before forward pass in manual optimization mode
384        opt = self.optimizers()
385        opt.zero_grad()
386
387        # classification loss
388        num_batches = self.trainer.num_training_batches
389        logits, mask, activations = self.forward(
390            x,
391            stage="train",
392            batch_idx=batch_idx,
393            num_batches=num_batches,
394            task_id=self.task_id,
395        )
396        loss_cls = self.criterion(logits, y)
397
398        # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
399        hat_mask_sparsity_reg, network_sparsity = self.mark_sparsity_reg(
400            mask, self.cumulative_mask_for_previous_tasks
401        )
402
403        # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
404        loss = loss_cls + hat_mask_sparsity_reg
405
406        # backward step (manually)
407        self.manual_backward(loss)  # calculate the gradients
408        # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
409        # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
410
411        adjustment_rate_weight, adjustment_rate_bias, capacity = (
412            self.clip_grad_by_adjustment(
413                network_sparsity=network_sparsity,  # passed for compatibility with AdaHAT, which inherits this method
414            )
415        )
416        # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
417        self.compensate_task_embedding_gradients(
418            batch_idx=batch_idx,
419            num_batches=num_batches,
420        )
421        # update parameters with the modified gradients
422        opt.step()
423
424        # predicted labels
425        preds = logits.argmax(dim=1)
426
427        # accuracy of the batch
428        acc = (preds == y).float().mean()
429
430        return {
431            "preds": preds,
432            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
433            "loss_cls": loss_cls,
434            "hat_mask_sparsity_reg": hat_mask_sparsity_reg,
435            "acc": acc,
436            "activations": activations,
437            "logits": logits,
438            "mask": mask,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
439            "input": x,  # return the input batch for Captum to use
440            "target": y,  # return the target batch for Captum to use
441            "adjustment_rate_weight": adjustment_rate_weight,  # return the adjustment rate for weights and biases for logging
442            "adjustment_rate_bias": adjustment_rate_bias,
443            "capacity": capacity,  # return the network capacity for logging
444        }
445
446    def on_train_end(self) -> None:
447        r"""The mask and update the cumulative mask after training the task."""
448
449        # store the mask for the current task
450        mask_t = self.backbone.store_mask()
451
452        # store the batch normalization if necessary
453        if self.backbone.batch_normalization is not None:
454            self.backbone.store_bn()
455
456        # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
457        self.cumulative_mask_for_previous_tasks = {
458            layer_name: torch.max(
459                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
460            )
461            for layer_name in self.backbone.weighted_layer_names
462        }
463
464    def validation_step(self, batch: Any) -> dict[str, Tensor]:
465        r"""Validation step for current task `self.task_id`.
466
467        **Args:**
468        - **batch** (`Any`): a batch of validation data.
469
470        **Returns:**
471        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
472        """
473        x, y = batch
474        logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
475        loss_cls = self.criterion(logits, y)
476        preds = logits.argmax(dim=1)
477        acc = (preds == y).float().mean()
478
479        return {
480            "preds": preds,
481            "loss_cls": loss_cls,
482            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
483        }
484
485    def test_step(
486        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
487    ) -> dict[str, Tensor]:
488        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
489
490        **Args:**
491        - **batch** (`Any`): a batch of test data.
492        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
493
494        **Returns:**
495        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
496        """
497        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
498
499        x, y = batch
500        logits, _, _ = self.forward(
501            x,
502            stage="test",
503            task_id=test_task_id,
504        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
505        loss_cls = self.criterion(logits, y)
506        preds = logits.argmax(dim=1)
507        acc = (preds == y).float().mean()
508
509        return {
510            "preds": preds,
511            "loss_cls": loss_cls,
512            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
513        }

HAT (Hard Attention to the Task) algorithm.

An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.

HAT( backbone: clarena.backbones.HATMaskBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadDIL, adjustment_mode: str, s_max: float, clamp_threshold: float, mask_sparsity_reg_factor: float, mask_sparsity_reg_mode: str = 'original', task_embedding_init_mode: str = 'N01', alpha: float | None = None, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
 31    def __init__(
 32        self,
 33        backbone: HATMaskBackbone,
 34        heads: HeadsTIL | HeadDIL,
 35        adjustment_mode: str,
 36        s_max: float,
 37        clamp_threshold: float,
 38        mask_sparsity_reg_factor: float,
 39        mask_sparsity_reg_mode: str = "original",
 40        task_embedding_init_mode: str = "N01",
 41        alpha: float | None = None,
 42        non_algorithmic_hparams: dict[str, Any] = {},
 43        **kwargs,
 44    ) -> None:
 45        r"""Initialize the HAT algorithm with the network.
 46
 47        **Args:**
 48        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 49        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. HAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 50        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 51            1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 52            2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53            3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54            4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 55        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 56        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 57        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 58        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 59            1. 'original' (default): the original mask sparsity regularization in the HAT paper.
 60            2. 'cross': the cross version of mask sparsity regularization.
 61        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 62            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 63            2. 'U-11': uniform distribution $U(-1, 1)$.
 64            3. 'U01': uniform distribution $U(0, 1)$.
 65            4. 'U-10': uniform distribution $U(-1, 0)$.
 66            5. 'last': inherit the task embedding from the last task.
 67        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'.
 68        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 69        - **kwargs**: Reserved for multiple inheritance.
 70
 71        """
 72        super().__init__(
 73            backbone=backbone,
 74            heads=heads,
 75            non_algorithmic_hparams=non_algorithmic_hparams,
 76            **kwargs,
 77        )
 78
 79        # save additional algorithmic hyperparameters
 80        self.save_hyperparameters(
 81            "adjustment_mode",
 82            "s_max",
 83            "clamp_threshold",
 84            "mask_sparsity_reg_factor",
 85            "mask_sparsity_reg_mode",
 86            "task_embedding_init_mode",
 87            "alpha",
 88        )
 89
 90        self.adjustment_mode: str = adjustment_mode
 91        r"""The adjustment mode for gradient clipping."""
 92        self.s_max: float = s_max
 93        r"""The hyperparameter s_max."""
 94        self.clamp_threshold: float = clamp_threshold
 95        r"""The clamp threshold for task embedding gradient compensation."""
 96        self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor
 97        r"""The mask sparsity regularization factor."""
 98        self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode
 99        r"""The mask sparsity regularization mode."""
100        self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg(
101            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
102        )
103        r"""The mask sparsity regularizer."""
104        self.task_embedding_init_mode: str = task_embedding_init_mode
105        r"""The task embedding initialization mode."""
106        self.alpha: float | None = alpha
107        r"""The hyperparameter alpha for `hat_const_alpha`."""
108        # self.epsilon: float | None = None
109        # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`."""
110
111        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
112        r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """
113
114        # set manual optimization
115        self.automatic_optimization = False
116
117        HAT.sanity_check(self)

Initialize the HAT algorithm with the network.

Args:

  • backbone (HATMaskBackbone): must be a backbone network with the HAT mask mechanism.
  • heads (HeadsTIL | HeadDIL): output heads. HAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
  • adjustment_mode (str): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
    1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
    2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the AdaHAT paper.
    3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value alpha. See "Baselines" in Sec. 4.1 in the AdaHAT paper.
    4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the AdaHAT paper.
  • s_max (float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper.
  • clamp_threshold (float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.
  • mask_sparsity_reg_factor (float): hyperparameter, the regularization factor for mask sparsity.
  • mask_sparsity_reg_mode (str): the mode of mask sparsity regularization, must be one of:
    1. 'original' (default): the original mask sparsity regularization in the HAT paper.
    2. 'cross': the cross version of mask sparsity regularization.
  • task_embedding_init_mode (str): the initialization mode for task embeddings, must be one of:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit the task embedding from the last task.
  • alpha (float | None): the alpha in the 'HAT-const-alpha' mode. Applies only when adjustment_mode is 'hat_const_alpha'.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
adjustment_mode: str

The adjustment mode for gradient clipping.

s_max: float

The hyperparameter s_max.

clamp_threshold: float

The clamp threshold for task embedding gradient compensation.

mask_sparsity_reg_factor: float

The mask sparsity regularization factor.

mask_sparsity_reg_mode: str

The mask sparsity regularization mode.

The mask sparsity regularizer.

task_embedding_init_mode: str

The task embedding initialization mode.

alpha: float | None

The hyperparameter alpha for hat_const_alpha.

cumulative_mask_for_previous_tasks: dict[str, torch.Tensor]

The cumulative binary attention mask $\mathrm{M}^{self.task_id). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ).

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
119    def sanity_check(self) -> None:
120        r"""Sanity check."""
121
122        # check the backbone and heads
123        if not isinstance(self.backbone, HATMaskBackbone):
124            raise ValueError("The backbone should be an instance of `HATMaskBackbone`.")
125        if not isinstance(self.heads, HeadsTIL) and not isinstance(self.heads, HeadDIL):
126            raise ValueError(
127                "The heads should be an instance of `HeadsTIL` or `HeadDIL`."
128            )
129
130        # check marker sparsity regularization mode
131        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
132            raise ValueError(
133                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
134            )
135
136        # check task embedding initialization mode
137        if self.task_embedding_init_mode not in [
138            "N01",
139            "U01",
140            "U-10",
141            "masked",
142            "unmasked",
143        ]:
144            raise ValueError(
145                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
146            )
147
148        # check adjustment mode `hat_const_alpha`
149        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
150            raise ValueError(
151                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
152            )

Sanity check.

def on_train_start(self) -> None:
154    def on_train_start(self) -> None:
155        r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task."""
156
157        self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode)
158
159        if self.backbone.batch_normalization is not None:
160            self.backbone.initialize_independent_bn()
161
162        # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
163        if self.cumulative_mask_for_previous_tasks == {}:
164            for layer_name in self.backbone.weighted_layer_names:
165                layer = self.backbone.get_layer_by_name(
166                    layer_name
167                )  # get the layer by its name
168                num_units = layer.weight.shape[0]
169
170                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
171                    num_units
172                ).to(
173                    self.device
174                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)

Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.

def clip_grad_by_mask(self, mask: dict[int, torch.Tensor], aggregation_mode: str) -> None:
176    def clip_grad_by_mask(
177        self,
178        mask: dict[int, Tensor],
179        aggregation_mode: str,
180    ) -> None:
181        r"""Clip the gradients by neuron mask.
182
183        **Args:**
184        - **mask** (`dict[int, Tensor]`): the neuron-wise mask which the gradients outside are clipped. Keys are layer names and values are the corresponding mask tensor for the layer.
185        - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
186            - 'min': takes the minimum of the two connected unit measures.
187            - 'max': takes the maximum of the two connected unit measures.
188            - 'mean': takes the mean of the two connected unit measures.
189
190        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
191        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
192        """
193
194        for layer_name in self.backbone.weighted_layer_names:
195
196            layer = self.backbone.get_layer_by_name(
197                layer_name
198            )  # get the layer by its name
199
200            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
201                neuron_wise_measure=mask,
202                layer_name=layer_name,
203                aggregation_mode=aggregation_mode,
204            )
205
206            # apply the adjustment rate to the gradients
207            layer.weight.grad.data *= weight_mask
208            if layer.bias is not None:
209                layer.bias.grad.data *= bias_mask

Clip the gradients by neuron mask.

Args:

  • mask (dict[int, Tensor]): the neuron-wise mask which the gradients outside are clipped. Keys are layer names and values are the corresponding mask tensor for the layer.
  • aggregation_mode (str): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:
    • 'min': takes the minimum of the two connected unit measures.
    • 'max': takes the maximum of the two connected unit measures.
    • 'mean': takes the mean of the two connected unit measures.

Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.

def clip_grad_by_adjustment( self, **kwargs) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], torch.Tensor]:
211    def clip_grad_by_adjustment(
212        self,
213        **kwargs,
214    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
215        r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
216
217        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
218        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
219
220        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters.
221        See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
222
223        **Returns:**
224        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
225        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors.
226        - **capacity** (`Tensor`): the calculated network capacity.
227        """
228
229        # initialize network capacity metric
230        capacity = HATNetworkCapacityMetric().to(self.device)
231        adjustment_rate_weight = {}
232        adjustment_rate_bias = {}
233
234        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist)
235        for layer_name in self.backbone.weighted_layer_names:
236
237            layer = self.backbone.get_layer_by_name(
238                layer_name
239            )  # get the layer by its name
240
241            # placeholder for the adjustment rate to avoid the error of using it before assignment
242            adjustment_rate_weight_layer = 1
243            adjustment_rate_bias_layer = 1
244
245            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
246                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
247                layer_name=layer_name,
248                aggregation_mode="min",
249            )
250
251            if self.adjustment_mode == "hat":
252                adjustment_rate_weight_layer = 1 - weight_mask
253                adjustment_rate_bias_layer = 1 - bias_mask
254
255            elif self.adjustment_mode == "hat_random":
256                adjustment_rate_weight_layer = torch.rand_like(
257                    weight_mask
258                ) * weight_mask + (1 - weight_mask)
259                adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + (
260                    1 - bias_mask
261                )
262
263            elif self.adjustment_mode == "hat_const_alpha":
264                adjustment_rate_weight_layer = self.alpha * torch.ones_like(
265                    weight_mask
266                ) * weight_mask + (1 - weight_mask)
267                adjustment_rate_bias_layer = self.alpha * torch.ones_like(
268                    bias_mask
269                ) * bias_mask + (1 - bias_mask)
270
271            elif self.adjustment_mode == "hat_const_1":
272                adjustment_rate_weight_layer = torch.ones_like(
273                    weight_mask
274                ) * weight_mask + (1 - weight_mask)
275                adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + (
276                    1 - bias_mask
277                )
278
279            # apply the adjustment rate to the gradients
280            layer.weight.grad.data *= adjustment_rate_weight_layer
281            if layer.bias is not None:
282                layer.bias.grad.data *= adjustment_rate_bias_layer
283
284            # store the adjustment rate for logging
285            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
286            if layer.bias is not None:
287                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
288
289            # update network capacity metric
290            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
291
292        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()

Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.

Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.

Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.

Returns:

  • adjustment_rate_weight (dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors.
  • adjustment_rate_bias (dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer name and values (Tensor) are the adjustment rate tensors.
  • capacity (Tensor): the calculated network capacity.
def compensate_task_embedding_gradients(self, batch_idx: int, num_batches: int) -> None:
294    def compensate_task_embedding_gradients(
295        self,
296        batch_idx: int,
297        num_batches: int,
298    ) -> None:
299        r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
300
301        **Args:**
302        - **batch_idx** (`int`): the current training batch index.
303        - **num_batches** (`int`): the total number of training batches.
304        """
305
306        for te in self.backbone.task_embedding_t.values():
307            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
308                batch_idx - 1
309            ) / (
310                num_batches - 1
311            )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
312
313            num = (
314                torch.cosh(
315                    torch.clamp(
316                        anneal_scalar * te.weight.data,
317                        -self.clamp_threshold,
318                        self.clamp_threshold,
319                    )
320                )
321                + 1
322            )
323
324            den = torch.cosh(te.weight.data) + 1
325
326            compensation = self.s_max / anneal_scalar * num / den
327
328            te.weight.grad.data *= compensation

Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.

Args:

  • batch_idx (int): the current training batch index.
  • num_batches (int): the total number of training batches.
def forward( self, input: torch.Tensor, stage: str, task_id: int | None = None, batch_idx: int | None = None, num_batches: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
330    def forward(
331        self,
332        input: torch.Tensor,
333        stage: str,
334        task_id: int | None = None,
335        batch_idx: int | None = None,
336        num_batches: int | None = None,
337    ) -> tuple[Tensor, dict[str, Tensor]]:
338        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
339
340        **Args:**
341        - **input** (`Tensor`): The input tensor from data.
342        - **stage** (`str`): the stage of the forward pass; one of:
343            1. 'train': training stage.
344            2. 'validation': validation stage.
345            3. 'test': testing stage.
346        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
347        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
348        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
349
350        **Returns:**
351        - **logits** (`Tensor`): the output logits tensor.
352        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ).
353        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
354        """
355        feature, mask, activations = self.backbone(
356            input,
357            stage=stage,
358            s_max=self.s_max if stage == "train" or stage == "validation" else None,
359            batch_idx=batch_idx if stage == "train" else None,
360            num_batches=num_batches if stage == "train" else None,
361            test_task_id=task_id if stage == "test" else None,
362        )
363        logits = self.heads(feature, task_id)
364
365        return (
366            logits
367            if self.if_forward_func_return_logits_only
368            else (logits, mask, activations)
369        )

The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): the stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • task_id (int| None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task self.task_id. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
  • batch_idx (int | None): the current batch index. Applies only to training stage. For other stages, it is default None.
  • num_batches (int | None): the total number of batches. Applies only to training stage. For other stages, it is default None.

Returns:

  • logits (Tensor): the output logits tensor.
  • mask (dict[str, Tensor]): the mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units, ).
  • activations (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this forward() method of HAT class.
def training_step(self, batch: Any, batch_idx: int) -> dict[str, torch.Tensor]:
371    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
372        r"""Training step for current task `self.task_id`.
373
374        **Args:**
375        - **batch** (`Any`): a batch of training data.
376        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
377
378        **Returns:**
379        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
380        """
381        x, y = batch
382
383        # zero the gradients before forward pass in manual optimization mode
384        opt = self.optimizers()
385        opt.zero_grad()
386
387        # classification loss
388        num_batches = self.trainer.num_training_batches
389        logits, mask, activations = self.forward(
390            x,
391            stage="train",
392            batch_idx=batch_idx,
393            num_batches=num_batches,
394            task_id=self.task_id,
395        )
396        loss_cls = self.criterion(logits, y)
397
398        # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
399        hat_mask_sparsity_reg, network_sparsity = self.mark_sparsity_reg(
400            mask, self.cumulative_mask_for_previous_tasks
401        )
402
403        # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
404        loss = loss_cls + hat_mask_sparsity_reg
405
406        # backward step (manually)
407        self.manual_backward(loss)  # calculate the gradients
408        # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
409        # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
410
411        adjustment_rate_weight, adjustment_rate_bias, capacity = (
412            self.clip_grad_by_adjustment(
413                network_sparsity=network_sparsity,  # passed for compatibility with AdaHAT, which inherits this method
414            )
415        )
416        # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
417        self.compensate_task_embedding_gradients(
418            batch_idx=batch_idx,
419            num_batches=num_batches,
420        )
421        # update parameters with the modified gradients
422        opt.step()
423
424        # predicted labels
425        preds = logits.argmax(dim=1)
426
427        # accuracy of the batch
428        acc = (preds == y).float().mean()
429
430        return {
431            "preds": preds,
432            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
433            "loss_cls": loss_cls,
434            "hat_mask_sparsity_reg": hat_mask_sparsity_reg,
435            "acc": acc,
436            "activations": activations,
437            "logits": logits,
438            "mask": mask,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
439            "input": x,  # return the input batch for Captum to use
440            "target": y,  # return the target batch for Captum to use
441            "adjustment_rate_weight": adjustment_rate_weight,  # return the adjustment rate for weights and biases for logging
442            "adjustment_rate_bias": adjustment_rate_bias,
443            "capacity": capacity,  # return the network capacity for logging
444        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.
  • batch_idx (int): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the HAT paper.

Returns:

  • outputs (dict[str, Tensor]): a dictionary containing loss and other metrics from this training step. Keys (str) are metric names, and values (Tensor) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
def on_train_end(self) -> None:
446    def on_train_end(self) -> None:
447        r"""The mask and update the cumulative mask after training the task."""
448
449        # store the mask for the current task
450        mask_t = self.backbone.store_mask()
451
452        # store the batch normalization if necessary
453        if self.backbone.batch_normalization is not None:
454            self.backbone.store_bn()
455
456        # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
457        self.cumulative_mask_for_previous_tasks = {
458            layer_name: torch.max(
459                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
460            )
461            for layer_name in self.backbone.weighted_layer_names
462        }

The mask and update the cumulative mask after training the task.

def validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
464    def validation_step(self, batch: Any) -> dict[str, Tensor]:
465        r"""Validation step for current task `self.task_id`.
466
467        **Args:**
468        - **batch** (`Any`): a batch of validation data.
469
470        **Returns:**
471        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
472        """
473        x, y = batch
474        logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
475        loss_cls = self.criterion(logits, y)
476        preds = logits.argmax(dim=1)
477        acc = (preds == y).float().mean()
478
479        return {
480            "preds": preds,
481            "loss_cls": loss_cls,
482            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
483        }

Validation step for current task self.task_id.

Args:

  • batch (Any): a batch of validation data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
485    def test_step(
486        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
487    ) -> dict[str, Tensor]:
488        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
489
490        **Args:**
491        - **batch** (`Any`): a batch of test data.
492        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
493
494        **Returns:**
495        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
496        """
497        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
498
499        x, y = batch
500        logits, _, _ = self.forward(
501            x,
502            stage="test",
503            task_id=test_task_id,
504        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
505        loss_cls = self.criterion(logits, y)
506        preds = logits.argmax(dim=1)
507        acc = (preds == y).float().mean()
508
509        return {
510            "preds": preds,
511            "loss_cls": loss_cls,
512            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
513        }

Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.