clarena.cl_algorithms.hat

The submodule in cl_algorithms for HAT (Hard Attention to the Task) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [HAT (Hard Attention to the Task) algorithm](http://proceedings.mlr.press/v80/serra18a).
  3"""
  4
  5__all__ = ["HAT"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12from torch.utils.data import DataLoader
 13
 14from clarena.backbones import HATMaskBackbone
 15from clarena.cl_algorithms import CLAlgorithm
 16from clarena.cl_algorithms.regularizers import HATMaskSparsityReg
 17from clarena.heads import HeadsTIL
 18from clarena.utils.metrics import HATNetworkCapacityMetric
 19
 20# always get logger for built-in logging in each module
 21pylogger = logging.getLogger(__name__)
 22
 23
 24class HAT(CLAlgorithm):
 25    r"""[HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm.
 26
 27    An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
 28    """
 29
 30    def __init__(
 31        self,
 32        backbone: HATMaskBackbone,
 33        heads: HeadsTIL,
 34        adjustment_mode: str,
 35        s_max: float,
 36        clamp_threshold: float,
 37        mask_sparsity_reg_factor: float,
 38        mask_sparsity_reg_mode: str = "original",
 39        task_embedding_init_mode: str = "N01",
 40        alpha: float | None = None,
 41        non_algorithmic_hparams: dict[str, Any] = {},
 42    ) -> None:
 43        r"""Initialize the HAT algorithm with the network.
 44
 45        **Args:**
 46        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 47        - **heads** (`HeadsTIL`): output heads. HAT only supports TIL (Task-Incremental Learning).
 48        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 49            1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 50            2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 51            3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52            4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 54        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 55        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 56        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 57            1. 'original' (default): the original mask sparsity regularization in the HAT paper.
 58            2. 'cross': the cross version of mask sparsity regularization.
 59        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 60            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 61            2. 'U-11': uniform distribution $U(-1, 1)$.
 62            3. 'U01': uniform distribution $U(0, 1)$.
 63            4. 'U-10': uniform distribution $U(-1, 0)$.
 64            5. 'last': inherit the task embedding from the last task.
 65        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'.
 66        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 67
 68        """
 69        super().__init__(
 70            backbone=backbone,
 71            heads=heads,
 72            non_algorithmic_hparams=non_algorithmic_hparams,
 73        )
 74
 75        # save additional algorithmic hyperparameters
 76        self.save_hyperparameters(
 77            "adjustment_mode",
 78            "s_max",
 79            "clamp_threshold",
 80            "mask_sparsity_reg_factor",
 81            "mask_sparsity_reg_mode",
 82            "task_embedding_init_mode",
 83            "alpha",
 84        )
 85
 86        self.adjustment_mode: str = adjustment_mode
 87        r"""The adjustment mode for gradient clipping."""
 88        self.s_max: float = s_max
 89        r"""The hyperparameter s_max."""
 90        self.clamp_threshold: float = clamp_threshold
 91        r"""The clamp threshold for task embedding gradient compensation."""
 92        self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor
 93        r"""The mask sparsity regularization factor."""
 94        self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode
 95        r"""The mask sparsity regularization mode."""
 96        self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg(
 97            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
 98        )
 99        r"""The mask sparsity regularizer."""
100        self.task_embedding_init_mode: str = task_embedding_init_mode
101        r"""The task embedding initialization mode."""
102        self.alpha: float | None = alpha
103        r"""The hyperparameter alpha for `hat_const_alpha`."""
104        # self.epsilon: float | None = None
105        # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`."""
106
107        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
108        r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """
109
110        # set manual optimization
111        self.automatic_optimization = False
112
113        HAT.sanity_check(self)
114
115    def sanity_check(self) -> None:
116        r"""Sanity check."""
117
118        # check the backbone and heads
119        if not isinstance(self.backbone, HATMaskBackbone):
120            raise ValueError("The backbone should be an instance of `HATMaskBackbone`.")
121        if not isinstance(self.heads, HeadsTIL):
122            raise ValueError("The heads should be an instance of `HeadsTIL`.")
123
124        # check marker sparsity regularization mode
125        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
126            raise ValueError(
127                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
128            )
129
130        # check task embedding initialization mode
131        if self.task_embedding_init_mode not in [
132            "N01",
133            "U01",
134            "U-10",
135            "masked",
136            "unmasked",
137        ]:
138            raise ValueError(
139                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
140            )
141
142        # check adjustment mode `hat_const_alpha`
143        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
144            raise ValueError(
145                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
146            )
147
148    def on_train_start(self) -> None:
149        r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task."""
150
151        self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode)
152
153        self.backbone.initialize_independent_bn()
154
155        # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
156        if self.task_id == 1:
157            for layer_name in self.backbone.weighted_layer_names:
158                layer = self.backbone.get_layer_by_name(
159                    layer_name
160                )  # get the layer by its name
161                num_units = layer.weight.shape[0]
162
163                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
164                    num_units
165                ).to(
166                    self.device
167                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
168
169                # self.neuron_first_task[layer_name] = [None] * num_units
170
171    def clip_grad_by_adjustment(
172        self,
173        **kwargs,
174    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
175        r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
176
177        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
178        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
179
180        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters.
181        See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
182
183        **Returns:**
184        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
185        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors.
186        - **capacity** (`Tensor`): the calculated network capacity.
187        """
188
189        # initialize network capacity metric
190        capacity = HATNetworkCapacityMetric().to(self.device)
191        adjustment_rate_weight = {}
192        adjustment_rate_bias = {}
193
194        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist)
195        for layer_name in self.backbone.weighted_layer_names:
196
197            layer = self.backbone.get_layer_by_name(
198                layer_name
199            )  # get the layer by its name
200
201            # placeholder for the adjustment rate to avoid the error of using it before assignment
202            adjustment_rate_weight_layer = 1
203            adjustment_rate_bias_layer = 1
204
205            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
206                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
207                layer_name=layer_name,
208                aggregation_mode="min",
209            )
210
211            if self.adjustment_mode == "hat":
212                adjustment_rate_weight_layer = 1 - weight_mask
213                adjustment_rate_bias_layer = 1 - bias_mask
214
215            elif self.adjustment_mode == "hat_random":
216                adjustment_rate_weight_layer = torch.rand_like(
217                    weight_mask
218                ) * weight_mask + (1 - weight_mask)
219                adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + (
220                    1 - bias_mask
221                )
222
223            elif self.adjustment_mode == "hat_const_alpha":
224                adjustment_rate_weight_layer = self.alpha * torch.ones_like(
225                    weight_mask
226                ) * weight_mask + (1 - weight_mask)
227                adjustment_rate_bias_layer = self.alpha * torch.ones_like(
228                    bias_mask
229                ) * bias_mask + (1 - bias_mask)
230
231            elif self.adjustment_mode == "hat_const_1":
232                adjustment_rate_weight_layer = torch.ones_like(
233                    weight_mask
234                ) * weight_mask + (1 - weight_mask)
235                adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + (
236                    1 - bias_mask
237                )
238
239            # apply the adjustment rate to the gradients
240            layer.weight.grad.data *= adjustment_rate_weight_layer
241            if layer.bias is not None:
242                layer.bias.grad.data *= adjustment_rate_bias_layer
243
244            # store the adjustment rate for logging
245            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
246            if layer.bias is not None:
247                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
248
249            # update network capacity metric
250            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
251
252        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
253
254    def compensate_task_embedding_gradients(
255        self,
256        batch_idx: int,
257        num_batches: int,
258    ) -> None:
259        r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
260
261        **Args:**
262        - **batch_idx** (`int`): the current training batch index.
263        - **num_batches** (`int`): the total number of training batches.
264        """
265
266        for te in self.backbone.task_embedding_t.values():
267            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
268                batch_idx - 1
269            ) / (
270                num_batches - 1
271            )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
272
273            num = (
274                torch.cosh(
275                    torch.clamp(
276                        anneal_scalar * te.weight.data,
277                        -self.clamp_threshold,
278                        self.clamp_threshold,
279                    )
280                )
281                + 1
282            )
283
284            den = torch.cosh(te.weight.data) + 1
285
286            compensation = self.s_max / anneal_scalar * num / den
287
288            te.weight.grad.data *= compensation
289
290    def forward(
291        self,
292        input: torch.Tensor,
293        stage: str,
294        task_id: int | None = None,
295        batch_idx: int | None = None,
296        num_batches: int | None = None,
297    ) -> tuple[Tensor, dict[str, Tensor]]:
298        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
299
300        **Args:**
301        - **input** (`Tensor`): The input tensor from data.
302        - **stage** (`str`): the stage of the forward pass; one of:
303            1. 'train': training stage.
304            2. 'validation': validation stage.
305            3. 'test': testing stage.
306        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
307        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
308        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
309
310        **Returns:**
311        - **logits** (`Tensor`): the output logits tensor.
312        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ).
313        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
314        """
315        feature, mask, activations = self.backbone(
316            input,
317            stage=stage,
318            s_max=self.s_max if stage == "train" or stage == "validation" else None,
319            batch_idx=batch_idx if stage == "train" else None,
320            num_batches=num_batches if stage == "train" else None,
321            test_task_id=task_id if stage == "test" else None,
322        )
323        logits = self.heads(feature, task_id)
324
325        return (
326            logits
327            if self.if_forward_func_return_logits_only
328            else (logits, mask, activations)
329        )
330
331    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
332        r"""Training step for current task `self.task_id`.
333
334        **Args:**
335        - **batch** (`Any`): a batch of training data.
336        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
337
338        **Returns:**
339        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
340        """
341        x, y = batch
342
343        # zero the gradients before forward pass in manual optimization mode
344        opt = self.optimizers()
345        opt.zero_grad()
346
347        # classification loss
348        num_batches = self.trainer.num_training_batches
349        logits, mask, activations = self.forward(
350            x,
351            stage="train",
352            batch_idx=batch_idx,
353            num_batches=num_batches,
354            task_id=self.task_id,
355        )
356        loss_cls = self.criterion(logits, y)
357
358        # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
359        loss_reg, network_sparsity = self.mark_sparsity_reg(
360            mask, self.cumulative_mask_for_previous_tasks
361        )
362
363        # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
364        loss = loss_cls + loss_reg
365
366        # backward step (manually)
367        self.manual_backward(loss)  # calculate the gradients
368        # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
369        # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
370
371        adjustment_rate_weight, adjustment_rate_bias, capacity = (
372            self.clip_grad_by_adjustment(
373                network_sparsity=network_sparsity,  # passed for compatibility with AdaHAT, which inherits this method
374            )
375        )
376        # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
377        self.compensate_task_embedding_gradients(
378            batch_idx=batch_idx,
379            num_batches=num_batches,
380        )
381        # update parameters with the modified gradients
382        opt.step()
383
384        # accuracy of the batch
385        acc = (logits.argmax(dim=1) == y).float().mean()
386
387        return {
388            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
389            "loss_cls": loss_cls,
390            "loss_reg": loss_reg,
391            "acc": acc,
392            "activations": activations,
393            "logits": logits,
394            "mask": mask,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
395            "input": x,  # return the input batch for Captum to use
396            "target": y,  # return the target batch for Captum to use
397            "adjustment_rate_weight": adjustment_rate_weight,  # return the adjustment rate for weights and biases for logging
398            "adjustment_rate_bias": adjustment_rate_bias,
399            "capacity": capacity,  # return the network capacity for logging
400        }
401
402    def on_train_end(self) -> None:
403        r"""The mask and update the cumulative mask after training the task."""
404
405        # store the mask for the current task
406        mask_t = self.backbone.store_mask()
407
408        # store the batch normalization if necessary
409        self.backbone.store_bn()
410
411        # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
412        self.cumulative_mask_for_previous_tasks = {
413            layer_name: torch.max(
414                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
415            )
416            for layer_name in self.backbone.weighted_layer_names
417        }
418
419    def validation_step(self, batch: Any) -> dict[str, Tensor]:
420        r"""Validation step for current task `self.task_id`.
421
422        **Args:**
423        - **batch** (`Any`): a batch of validation data.
424
425        **Returns:**
426        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
427        """
428        x, y = batch
429        logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
430        loss_cls = self.criterion(logits, y)
431        acc = (logits.argmax(dim=1) == y).float().mean()
432
433        return {
434            "loss_cls": loss_cls,
435            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
436        }
437
438    def test_step(
439        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
440    ) -> dict[str, Tensor]:
441        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
442
443        **Args:**
444        - **batch** (`Any`): a batch of test data.
445        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
446
447        **Returns:**
448        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
449        """
450        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
451
452        x, y = batch
453        logits, _, _ = self.forward(
454            x,
455            stage="test",
456            task_id=test_task_id,
457        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
458        loss_cls = self.criterion(logits, y)
459        acc = (logits.argmax(dim=1) == y).float().mean()
460
461        return {
462            "loss_cls": loss_cls,
463            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
464        }
class HAT(clarena.cl_algorithms.base.CLAlgorithm):
 25class HAT(CLAlgorithm):
 26    r"""[HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm.
 27
 28    An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: HATMaskBackbone,
 34        heads: HeadsTIL,
 35        adjustment_mode: str,
 36        s_max: float,
 37        clamp_threshold: float,
 38        mask_sparsity_reg_factor: float,
 39        mask_sparsity_reg_mode: str = "original",
 40        task_embedding_init_mode: str = "N01",
 41        alpha: float | None = None,
 42        non_algorithmic_hparams: dict[str, Any] = {},
 43    ) -> None:
 44        r"""Initialize the HAT algorithm with the network.
 45
 46        **Args:**
 47        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 48        - **heads** (`HeadsTIL`): output heads. HAT only supports TIL (Task-Incremental Learning).
 49        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 50            1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 51            2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52            3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53            4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 55        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 56        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 57        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 58            1. 'original' (default): the original mask sparsity regularization in the HAT paper.
 59            2. 'cross': the cross version of mask sparsity regularization.
 60        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 61            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 62            2. 'U-11': uniform distribution $U(-1, 1)$.
 63            3. 'U01': uniform distribution $U(0, 1)$.
 64            4. 'U-10': uniform distribution $U(-1, 0)$.
 65            5. 'last': inherit the task embedding from the last task.
 66        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'.
 67        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 68
 69        """
 70        super().__init__(
 71            backbone=backbone,
 72            heads=heads,
 73            non_algorithmic_hparams=non_algorithmic_hparams,
 74        )
 75
 76        # save additional algorithmic hyperparameters
 77        self.save_hyperparameters(
 78            "adjustment_mode",
 79            "s_max",
 80            "clamp_threshold",
 81            "mask_sparsity_reg_factor",
 82            "mask_sparsity_reg_mode",
 83            "task_embedding_init_mode",
 84            "alpha",
 85        )
 86
 87        self.adjustment_mode: str = adjustment_mode
 88        r"""The adjustment mode for gradient clipping."""
 89        self.s_max: float = s_max
 90        r"""The hyperparameter s_max."""
 91        self.clamp_threshold: float = clamp_threshold
 92        r"""The clamp threshold for task embedding gradient compensation."""
 93        self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor
 94        r"""The mask sparsity regularization factor."""
 95        self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode
 96        r"""The mask sparsity regularization mode."""
 97        self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg(
 98            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
 99        )
100        r"""The mask sparsity regularizer."""
101        self.task_embedding_init_mode: str = task_embedding_init_mode
102        r"""The task embedding initialization mode."""
103        self.alpha: float | None = alpha
104        r"""The hyperparameter alpha for `hat_const_alpha`."""
105        # self.epsilon: float | None = None
106        # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`."""
107
108        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
109        r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """
110
111        # set manual optimization
112        self.automatic_optimization = False
113
114        HAT.sanity_check(self)
115
116    def sanity_check(self) -> None:
117        r"""Sanity check."""
118
119        # check the backbone and heads
120        if not isinstance(self.backbone, HATMaskBackbone):
121            raise ValueError("The backbone should be an instance of `HATMaskBackbone`.")
122        if not isinstance(self.heads, HeadsTIL):
123            raise ValueError("The heads should be an instance of `HeadsTIL`.")
124
125        # check marker sparsity regularization mode
126        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
127            raise ValueError(
128                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
129            )
130
131        # check task embedding initialization mode
132        if self.task_embedding_init_mode not in [
133            "N01",
134            "U01",
135            "U-10",
136            "masked",
137            "unmasked",
138        ]:
139            raise ValueError(
140                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
141            )
142
143        # check adjustment mode `hat_const_alpha`
144        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
145            raise ValueError(
146                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
147            )
148
149    def on_train_start(self) -> None:
150        r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task."""
151
152        self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode)
153
154        self.backbone.initialize_independent_bn()
155
156        # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
157        if self.task_id == 1:
158            for layer_name in self.backbone.weighted_layer_names:
159                layer = self.backbone.get_layer_by_name(
160                    layer_name
161                )  # get the layer by its name
162                num_units = layer.weight.shape[0]
163
164                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
165                    num_units
166                ).to(
167                    self.device
168                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
169
170                # self.neuron_first_task[layer_name] = [None] * num_units
171
172    def clip_grad_by_adjustment(
173        self,
174        **kwargs,
175    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
176        r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
177
178        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
179        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
180
181        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters.
182        See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
183
184        **Returns:**
185        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
186        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors.
187        - **capacity** (`Tensor`): the calculated network capacity.
188        """
189
190        # initialize network capacity metric
191        capacity = HATNetworkCapacityMetric().to(self.device)
192        adjustment_rate_weight = {}
193        adjustment_rate_bias = {}
194
195        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist)
196        for layer_name in self.backbone.weighted_layer_names:
197
198            layer = self.backbone.get_layer_by_name(
199                layer_name
200            )  # get the layer by its name
201
202            # placeholder for the adjustment rate to avoid the error of using it before assignment
203            adjustment_rate_weight_layer = 1
204            adjustment_rate_bias_layer = 1
205
206            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
207                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
208                layer_name=layer_name,
209                aggregation_mode="min",
210            )
211
212            if self.adjustment_mode == "hat":
213                adjustment_rate_weight_layer = 1 - weight_mask
214                adjustment_rate_bias_layer = 1 - bias_mask
215
216            elif self.adjustment_mode == "hat_random":
217                adjustment_rate_weight_layer = torch.rand_like(
218                    weight_mask
219                ) * weight_mask + (1 - weight_mask)
220                adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + (
221                    1 - bias_mask
222                )
223
224            elif self.adjustment_mode == "hat_const_alpha":
225                adjustment_rate_weight_layer = self.alpha * torch.ones_like(
226                    weight_mask
227                ) * weight_mask + (1 - weight_mask)
228                adjustment_rate_bias_layer = self.alpha * torch.ones_like(
229                    bias_mask
230                ) * bias_mask + (1 - bias_mask)
231
232            elif self.adjustment_mode == "hat_const_1":
233                adjustment_rate_weight_layer = torch.ones_like(
234                    weight_mask
235                ) * weight_mask + (1 - weight_mask)
236                adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + (
237                    1 - bias_mask
238                )
239
240            # apply the adjustment rate to the gradients
241            layer.weight.grad.data *= adjustment_rate_weight_layer
242            if layer.bias is not None:
243                layer.bias.grad.data *= adjustment_rate_bias_layer
244
245            # store the adjustment rate for logging
246            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
247            if layer.bias is not None:
248                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
249
250            # update network capacity metric
251            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
252
253        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
254
255    def compensate_task_embedding_gradients(
256        self,
257        batch_idx: int,
258        num_batches: int,
259    ) -> None:
260        r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
261
262        **Args:**
263        - **batch_idx** (`int`): the current training batch index.
264        - **num_batches** (`int`): the total number of training batches.
265        """
266
267        for te in self.backbone.task_embedding_t.values():
268            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
269                batch_idx - 1
270            ) / (
271                num_batches - 1
272            )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
273
274            num = (
275                torch.cosh(
276                    torch.clamp(
277                        anneal_scalar * te.weight.data,
278                        -self.clamp_threshold,
279                        self.clamp_threshold,
280                    )
281                )
282                + 1
283            )
284
285            den = torch.cosh(te.weight.data) + 1
286
287            compensation = self.s_max / anneal_scalar * num / den
288
289            te.weight.grad.data *= compensation
290
291    def forward(
292        self,
293        input: torch.Tensor,
294        stage: str,
295        task_id: int | None = None,
296        batch_idx: int | None = None,
297        num_batches: int | None = None,
298    ) -> tuple[Tensor, dict[str, Tensor]]:
299        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
300
301        **Args:**
302        - **input** (`Tensor`): The input tensor from data.
303        - **stage** (`str`): the stage of the forward pass; one of:
304            1. 'train': training stage.
305            2. 'validation': validation stage.
306            3. 'test': testing stage.
307        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
308        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
309        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
310
311        **Returns:**
312        - **logits** (`Tensor`): the output logits tensor.
313        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ).
314        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
315        """
316        feature, mask, activations = self.backbone(
317            input,
318            stage=stage,
319            s_max=self.s_max if stage == "train" or stage == "validation" else None,
320            batch_idx=batch_idx if stage == "train" else None,
321            num_batches=num_batches if stage == "train" else None,
322            test_task_id=task_id if stage == "test" else None,
323        )
324        logits = self.heads(feature, task_id)
325
326        return (
327            logits
328            if self.if_forward_func_return_logits_only
329            else (logits, mask, activations)
330        )
331
332    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
333        r"""Training step for current task `self.task_id`.
334
335        **Args:**
336        - **batch** (`Any`): a batch of training data.
337        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
338
339        **Returns:**
340        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
341        """
342        x, y = batch
343
344        # zero the gradients before forward pass in manual optimization mode
345        opt = self.optimizers()
346        opt.zero_grad()
347
348        # classification loss
349        num_batches = self.trainer.num_training_batches
350        logits, mask, activations = self.forward(
351            x,
352            stage="train",
353            batch_idx=batch_idx,
354            num_batches=num_batches,
355            task_id=self.task_id,
356        )
357        loss_cls = self.criterion(logits, y)
358
359        # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
360        loss_reg, network_sparsity = self.mark_sparsity_reg(
361            mask, self.cumulative_mask_for_previous_tasks
362        )
363
364        # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
365        loss = loss_cls + loss_reg
366
367        # backward step (manually)
368        self.manual_backward(loss)  # calculate the gradients
369        # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
370        # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
371
372        adjustment_rate_weight, adjustment_rate_bias, capacity = (
373            self.clip_grad_by_adjustment(
374                network_sparsity=network_sparsity,  # passed for compatibility with AdaHAT, which inherits this method
375            )
376        )
377        # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
378        self.compensate_task_embedding_gradients(
379            batch_idx=batch_idx,
380            num_batches=num_batches,
381        )
382        # update parameters with the modified gradients
383        opt.step()
384
385        # accuracy of the batch
386        acc = (logits.argmax(dim=1) == y).float().mean()
387
388        return {
389            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
390            "loss_cls": loss_cls,
391            "loss_reg": loss_reg,
392            "acc": acc,
393            "activations": activations,
394            "logits": logits,
395            "mask": mask,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
396            "input": x,  # return the input batch for Captum to use
397            "target": y,  # return the target batch for Captum to use
398            "adjustment_rate_weight": adjustment_rate_weight,  # return the adjustment rate for weights and biases for logging
399            "adjustment_rate_bias": adjustment_rate_bias,
400            "capacity": capacity,  # return the network capacity for logging
401        }
402
403    def on_train_end(self) -> None:
404        r"""The mask and update the cumulative mask after training the task."""
405
406        # store the mask for the current task
407        mask_t = self.backbone.store_mask()
408
409        # store the batch normalization if necessary
410        self.backbone.store_bn()
411
412        # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
413        self.cumulative_mask_for_previous_tasks = {
414            layer_name: torch.max(
415                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
416            )
417            for layer_name in self.backbone.weighted_layer_names
418        }
419
420    def validation_step(self, batch: Any) -> dict[str, Tensor]:
421        r"""Validation step for current task `self.task_id`.
422
423        **Args:**
424        - **batch** (`Any`): a batch of validation data.
425
426        **Returns:**
427        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
428        """
429        x, y = batch
430        logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
431        loss_cls = self.criterion(logits, y)
432        acc = (logits.argmax(dim=1) == y).float().mean()
433
434        return {
435            "loss_cls": loss_cls,
436            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
437        }
438
439    def test_step(
440        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
441    ) -> dict[str, Tensor]:
442        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
443
444        **Args:**
445        - **batch** (`Any`): a batch of test data.
446        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
447
448        **Returns:**
449        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
450        """
451        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
452
453        x, y = batch
454        logits, _, _ = self.forward(
455            x,
456            stage="test",
457            task_id=test_task_id,
458        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
459        loss_cls = self.criterion(logits, y)
460        acc = (logits.argmax(dim=1) == y).float().mean()
461
462        return {
463            "loss_cls": loss_cls,
464            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
465        }

HAT (Hard Attention to the Task) algorithm.

An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.

HAT( backbone: clarena.backbones.HATMaskBackbone, heads: clarena.heads.HeadsTIL, adjustment_mode: str, s_max: float, clamp_threshold: float, mask_sparsity_reg_factor: float, mask_sparsity_reg_mode: str = 'original', task_embedding_init_mode: str = 'N01', alpha: float | None = None, non_algorithmic_hparams: dict[str, typing.Any] = {})
 31    def __init__(
 32        self,
 33        backbone: HATMaskBackbone,
 34        heads: HeadsTIL,
 35        adjustment_mode: str,
 36        s_max: float,
 37        clamp_threshold: float,
 38        mask_sparsity_reg_factor: float,
 39        mask_sparsity_reg_mode: str = "original",
 40        task_embedding_init_mode: str = "N01",
 41        alpha: float | None = None,
 42        non_algorithmic_hparams: dict[str, Any] = {},
 43    ) -> None:
 44        r"""Initialize the HAT algorithm with the network.
 45
 46        **Args:**
 47        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 48        - **heads** (`HeadsTIL`): output heads. HAT only supports TIL (Task-Incremental Learning).
 49        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 50            1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 51            2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52            3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53            4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 55        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 56        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 57        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 58            1. 'original' (default): the original mask sparsity regularization in the HAT paper.
 59            2. 'cross': the cross version of mask sparsity regularization.
 60        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 61            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 62            2. 'U-11': uniform distribution $U(-1, 1)$.
 63            3. 'U01': uniform distribution $U(0, 1)$.
 64            4. 'U-10': uniform distribution $U(-1, 0)$.
 65            5. 'last': inherit the task embedding from the last task.
 66        - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'.
 67        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 68
 69        """
 70        super().__init__(
 71            backbone=backbone,
 72            heads=heads,
 73            non_algorithmic_hparams=non_algorithmic_hparams,
 74        )
 75
 76        # save additional algorithmic hyperparameters
 77        self.save_hyperparameters(
 78            "adjustment_mode",
 79            "s_max",
 80            "clamp_threshold",
 81            "mask_sparsity_reg_factor",
 82            "mask_sparsity_reg_mode",
 83            "task_embedding_init_mode",
 84            "alpha",
 85        )
 86
 87        self.adjustment_mode: str = adjustment_mode
 88        r"""The adjustment mode for gradient clipping."""
 89        self.s_max: float = s_max
 90        r"""The hyperparameter s_max."""
 91        self.clamp_threshold: float = clamp_threshold
 92        r"""The clamp threshold for task embedding gradient compensation."""
 93        self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor
 94        r"""The mask sparsity regularization factor."""
 95        self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode
 96        r"""The mask sparsity regularization mode."""
 97        self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg(
 98            factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode
 99        )
100        r"""The mask sparsity regularizer."""
101        self.task_embedding_init_mode: str = task_embedding_init_mode
102        r"""The task embedding initialization mode."""
103        self.alpha: float | None = alpha
104        r"""The hyperparameter alpha for `hat_const_alpha`."""
105        # self.epsilon: float | None = None
106        # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`."""
107
108        self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {}
109        r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """
110
111        # set manual optimization
112        self.automatic_optimization = False
113
114        HAT.sanity_check(self)

Initialize the HAT algorithm with the network.

Args:

  • backbone (HATMaskBackbone): must be a backbone network with the HAT mask mechanism.
  • heads (HeadsTIL): output heads. HAT only supports TIL (Task-Incremental Learning).
  • adjustment_mode (str): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
    1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
    2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the AdaHAT paper.
    3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value alpha. See "Baselines" in Sec. 4.1 in the AdaHAT paper.
    4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the AdaHAT paper.
  • s_max (float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper.
  • clamp_threshold (float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.
  • mask_sparsity_reg_factor (float): hyperparameter, the regularization factor for mask sparsity.
  • mask_sparsity_reg_mode (str): the mode of mask sparsity regularization, must be one of:
    1. 'original' (default): the original mask sparsity regularization in the HAT paper.
    2. 'cross': the cross version of mask sparsity regularization.
  • task_embedding_init_mode (str): the initialization mode for task embeddings, must be one of:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit the task embedding from the last task.
  • alpha (float | None): the alpha in the 'HAT-const-alpha' mode. Applies only when adjustment_mode is 'hat_const_alpha'.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
adjustment_mode: str

The adjustment mode for gradient clipping.

s_max: float

The hyperparameter s_max.

clamp_threshold: float

The clamp threshold for task embedding gradient compensation.

mask_sparsity_reg_factor: float

The mask sparsity regularization factor.

mask_sparsity_reg_mode: str

The mask sparsity regularization mode.

The mask sparsity regularizer.

task_embedding_init_mode: str

The task embedding initialization mode.

alpha: float | None

The hyperparameter alpha for hat_const_alpha.

cumulative_mask_for_previous_tasks: dict[str, torch.Tensor]

The cumulative binary attention mask $\mathrm{M}^{self.task_id). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ).

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
116    def sanity_check(self) -> None:
117        r"""Sanity check."""
118
119        # check the backbone and heads
120        if not isinstance(self.backbone, HATMaskBackbone):
121            raise ValueError("The backbone should be an instance of `HATMaskBackbone`.")
122        if not isinstance(self.heads, HeadsTIL):
123            raise ValueError("The heads should be an instance of `HeadsTIL`.")
124
125        # check marker sparsity regularization mode
126        if self.mask_sparsity_reg_mode not in ["original", "cross"]:
127            raise ValueError(
128                "The mask_sparsity_reg_mode should be one of 'original', 'cross'."
129            )
130
131        # check task embedding initialization mode
132        if self.task_embedding_init_mode not in [
133            "N01",
134            "U01",
135            "U-10",
136            "masked",
137            "unmasked",
138        ]:
139            raise ValueError(
140                "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'."
141            )
142
143        # check adjustment mode `hat_const_alpha`
144        if self.adjustment_mode == "hat_const_alpha" and self.alpha is None:
145            raise ValueError(
146                "Alpha should be given when the adjustment_mode is 'hat_const_alpha'."
147            )

Sanity check.

def on_train_start(self) -> None:
149    def on_train_start(self) -> None:
150        r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task."""
151
152        self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode)
153
154        self.backbone.initialize_independent_bn()
155
156        # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
157        if self.task_id == 1:
158            for layer_name in self.backbone.weighted_layer_names:
159                layer = self.backbone.get_layer_by_name(
160                    layer_name
161                )  # get the layer by its name
162                num_units = layer.weight.shape[0]
163
164                self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros(
165                    num_units
166                ).to(
167                    self.device
168                )  # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
169
170                # self.neuron_first_task[layer_name] = [None] * num_units

Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.

