clarena.cl_algorithms.wsn

The submodule in cl_algorithms for WSN (Winning Subnetworks) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm.
  3"""
  4
  5__all__ = ["WSN"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12from torch.utils.data import DataLoader
 13
 14from clarena.backbones import WSNMaskBackbone
 15from clarena.cl_algorithms import CLAlgorithm
 16from clarena.heads import HeadDIL, HeadsTIL
 17
 18# always get logger for built-in logging in each module
 19pylogger = logging.getLogger(__name__)
 20
 21
 22class WSN(CLAlgorithm):
 23    r"""[WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm.
 24
 25    An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task.
 26    """
 27
 28    def __init__(
 29        self,
 30        backbone: WSNMaskBackbone,
 31        heads: HeadsTIL | HeadDIL,
 32        mask_percentage: float,
 33        parameter_score_init_mode: str = "default",
 34        non_algorithmic_hparams: dict[str, Any] = {},
 35        **kwargs,
 36    ) -> None:
 37        r"""Initialize the WSN algorithm with the network.
 38
 39        **Args:**
 40        - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism.
 41        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. WSN supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 42        - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
 43        - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of:
 44            1. 'default': the default initialization in the original WSN code.
 45            2. 'N01': standard normal distribution $N(0, 1)$.
 46            3. 'U01': uniform distribution $U(0, 1)$.
 47        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 48        - **kwargs**: Reserved for multiple inheritance.
 49
 50        """
 51        super().__init__(
 52            backbone=backbone,
 53            heads=heads,
 54            non_algorithmic_hparams=non_algorithmic_hparams,
 55            **kwargs,
 56        )
 57
 58        self.mask_percentage: float = mask_percentage
 59        r"""The percentage of parameters to be used for each task."""
 60        self.parameter_score_init_mode: str = parameter_score_init_mode
 61        r"""The parameter score initialization mode."""
 62
 63        # save additional algorithmic hyperparameters
 64        self.save_hyperparameters(
 65            "mask_percentage",
 66            "parameter_score_init_mode",
 67        )
 68
 69        self.weight_masks: dict[int, dict[str, Tensor]] = {}
 70        r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight."""
 71        self.bias_masks: dict[int, dict[str, Tensor]] = {}
 72        r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
 73
 74        self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {}
 75        r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight."""
 76        self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {}
 77        r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
 78
 79        # set manual optimization
 80        self.automatic_optimization = False
 81
 82        WSN.sanity_check(self)
 83
 84    def sanity_check(self) -> None:
 85        r"""Sanity check."""
 86
 87        # check the backbone and heads
 88        if not isinstance(self.backbone, WSNMaskBackbone):
 89            raise ValueError("The backbone should be an instance of WSNMaskBackbone.")
 90        if not isinstance(self.heads, HeadsTIL):
 91            raise ValueError("The heads should be an instance of `HeadsTIL`.")
 92
 93        # check the mask percentage
 94        if not (0 < self.mask_percentage <= 1):
 95            raise ValueError(
 96                f"Mask percentage should be in (0, 1], but got {self.mask_percentage}."
 97            )
 98
 99    def on_train_start(self) -> None:
100        r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task."""
101
102        self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode)
103
104        # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
105        if self.task_id == 1:
106            for layer_name in self.backbone.weighted_layer_names:
107                layer = self.backbone.get_layer_by_name(
108                    layer_name
109                )  # get the layer by its name
110
111                self.cumulative_weight_mask_for_previous_tasks[layer_name] = (
112                    torch.zeros_like(layer.weight).to(self.device)
113                )
114                if layer.bias is not None:
115                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = (
116                        torch.zeros_like(layer.bias).to(self.device)
117                    )
118                else:
119                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
120                # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$)
121
122    def clip_grad_by_mask(
123        self,
124    ) -> None:
125        r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf)."""
126
127        for layer_name in self.backbone.weighted_layer_names:
128            layer = self.backbone.get_layer_by_name(layer_name)
129
130            layer.weight.grad.data *= (
131                1 - self.cumulative_weight_mask_for_previous_tasks[layer_name]
132            )
133            if layer.bias is not None:
134                layer.bias.grad.data *= (
135                    1 - self.cumulative_bias_mask_for_previous_tasks[layer_name]
136                )
137
138    def forward(
139        self,
140        input: torch.Tensor,
141        stage: str,
142        task_id: int | None = None,
143    ) -> tuple[Tensor, dict[str, Tensor]]:
144        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
145
146        **Args:**
147        - **input** (`Tensor`): the input tensor from data.
148        - **stage** (`str`): the stage of the forward pass, should be one of:
149            1. 'train': training stage.
150            2. 'validation': validation stage.
151            3. 'test': testing stage.
152        - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).
153
154        **Returns:**
155        - **logits** (`Tensor`): the output logits tensor.
156        - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight.
157        - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`.
158        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
159        """
160        feature, weight_mask, bias_mask, activations = self.backbone(
161            input,
162            stage=stage,
163            mask_percentage=self.mask_percentage,
164            test_mask=(
165                (self.weight_masks[task_id], self.bias_masks[task_id])
166                if stage == "test"
167                else None
168            ),
169        )
170        logits = self.heads(feature, task_id)
171
172        return (
173            logits
174            if self.if_forward_func_return_logits_only
175            else (logits, weight_mask, bias_mask, activations)
176        )
177
178    def training_step(self, batch: Any) -> dict[str, Tensor]:
179        r"""Training step for current task `self.task_id`.
180
181        **Args:**
182        - **batch** (`Any`): a batch of training data.
183
184        **Returns:**
185        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
186        """
187        x, y = batch
188
189        # zero the gradients before forward pass in manual optimization mode
190        opt = self.optimizers()
191        opt.zero_grad()
192
193        # classification loss
194        logits, weight_mask, bias_mask, activations = self.forward(
195            x, stage="train", task_id=self.task_id
196        )
197        loss_cls = self.criterion(logits, y)
198
199        # total loss
200        loss = loss_cls
201
202        # backward step (manually)
203        self.manual_backward(loss)  # calculate the gradients
204        # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
205        self.clip_grad_by_mask()
206
207        # update parameters with the modified gradients
208        opt.step()
209
210        # predicted labels
211        preds = logits.argmax(dim=1)
212
213        # accuracy of the batch
214        acc = (preds == y).float().mean()
215
216        return {
217            "preds": preds,
218            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
219            "loss_cls": loss_cls,
220            "acc": acc,
221            "activations": activations,
222            "weight_mask": weight_mask,  # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()`
223            "bias_mask": bias_mask,
224        }
225
226    def on_train_end(self) -> None:
227        r"""Store the weight and bias masks and update the cumulative masks after training the task."""
228
229        # get the weight and bias mask for the current task
230        weight_mask_t = {}
231        bias_mask_t = {}
232        for layer_name in self.backbone.weighted_layer_names:
233            layer = self.backbone.get_layer_by_name(layer_name)
234
235            weight_mask_t[layer_name] = self.backbone.gate_fn.apply(
236                self.backbone.weight_score_t[layer_name].weight, self.mask_percentage
237            )
238            if layer.bias is not None:
239                bias_mask_t[layer_name] = self.backbone.gate_fn.apply(
240                    self.backbone.bias_score_t[layer_name].weight.squeeze(
241                        0
242                    ),  # from (1, output_dim) to (output_dim, )
243                    self.mask_percentage,
244                )
245            else:
246                bias_mask_t[layer_name] = None
247
248        # store the weight and bias mask for the current task
249        self.weight_masks[self.task_id] = weight_mask_t
250        self.bias_masks[self.task_id] = bias_mask_t
251
252        # update the cumulative mask
253        for layer_name in self.backbone.weighted_layer_names:
254            layer = self.backbone.get_layer_by_name(layer_name)
255
256            self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max(
257                self.cumulative_weight_mask_for_previous_tasks[layer_name],
258                weight_mask_t[layer_name],
259            )
260            if layer.bias is not None:
261                print(
262                    self.cumulative_bias_mask_for_previous_tasks[layer_name].shape,
263                    bias_mask_t[layer_name].shape,
264                )
265                self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max(
266                    self.cumulative_bias_mask_for_previous_tasks[layer_name],
267                    bias_mask_t[layer_name],
268                )
269            else:
270                self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
271
272        print(self.cumulative_bias_mask_for_previous_tasks)
273
274    def validation_step(self, batch: Any) -> dict[str, Tensor]:
275        r"""Validation step for current task `self.task_id`.
276
277        **Args:**
278        - **batch** (`Any`): a batch of validation data.
279
280        **Returns:**
281        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
282        """
283        x, y = batch
284        logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
285        loss_cls = self.criterion(logits, y)
286        preds = logits.argmax(dim=1)
287        acc = (preds == y).float().mean()
288
289        return {
290            "loss_cls": loss_cls,
291            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()`
292            "preds": preds,
293        }
294
295    def test_step(
296        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
297    ) -> dict[str, Tensor]:
298        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
299
300        **Args:**
301        - **batch** (`Any`): a batch of test data.
302        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
303
304        **Returns:**
305        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
306        """
307        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
308
309        x, y = batch
310        logits, _, _, _ = self.forward(
311            x,
312            stage="test",
313            task_id=test_task_id,
314        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
315        loss_cls = self.criterion(logits, y)
316        preds = logits.argmax(dim=1)
317        acc = (preds == y).float().mean()
318
319        return {
320            "loss_cls": loss_cls,
321            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_test_batch_end()`
322            "preds": preds,
323        }
class WSN(clarena.cl_algorithms.base.CLAlgorithm):
 23class WSN(CLAlgorithm):
 24    r"""[WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm.
 25
 26    An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task.
 27    """
 28
 29    def __init__(
 30        self,
 31        backbone: WSNMaskBackbone,
 32        heads: HeadsTIL | HeadDIL,
 33        mask_percentage: float,
 34        parameter_score_init_mode: str = "default",
 35        non_algorithmic_hparams: dict[str, Any] = {},
 36        **kwargs,
 37    ) -> None:
 38        r"""Initialize the WSN algorithm with the network.
 39
 40        **Args:**
 41        - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism.
 42        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. WSN supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 43        - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
 44        - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of:
 45            1. 'default': the default initialization in the original WSN code.
 46            2. 'N01': standard normal distribution $N(0, 1)$.
 47            3. 'U01': uniform distribution $U(0, 1)$.
 48        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 49        - **kwargs**: Reserved for multiple inheritance.
 50
 51        """
 52        super().__init__(
 53            backbone=backbone,
 54            heads=heads,
 55            non_algorithmic_hparams=non_algorithmic_hparams,
 56            **kwargs,
 57        )
 58
 59        self.mask_percentage: float = mask_percentage
 60        r"""The percentage of parameters to be used for each task."""
 61        self.parameter_score_init_mode: str = parameter_score_init_mode
 62        r"""The parameter score initialization mode."""
 63
 64        # save additional algorithmic hyperparameters
 65        self.save_hyperparameters(
 66            "mask_percentage",
 67            "parameter_score_init_mode",
 68        )
 69
 70        self.weight_masks: dict[int, dict[str, Tensor]] = {}
 71        r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight."""
 72        self.bias_masks: dict[int, dict[str, Tensor]] = {}
 73        r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
 74
 75        self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {}
 76        r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight."""
 77        self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {}
 78        r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
 79
 80        # set manual optimization
 81        self.automatic_optimization = False
 82
 83        WSN.sanity_check(self)
 84
 85    def sanity_check(self) -> None:
 86        r"""Sanity check."""
 87
 88        # check the backbone and heads
 89        if not isinstance(self.backbone, WSNMaskBackbone):
 90            raise ValueError("The backbone should be an instance of WSNMaskBackbone.")
 91        if not isinstance(self.heads, HeadsTIL):
 92            raise ValueError("The heads should be an instance of `HeadsTIL`.")
 93
 94        # check the mask percentage
 95        if not (0 < self.mask_percentage <= 1):
 96            raise ValueError(
 97                f"Mask percentage should be in (0, 1], but got {self.mask_percentage}."
 98            )
 99
100    def on_train_start(self) -> None:
101        r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task."""
102
103        self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode)
104
105        # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
106        if self.task_id == 1:
107            for layer_name in self.backbone.weighted_layer_names:
108                layer = self.backbone.get_layer_by_name(
109                    layer_name
110                )  # get the layer by its name
111
112                self.cumulative_weight_mask_for_previous_tasks[layer_name] = (
113                    torch.zeros_like(layer.weight).to(self.device)
114                )
115                if layer.bias is not None:
116                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = (
117                        torch.zeros_like(layer.bias).to(self.device)
118                    )
119                else:
120                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
121                # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$)
122
123    def clip_grad_by_mask(
124        self,
125    ) -> None:
126        r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf)."""
127
128        for layer_name in self.backbone.weighted_layer_names:
129            layer = self.backbone.get_layer_by_name(layer_name)
130
131            layer.weight.grad.data *= (
132                1 - self.cumulative_weight_mask_for_previous_tasks[layer_name]
133            )
134            if layer.bias is not None:
135                layer.bias.grad.data *= (
136                    1 - self.cumulative_bias_mask_for_previous_tasks[layer_name]
137                )
138
139    def forward(
140        self,
141        input: torch.Tensor,
142        stage: str,
143        task_id: int | None = None,
144    ) -> tuple[Tensor, dict[str, Tensor]]:
145        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
146
147        **Args:**
148        - **input** (`Tensor`): the input tensor from data.
149        - **stage** (`str`): the stage of the forward pass, should be one of:
150            1. 'train': training stage.
151            2. 'validation': validation stage.
152            3. 'test': testing stage.
153        - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).
154
155        **Returns:**
156        - **logits** (`Tensor`): the output logits tensor.
157        - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight.
158        - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`.
159        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
160        """
161        feature, weight_mask, bias_mask, activations = self.backbone(
162            input,
163            stage=stage,
164            mask_percentage=self.mask_percentage,
165            test_mask=(
166                (self.weight_masks[task_id], self.bias_masks[task_id])
167                if stage == "test"
168                else None
169            ),
170        )
171        logits = self.heads(feature, task_id)
172
173        return (
174            logits
175            if self.if_forward_func_return_logits_only
176            else (logits, weight_mask, bias_mask, activations)
177        )
178
179    def training_step(self, batch: Any) -> dict[str, Tensor]:
180        r"""Training step for current task `self.task_id`.
181
182        **Args:**
183        - **batch** (`Any`): a batch of training data.
184
185        **Returns:**
186        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
187        """
188        x, y = batch
189
190        # zero the gradients before forward pass in manual optimization mode
191        opt = self.optimizers()
192        opt.zero_grad()
193
194        # classification loss
195        logits, weight_mask, bias_mask, activations = self.forward(
196            x, stage="train", task_id=self.task_id
197        )
198        loss_cls = self.criterion(logits, y)
199
200        # total loss
201        loss = loss_cls
202
203        # backward step (manually)
204        self.manual_backward(loss)  # calculate the gradients
205        # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
206        self.clip_grad_by_mask()
207
208        # update parameters with the modified gradients
209        opt.step()
210
211        # predicted labels
212        preds = logits.argmax(dim=1)
213
214        # accuracy of the batch
215        acc = (preds == y).float().mean()
216
217        return {
218            "preds": preds,
219            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
220            "loss_cls": loss_cls,
221            "acc": acc,
222            "activations": activations,
223            "weight_mask": weight_mask,  # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()`
224            "bias_mask": bias_mask,
225        }
226
227    def on_train_end(self) -> None:
228        r"""Store the weight and bias masks and update the cumulative masks after training the task."""
229
230        # get the weight and bias mask for the current task
231        weight_mask_t = {}
232        bias_mask_t = {}
233        for layer_name in self.backbone.weighted_layer_names:
234            layer = self.backbone.get_layer_by_name(layer_name)
235
236            weight_mask_t[layer_name] = self.backbone.gate_fn.apply(
237                self.backbone.weight_score_t[layer_name].weight, self.mask_percentage
238            )
239            if layer.bias is not None:
240                bias_mask_t[layer_name] = self.backbone.gate_fn.apply(
241                    self.backbone.bias_score_t[layer_name].weight.squeeze(
242                        0
243                    ),  # from (1, output_dim) to (output_dim, )
244                    self.mask_percentage,
245                )
246            else:
247                bias_mask_t[layer_name] = None
248
249        # store the weight and bias mask for the current task
250        self.weight_masks[self.task_id] = weight_mask_t
251        self.bias_masks[self.task_id] = bias_mask_t
252
253        # update the cumulative mask
254        for layer_name in self.backbone.weighted_layer_names:
255            layer = self.backbone.get_layer_by_name(layer_name)
256
257            self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max(
258                self.cumulative_weight_mask_for_previous_tasks[layer_name],
259                weight_mask_t[layer_name],
260            )
261            if layer.bias is not None:
262                print(
263                    self.cumulative_bias_mask_for_previous_tasks[layer_name].shape,
264                    bias_mask_t[layer_name].shape,
265                )
266                self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max(
267                    self.cumulative_bias_mask_for_previous_tasks[layer_name],
268                    bias_mask_t[layer_name],
269                )
270            else:
271                self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
272
273        print(self.cumulative_bias_mask_for_previous_tasks)
274
275    def validation_step(self, batch: Any) -> dict[str, Tensor]:
276        r"""Validation step for current task `self.task_id`.
277
278        **Args:**
279        - **batch** (`Any`): a batch of validation data.
280
281        **Returns:**
282        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
283        """
284        x, y = batch
285        logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
286        loss_cls = self.criterion(logits, y)
287        preds = logits.argmax(dim=1)
288        acc = (preds == y).float().mean()
289
290        return {
291            "loss_cls": loss_cls,
292            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()`
293            "preds": preds,
294        }
295
296    def test_step(
297        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
298    ) -> dict[str, Tensor]:
299        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
300
301        **Args:**
302        - **batch** (`Any`): a batch of test data.
303        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
304
305        **Returns:**
306        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
307        """
308        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
309
310        x, y = batch
311        logits, _, _, _ = self.forward(
312            x,
313            stage="test",
314            task_id=test_task_id,
315        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
316        loss_cls = self.criterion(logits, y)
317        preds = logits.argmax(dim=1)
318        acc = (preds == y).float().mean()
319
320        return {
321            "loss_cls": loss_cls,
322            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_test_batch_end()`
323            "preds": preds,
324        }