def clip_grad_by_adjustment( self, **kwargs) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], torch.Tensor]:
172    def clip_grad_by_adjustment(
173        self,
174        **kwargs,
175    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
176        r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
177
178        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system.
179        This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
180
181        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters.
182        See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
183
184        **Returns:**
185        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
186        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors.
187        - **capacity** (`Tensor`): the calculated network capacity.
188        """
189
190        # initialize network capacity metric
191        capacity = HATNetworkCapacityMetric().to(self.device)
192        adjustment_rate_weight = {}
193        adjustment_rate_bias = {}
194
195        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist)
196        for layer_name in self.backbone.weighted_layer_names:
197
198            layer = self.backbone.get_layer_by_name(
199                layer_name
200            )  # get the layer by its name
201
202            # placeholder for the adjustment rate to avoid the error of using it before assignment
203            adjustment_rate_weight_layer = 1
204            adjustment_rate_bias_layer = 1
205
206            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
207                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
208                layer_name=layer_name,
209                aggregation_mode="min",
210            )
211
212            if self.adjustment_mode == "hat":
213                adjustment_rate_weight_layer = 1 - weight_mask
214                adjustment_rate_bias_layer = 1 - bias_mask
215
216            elif self.adjustment_mode == "hat_random":
217                adjustment_rate_weight_layer = torch.rand_like(
218                    weight_mask
219                ) * weight_mask + (1 - weight_mask)
220                adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + (
221                    1 - bias_mask
222                )
223
224            elif self.adjustment_mode == "hat_const_alpha":
225                adjustment_rate_weight_layer = self.alpha * torch.ones_like(
226                    weight_mask
227                ) * weight_mask + (1 - weight_mask)
228                adjustment_rate_bias_layer = self.alpha * torch.ones_like(
229                    bias_mask
230                ) * bias_mask + (1 - bias_mask)
231
232            elif self.adjustment_mode == "hat_const_1":
233                adjustment_rate_weight_layer = torch.ones_like(
234                    weight_mask
235                ) * weight_mask + (1 - weight_mask)
236                adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + (
237                    1 - bias_mask
238                )
239
240            # apply the adjustment rate to the gradients
241            layer.weight.grad.data *= adjustment_rate_weight_layer
242            if layer.bias is not None:
243                layer.bias.grad.data *= adjustment_rate_bias_layer
244
245            # store the adjustment rate for logging
246            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
247            if layer.bias is not None:
248                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
249
250            # update network capacity metric
251            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
252
253        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()

Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.

Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.

Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.

Returns:

  • adjustment_rate_weight (dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors.
  • adjustment_rate_bias (dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer name and values (Tensor) are the adjustment rate tensors.
  • capacity (Tensor): the calculated network capacity.
def compensate_task_embedding_gradients(self, batch_idx: int, num_batches: int) -> None:
255    def compensate_task_embedding_gradients(
256        self,
257        batch_idx: int,
258        num_batches: int,
259    ) -> None:
260        r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
261
262        **Args:**
263        - **batch_idx** (`int`): the current training batch index.
264        - **num_batches** (`int`): the total number of training batches.
265        """
266
267        for te in self.backbone.task_embedding_t.values():
268            anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * (
269                batch_idx - 1
270            ) / (
271                num_batches - 1
272            )  # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
273
274            num = (
275                torch.cosh(
276                    torch.clamp(
277                        anneal_scalar * te.weight.data,
278                        -self.clamp_threshold,
279                        self.clamp_threshold,
280                    )
281                )
282                + 1
283            )
284
285            den = torch.cosh(te.weight.data) + 1
286
287            compensation = self.s_max / anneal_scalar * num / den
288
289            te.weight.grad.data *= compensation

Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.

Args:

  • batch_idx (int): the current training batch index.
  • num_batches (int): the total number of training batches.
def forward( self, input: torch.Tensor, stage: str, task_id: int | None = None, batch_idx: int | None = None, num_batches: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
291    def forward(
292        self,
293        input: torch.Tensor,
294        stage: str,
295        task_id: int | None = None,
296        batch_idx: int | None = None,
297        num_batches: int | None = None,
298    ) -> tuple[Tensor, dict[str, Tensor]]:
299        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
300
301        **Args:**
302        - **input** (`Tensor`): The input tensor from data.
303        - **stage** (`str`): the stage of the forward pass; one of:
304            1. 'train': training stage.
305            2. 'validation': validation stage.
306            3. 'test': testing stage.
307        - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
308        - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`.
309        - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`.
310
311        **Returns:**
312        - **logits** (`Tensor`): the output logits tensor.
313        - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ).
314        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class.
315        """
316        feature, mask, activations = self.backbone(
317            input,
318            stage=stage,
319            s_max=self.s_max if stage == "train" or stage == "validation" else None,
320            batch_idx=batch_idx if stage == "train" else None,
321            num_batches=num_batches if stage == "train" else None,
322            test_task_id=task_id if stage == "test" else None,
323        )
324        logits = self.heads(feature, task_id)
325
326        return (
327            logits
328            if self.if_forward_func_return_logits_only
329            else (logits, mask, activations)
330        )

The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): the stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • task_id (int| None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task self.task_id. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
  • batch_idx (int | None): the current batch index. Applies only to training stage. For other stages, it is default None.
  • num_batches (int | None): the total number of batches. Applies only to training stage. For other stages, it is default None.

Returns:

  • logits (Tensor): the output logits tensor.
  • mask (dict[str, Tensor]): the mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units, ).
  • activations (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this forward() method of HAT class.
def training_step(self, batch: Any, batch_idx: int) -> dict[str, torch.Tensor]:
332    def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]:
333        r"""Training step for current task `self.task_id`.
334
335        **Args:**
336        - **batch** (`Any`): a batch of training data.
337        - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
338
339        **Returns:**
340        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
341        """
342        x, y = batch
343
344        # zero the gradients before forward pass in manual optimization mode
345        opt = self.optimizers()
346        opt.zero_grad()
347
348        # classification loss
349        num_batches = self.trainer.num_training_batches
350        logits, mask, activations = self.forward(
351            x,
352            stage="train",
353            batch_idx=batch_idx,
354            num_batches=num_batches,
355            task_id=self.task_id,
356        )
357        loss_cls = self.criterion(logits, y)
358
359        # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
360        loss_reg, network_sparsity = self.mark_sparsity_reg(
361            mask, self.cumulative_mask_for_previous_tasks
362        )
363
364        # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
365        loss = loss_cls + loss_reg
366
367        # backward step (manually)
368        self.manual_backward(loss)  # calculate the gradients
369        # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
370        # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
371
372        adjustment_rate_weight, adjustment_rate_bias, capacity = (
373            self.clip_grad_by_adjustment(
374                network_sparsity=network_sparsity,  # passed for compatibility with AdaHAT, which inherits this method
375            )
376        )
377        # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
378        self.compensate_task_embedding_gradients(
379            batch_idx=batch_idx,
380            num_batches=num_batches,
381        )
382        # update parameters with the modified gradients
383        opt.step()
384
385        # accuracy of the batch
386        acc = (logits.argmax(dim=1) == y).float().mean()
387
388        return {
389            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
390            "loss_cls": loss_cls,
391            "loss_reg": loss_reg,
392            "acc": acc,
393            "activations": activations,
394            "logits": logits,
395            "mask": mask,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
396            "input": x,  # return the input batch for Captum to use
397            "target": y,  # return the target batch for Captum to use
398            "adjustment_rate_weight": adjustment_rate_weight,  # return the adjustment rate for weights and biases for logging
399            "adjustment_rate_bias": adjustment_rate_bias,
400            "capacity": capacity,  # return the network capacity for logging
401        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.
  • batch_idx (int): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the HAT paper.

Returns:

  • outputs (dict[str, Tensor]): a dictionary containing loss and other metrics from this training step. Keys (str) are metric names, and values (Tensor) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
def on_train_end(self) -> None:
403    def on_train_end(self) -> None:
404        r"""The mask and update the cumulative mask after training the task."""
405
406        # store the mask for the current task
407        mask_t = self.backbone.store_mask()
408
409        # store the batch normalization if necessary
410        self.backbone.store_bn()
411
412        # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
413        self.cumulative_mask_for_previous_tasks = {
414            layer_name: torch.max(
415                self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name]
416            )
417            for layer_name in self.backbone.weighted_layer_names
418        }

The mask and update the cumulative mask after training the task.

def validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
420    def validation_step(self, batch: Any) -> dict[str, Tensor]:
421        r"""Validation step for current task `self.task_id`.
422
423        **Args:**
424        - **batch** (`Any`): a batch of validation data.
425
426        **Returns:**
427        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
428        """
429        x, y = batch
430        logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
431        loss_cls = self.criterion(logits, y)
432        acc = (logits.argmax(dim=1) == y).float().mean()
433
434        return {
435            "loss_cls": loss_cls,
436            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
437        }

Validation step for current task self.task_id.

Args:

  • batch (Any): a batch of validation data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
439    def test_step(
440        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
441    ) -> dict[str, Tensor]:
442        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
443
444        **Args:**
445        - **batch** (`Any`): a batch of test data.
446        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
447
448        **Returns:**
449        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
450        """
451        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
452
453        x, y = batch
454        logits, _, _ = self.forward(
455            x,
456            stage="test",
457            task_id=test_task_id,
458        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
459        loss_cls = self.criterion(logits, y)
460        acc = (logits.argmax(dim=1) == y).float().mean()
461
462        return {
463            "loss_cls": loss_cls,
464            "acc": acc,  # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
465        }

Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.