WSN (Winning Subnetworks) algorithm.

An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task.

WSN( backbone: clarena.backbones.WSNMaskBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadDIL, mask_percentage: float, parameter_score_init_mode: str = 'default', non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
29    def __init__(
30        self,
31        backbone: WSNMaskBackbone,
32        heads: HeadsTIL | HeadDIL,
33        mask_percentage: float,
34        parameter_score_init_mode: str = "default",
35        non_algorithmic_hparams: dict[str, Any] = {},
36        **kwargs,
37    ) -> None:
38        r"""Initialize the WSN algorithm with the network.
39
40        **Args:**
41        - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism.
42        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. WSN supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
43        - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
44        - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of:
45            1. 'default': the default initialization in the original WSN code.
46            2. 'N01': standard normal distribution $N(0, 1)$.
47            3. 'U01': uniform distribution $U(0, 1)$.
48        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
49        - **kwargs**: Reserved for multiple inheritance.
50
51        """
52        super().__init__(
53            backbone=backbone,
54            heads=heads,
55            non_algorithmic_hparams=non_algorithmic_hparams,
56            **kwargs,
57        )
58
59        self.mask_percentage: float = mask_percentage
60        r"""The percentage of parameters to be used for each task."""
61        self.parameter_score_init_mode: str = parameter_score_init_mode
62        r"""The parameter score initialization mode."""
63
64        # save additional algorithmic hyperparameters
65        self.save_hyperparameters(
66            "mask_percentage",
67            "parameter_score_init_mode",
68        )
69
70        self.weight_masks: dict[int, dict[str, Tensor]] = {}
71        r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight."""
72        self.bias_masks: dict[int, dict[str, Tensor]] = {}
73        r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
74
75        self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {}
76        r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight."""
77        self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {}
78        r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
79
80        # set manual optimization
81        self.automatic_optimization = False
82
83        WSN.sanity_check(self)

Initialize the WSN algorithm with the network.

Args:

  • backbone (WSNMaskBackbone): must be a backbone network with the WSN mask mechanism.
  • heads (HeadsTIL | HeadDIL): output heads. WSN supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
  • mask_percentage (float): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the WSN paper.
  • parameter_score_init_mode (str): the initialization mode for parameter scores, must be one of:
    1. 'default': the default initialization in the original WSN code.
    2. 'N01': standard normal distribution $N(0, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
mask_percentage: float

The percentage of parameters to be used for each task.

parameter_score_init_mode: str

The parameter score initialization mode.

weight_masks: dict[int, dict[str, torch.Tensor]]

The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.

bias_masks: dict[int, dict[str, torch.Tensor]]

The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is None.

cumulative_weight_mask_for_previous_tasks: dict[str, torch.Tensor]

The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.

cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, torch.Tensor]]

The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is None.

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
85    def sanity_check(self) -> None:
86        r"""Sanity check."""
87
88        # check the backbone and heads
89        if not isinstance(self.backbone, WSNMaskBackbone):
90            raise ValueError("The backbone should be an instance of WSNMaskBackbone.")
91        if not isinstance(self.heads, HeadsTIL):
92            raise ValueError("The heads should be an instance of `HeadsTIL`.")
93
94        # check the mask percentage
95        if not (0 < self.mask_percentage <= 1):
96            raise ValueError(
97                f"Mask percentage should be in (0, 1], but got {self.mask_percentage}."
98            )

Sanity check.

def on_train_start(self) -> None:
100    def on_train_start(self) -> None:
101        r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task."""
102
103        self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode)
104
105        # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
106        if self.task_id == 1:
107            for layer_name in self.backbone.weighted_layer_names:
108                layer = self.backbone.get_layer_by_name(
109                    layer_name
110                )  # get the layer by its name
111
112                self.cumulative_weight_mask_for_previous_tasks[layer_name] = (
113                    torch.zeros_like(layer.weight).to(self.device)
114                )
115                if layer.bias is not None:
116                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = (
117                        torch.zeros_like(layer.bias).to(self.device)
118                    )
119                else:
120                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
121                # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$)

Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.

def clip_grad_by_mask(self) -> None:
123    def clip_grad_by_mask(
124        self,
125    ) -> None:
126        r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf)."""
127
128        for layer_name in self.backbone.weighted_layer_names:
129            layer = self.backbone.get_layer_by_name(layer_name)
130
131            layer.weight.grad.data *= (
132                1 - self.cumulative_weight_mask_for_previous_tasks[layer_name]
133            )
134            if layer.bias is not None:
135                layer.bias.grad.data *= (
136                    1 - self.cumulative_bias_mask_for_previous_tasks[layer_name]
137                )

Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the WSN paper.

def forward( self, input: torch.Tensor, stage: str, task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
139    def forward(
140        self,
141        input: torch.Tensor,
142        stage: str,
143        task_id: int | None = None,
144    ) -> tuple[Tensor, dict[str, Tensor]]:
145        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
146
147        **Args:**
148        - **input** (`Tensor`): the input tensor from data.
149        - **stage** (`str`): the stage of the forward pass, should be one of:
150            1. 'train': training stage.
151            2. 'validation': validation stage.
152            3. 'test': testing stage.
153        - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).
154
155        **Returns:**
156        - **logits** (`Tensor`): the output logits tensor.
157        - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight.
158        - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`.
159        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
160        """
161        feature, weight_mask, bias_mask, activations = self.backbone(
162            input,
163            stage=stage,
164            mask_percentage=self.mask_percentage,
165            test_mask=(
166                (self.weight_masks[task_id], self.bias_masks[task_id])
167                if stage == "test"
168                else None
169            ),
170        )
171        logits = self.heads(feature, task_id)
172
173        return (
174            logits
175            if self.if_forward_func_return_logits_only
176            else (logits, weight_mask, bias_mask, activations)
177        )

The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.

Args:

  • input (Tensor): the input tensor from data.
  • stage (str): the stage of the forward pass, should be one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • task_id (int | None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task self.task_id. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).

Returns:

  • logits (Tensor): the output logits tensor.
  • weight_mask (dict[str, Tensor]): the weight mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has same (output features, input features) as weight.
  • bias_mask (dict[str, Tensor]): the bias mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is None.
  • activations (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
179    def training_step(self, batch: Any) -> dict[str, Tensor]:
180        r"""Training step for current task `self.task_id`.
181
182        **Args:**
183        - **batch** (`Any`): a batch of training data.
184
185        **Returns:**
186        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
187        """
188        x, y = batch
189
190        # zero the gradients before forward pass in manual optimization mode
191        opt = self.optimizers()
192        opt.zero_grad()
193
194        # classification loss
195        logits, weight_mask, bias_mask, activations = self.forward(
196            x, stage="train", task_id=self.task_id
197        )
198        loss_cls = self.criterion(logits, y)
199
200        # total loss
201        loss = loss_cls
202
203        # backward step (manually)
204        self.manual_backward(loss)  # calculate the gradients
205        # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
206        self.clip_grad_by_mask()
207
208        # update parameters with the modified gradients
209        opt.step()
210
211        # predicted labels
212        preds = logits.argmax(dim=1)
213
214        # accuracy of the batch
215        acc = (preds == y).float().mean()
216
217        return {
218            "preds": preds,
219            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
220            "loss_cls": loss_cls,
221            "acc": acc,
222            "activations": activations,
223            "weight_mask": weight_mask,  # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()`
224            "bias_mask": bias_mask,
225        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary containing loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
def on_train_end(self) -> None:
227    def on_train_end(self) -> None:
228        r"""Store the weight and bias masks and update the cumulative masks after training the task."""
229
230        # get the weight and bias mask for the current task
231        weight_mask_t = {}
232        bias_mask_t = {}
233        for layer_name in self.backbone.weighted_layer_names:
234            layer = self.backbone.get_layer_by_name(layer_name)
235
236            weight_mask_t[layer_name] = self.backbone.gate_fn.apply(
237                self.backbone.weight_score_t[layer_name].weight, self.mask_percentage
238            )
239            if layer.bias is not None:
240                bias_mask_t[layer_name] = self.backbone.gate_fn.apply(
241                    self.backbone.bias_score_t[layer_name].weight.squeeze(
242                        0
243                    ),  # from (1, output_dim) to (output_dim, )
244                    self.mask_percentage,
245                )
246            else:
247                bias_mask_t[layer_name] = None
248
249        # store the weight and bias mask for the current task
250        self.weight_masks[self.task_id] = weight_mask_t
251        self.bias_masks[self.task_id] = bias_mask_t
252
253        # update the cumulative mask
254        for layer_name in self.backbone.weighted_layer_names:
255            layer = self.backbone.get_layer_by_name(layer_name)
256
257            self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max(
258                self.cumulative_weight_mask_for_previous_tasks[layer_name],
259                weight_mask_t[layer_name],
260            )
261            if layer.bias is not None:
262                print(
263                    self.cumulative_bias_mask_for_previous_tasks[layer_name].shape,
264                    bias_mask_t[layer_name].shape,
265                )
266                self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max(
267                    self.cumulative_bias_mask_for_previous_tasks[layer_name],
268                    bias_mask_t[layer_name],
269                )
270            else:
271                self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
272
273        print(self.cumulative_bias_mask_for_previous_tasks)

Store the weight and bias masks and update the cumulative masks after training the task.

def validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
275    def validation_step(self, batch: Any) -> dict[str, Tensor]:
276        r"""Validation step for current task `self.task_id`.
277
278        **Args:**
279        - **batch** (`Any`): a batch of validation data.
280
281        **Returns:**
282        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
283        """
284        x, y = batch
285        logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
286        loss_cls = self.criterion(logits, y)
287        preds = logits.argmax(dim=1)
288        acc = (preds == y).float().mean()
289
290        return {
291            "loss_cls": loss_cls,
292            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()`
293            "preds": preds,
294        }

Validation step for current task self.task_id.

Args:

  • batch (Any): a batch of validation data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
296    def test_step(
297        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
298    ) -> dict[str, Tensor]:
299        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
300
301        **Args:**
302        - **batch** (`Any`): a batch of test data.
303        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
304
305        **Returns:**
306        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
307        """
308        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
309
310        x, y = batch
311        logits, _, _, _ = self.forward(
312            x,
313            stage="test",
314            task_id=test_task_id,
315        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
316        loss_cls = self.criterion(logits, y)
317        preds = logits.argmax(dim=1)
318        acc = (preds == y).float().mean()
319
320        return {
321            "loss_cls": loss_cls,
322            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_test_batch_end()`
323            "preds": preds,
324        }

Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.