clarena.cl_algorithms.wsn

The submodule in cl_algorithms for WSN (Winning Subnetworks) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm.
  3"""
  4
  5__all__ = ["WSN"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12from torch.utils.data import DataLoader
 13
 14from clarena.backbones import WSNMaskBackbone
 15from clarena.cl_algorithms import CLAlgorithm
 16from clarena.heads import HeadsTIL
 17
 18# always get logger for built-in logging in each module
 19pylogger = logging.getLogger(__name__)
 20
 21
 22class WSN(CLAlgorithm):
 23    r"""[WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm.
 24
 25    An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task.
 26    """
 27
 28    def __init__(
 29        self,
 30        backbone: WSNMaskBackbone,
 31        heads: HeadsTIL,
 32        mask_percentage: float,
 33        parameter_score_init_mode: str = "default",
 34        non_algorithmic_hparams: dict[str, Any] = {},
 35    ) -> None:
 36        r"""Initialize the WSN algorithm with the network.
 37
 38        **Args:**
 39        - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism.
 40        - **heads** (`HeadsTIL`): output heads. WSN only supports TIL (Task-Incremental Learning).
 41        - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
 42        - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of:
 43            1. 'default': the default initialization in the original WSN code.
 44            2. 'N01': standard normal distribution $N(0, 1)$.
 45            3. 'U01': uniform distribution $U(0, 1)$.
 46        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 47
 48        """
 49        super().__init__(
 50            backbone=backbone,
 51            heads=heads,
 52            non_algorithmic_hparams=non_algorithmic_hparams,
 53        )
 54
 55        self.mask_percentage: float = mask_percentage
 56        r"""The percentage of parameters to be used for each task."""
 57        self.parameter_score_init_mode: str = parameter_score_init_mode
 58        r"""The parameter score initialization mode."""
 59
 60        # save additional algorithmic hyperparameters
 61        self.save_hyperparameters(
 62            "mask_percentage",
 63            "parameter_score_init_mode",
 64        )
 65
 66        self.weight_masks: dict[int, dict[str, Tensor]] = {}
 67        r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight."""
 68        self.bias_masks: dict[int, dict[str, Tensor]] = {}
 69        r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
 70
 71        self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {}
 72        r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight."""
 73        self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {}
 74        r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
 75
 76        # set manual optimization
 77        self.automatic_optimization = False
 78
 79        WSN.sanity_check(self)
 80
 81    def sanity_check(self) -> None:
 82        r"""Sanity check."""
 83
 84        # check the backbone and heads
 85        if not isinstance(self.backbone, WSNMaskBackbone):
 86            raise ValueError("The backbone should be an instance of WSNMaskBackbone.")
 87        if not isinstance(self.heads, HeadsTIL):
 88            raise ValueError("The heads should be an instance of `HeadsTIL`.")
 89
 90        # check the mask percentage
 91        if not (0 < self.mask_percentage <= 1):
 92            raise ValueError(
 93                f"Mask percentage should be in (0, 1], but got {self.mask_percentage}."
 94            )
 95
 96    def on_train_start(self) -> None:
 97        r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task."""
 98
 99        self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode)
100
101        # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
102        if self.task_id == 1:
103            for layer_name in self.backbone.weighted_layer_names:
104                layer = self.backbone.get_layer_by_name(
105                    layer_name
106                )  # get the layer by its name
107
108                self.cumulative_weight_mask_for_previous_tasks[layer_name] = (
109                    torch.zeros_like(layer.weight).to(self.device)
110                )
111                if layer.bias is not None:
112                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = (
113                        torch.zeros_like(layer.bias).to(self.device)
114                    )
115                else:
116                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
117                # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$)
118
119    def clip_grad_by_mask(
120        self,
121    ) -> None:
122        r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf)."""
123
124        for layer_name in self.backbone.weighted_layer_names:
125            layer = self.backbone.get_layer_by_name(layer_name)
126
127            layer.weight.grad.data *= (
128                1 - self.cumulative_weight_mask_for_previous_tasks[layer_name]
129            )
130            if layer.bias is not None:
131                layer.bias.grad.data *= (
132                    1 - self.cumulative_bias_mask_for_previous_tasks[layer_name]
133                )
134
135    def forward(
136        self,
137        input: torch.Tensor,
138        stage: str,
139        task_id: int | None = None,
140    ) -> tuple[Tensor, dict[str, Tensor]]:
141        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
142
143        **Args:**
144        - **input** (`Tensor`): the input tensor from data.
145        - **stage** (`str`): the stage of the forward pass, should be one of:
146            1. 'train': training stage.
147            2. 'validation': validation stage.
148            3. 'test': testing stage.
149        - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).
150
151        **Returns:**
152        - **logits** (`Tensor`): the output logits tensor.
153        - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight.
154        - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`.
155        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
156        """
157        feature, weight_mask, bias_mask, activations = self.backbone(
158            input,
159            stage=stage,
160            mask_percentage=self.mask_percentage,
161            test_mask=(
162                (self.weight_masks[task_id], self.bias_masks[task_id])
163                if stage == "test"
164                else None
165            ),
166        )
167        logits = self.heads(feature, task_id)
168
169        return (
170            logits
171            if self.if_forward_func_return_logits_only
172            else (logits, weight_mask, bias_mask, activations)
173        )
174
175    def training_step(self, batch: Any) -> dict[str, Tensor]:
176        r"""Training step for current task `self.task_id`.
177
178        **Args:**
179        - **batch** (`Any`): a batch of training data.
180
181        **Returns:**
182        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
183        """
184        x, y = batch
185
186        # zero the gradients before forward pass in manual optimization mode
187        opt = self.optimizers()
188        opt.zero_grad()
189
190        # classification loss
191        logits, weight_mask, bias_mask, activations = self.forward(
192            x, stage="train", task_id=self.task_id
193        )
194        loss_cls = self.criterion(logits, y)
195
196        # total loss
197        loss = loss_cls
198
199        # backward step (manually)
200        self.manual_backward(loss)  # calculate the gradients
201        # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
202        self.clip_grad_by_mask()
203
204        # update parameters with the modified gradients
205        opt.step()
206
207        # accuracy of the batch
208        acc = (logits.argmax(dim=1) == y).float().mean()
209
210        return {
211            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
212            "loss_cls": loss_cls,
213            "acc": acc,
214            "activations": activations,
215            "weight_mask": weight_mask,  # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()`
216            "bias_mask": bias_mask,
217        }
218
219    def on_train_end(self) -> None:
220        r"""Store the weight and bias masks and update the cumulative masks after training the task."""
221
222        # get the weight and bias mask for the current task
223        weight_mask_t = {}
224        bias_mask_t = {}
225        for layer_name in self.backbone.weighted_layer_names:
226            layer = self.backbone.get_layer_by_name(layer_name)
227
228            weight_mask_t[layer_name] = self.backbone.gate_fn.apply(
229                self.backbone.weight_score_t[layer_name].weight, self.mask_percentage
230            )
231            if layer.bias is not None:
232                bias_mask_t[layer_name] = self.backbone.gate_fn.apply(
233                    self.backbone.bias_score_t[layer_name].weight.squeeze(
234                        0
235                    ),  # from (1, output_dim) to (output_dim, )
236                    self.mask_percentage,
237                )
238            else:
239                bias_mask_t[layer_name] = None
240
241        # store the weight and bias mask for the current task
242        self.weight_masks[self.task_id] = weight_mask_t
243        self.bias_masks[self.task_id] = bias_mask_t
244
245        # update the cumulative mask
246        for layer_name in self.backbone.weighted_layer_names:
247            layer = self.backbone.get_layer_by_name(layer_name)
248
249            self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max(
250                self.cumulative_weight_mask_for_previous_tasks[layer_name],
251                weight_mask_t[layer_name],
252            )
253            if layer.bias is not None:
254                print(
255                    self.cumulative_bias_mask_for_previous_tasks[layer_name].shape,
256                    bias_mask_t[layer_name].shape,
257                )
258                self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max(
259                    self.cumulative_bias_mask_for_previous_tasks[layer_name],
260                    bias_mask_t[layer_name],
261                )
262            else:
263                self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
264
265        print(self.cumulative_bias_mask_for_previous_tasks)
266
267    def validation_step(self, batch: Any) -> dict[str, Tensor]:
268        r"""Validation step for current task `self.task_id`.
269
270        **Args:**
271        - **batch** (`Any`): a batch of validation data.
272
273        **Returns:**
274        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
275        """
276        x, y = batch
277        logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
278        loss_cls = self.criterion(logits, y)
279        acc = (logits.argmax(dim=1) == y).float().mean()
280
281        return {
282            "loss_cls": loss_cls,
283            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()`
284        }
285
286    def test_step(
287        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
288    ) -> dict[str, Tensor]:
289        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
290
291        **Args:**
292        - **batch** (`Any`): a batch of test data.
293        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
294
295        **Returns:**
296        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
297        """
298        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
299
300        x, y = batch
301        logits, _, _, _ = self.forward(
302            x,
303            stage="test",
304            task_id=test_task_id,
305        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
306        loss_cls = self.criterion(logits, y)
307        acc = (logits.argmax(dim=1) == y).float().mean()
308
309        return {
310            "loss_cls": loss_cls,
311            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_test_batch_end()`
312        }
class WSN(clarena.cl_algorithms.base.CLAlgorithm):
 23class WSN(CLAlgorithm):
 24    r"""[WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm.
 25
 26    An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task.
 27    """
 28
 29    def __init__(
 30        self,
 31        backbone: WSNMaskBackbone,
 32        heads: HeadsTIL,
 33        mask_percentage: float,
 34        parameter_score_init_mode: str = "default",
 35        non_algorithmic_hparams: dict[str, Any] = {},
 36    ) -> None:
 37        r"""Initialize the WSN algorithm with the network.
 38
 39        **Args:**
 40        - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism.
 41        - **heads** (`HeadsTIL`): output heads. WSN only supports TIL (Task-Incremental Learning).
 42        - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
 43        - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of:
 44            1. 'default': the default initialization in the original WSN code.
 45            2. 'N01': standard normal distribution $N(0, 1)$.
 46            3. 'U01': uniform distribution $U(0, 1)$.
 47        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 48
 49        """
 50        super().__init__(
 51            backbone=backbone,
 52            heads=heads,
 53            non_algorithmic_hparams=non_algorithmic_hparams,
 54        )
 55
 56        self.mask_percentage: float = mask_percentage
 57        r"""The percentage of parameters to be used for each task."""
 58        self.parameter_score_init_mode: str = parameter_score_init_mode
 59        r"""The parameter score initialization mode."""
 60
 61        # save additional algorithmic hyperparameters
 62        self.save_hyperparameters(
 63            "mask_percentage",
 64            "parameter_score_init_mode",
 65        )
 66
 67        self.weight_masks: dict[int, dict[str, Tensor]] = {}
 68        r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight."""
 69        self.bias_masks: dict[int, dict[str, Tensor]] = {}
 70        r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
 71
 72        self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {}
 73        r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight."""
 74        self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {}
 75        r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
 76
 77        # set manual optimization
 78        self.automatic_optimization = False
 79
 80        WSN.sanity_check(self)
 81
 82    def sanity_check(self) -> None:
 83        r"""Sanity check."""
 84
 85        # check the backbone and heads
 86        if not isinstance(self.backbone, WSNMaskBackbone):
 87            raise ValueError("The backbone should be an instance of WSNMaskBackbone.")
 88        if not isinstance(self.heads, HeadsTIL):
 89            raise ValueError("The heads should be an instance of `HeadsTIL`.")
 90
 91        # check the mask percentage
 92        if not (0 < self.mask_percentage <= 1):
 93            raise ValueError(
 94                f"Mask percentage should be in (0, 1], but got {self.mask_percentage}."
 95            )
 96
 97    def on_train_start(self) -> None:
 98        r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task."""
 99
100        self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode)
101
102        # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
103        if self.task_id == 1:
104            for layer_name in self.backbone.weighted_layer_names:
105                layer = self.backbone.get_layer_by_name(
106                    layer_name
107                )  # get the layer by its name
108
109                self.cumulative_weight_mask_for_previous_tasks[layer_name] = (
110                    torch.zeros_like(layer.weight).to(self.device)
111                )
112                if layer.bias is not None:
113                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = (
114                        torch.zeros_like(layer.bias).to(self.device)
115                    )
116                else:
117                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
118                # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$)
119
120    def clip_grad_by_mask(
121        self,
122    ) -> None:
123        r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf)."""
124
125        for layer_name in self.backbone.weighted_layer_names:
126            layer = self.backbone.get_layer_by_name(layer_name)
127
128            layer.weight.grad.data *= (
129                1 - self.cumulative_weight_mask_for_previous_tasks[layer_name]
130            )
131            if layer.bias is not None:
132                layer.bias.grad.data *= (
133                    1 - self.cumulative_bias_mask_for_previous_tasks[layer_name]
134                )
135
136    def forward(
137        self,
138        input: torch.Tensor,
139        stage: str,
140        task_id: int | None = None,
141    ) -> tuple[Tensor, dict[str, Tensor]]:
142        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
143
144        **Args:**
145        - **input** (`Tensor`): the input tensor from data.
146        - **stage** (`str`): the stage of the forward pass, should be one of:
147            1. 'train': training stage.
148            2. 'validation': validation stage.
149            3. 'test': testing stage.
150        - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).
151
152        **Returns:**
153        - **logits** (`Tensor`): the output logits tensor.
154        - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight.
155        - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`.
156        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
157        """
158        feature, weight_mask, bias_mask, activations = self.backbone(
159            input,
160            stage=stage,
161            mask_percentage=self.mask_percentage,
162            test_mask=(
163                (self.weight_masks[task_id], self.bias_masks[task_id])
164                if stage == "test"
165                else None
166            ),
167        )
168        logits = self.heads(feature, task_id)
169
170        return (
171            logits
172            if self.if_forward_func_return_logits_only
173            else (logits, weight_mask, bias_mask, activations)
174        )
175
176    def training_step(self, batch: Any) -> dict[str, Tensor]:
177        r"""Training step for current task `self.task_id`.
178
179        **Args:**
180        - **batch** (`Any`): a batch of training data.
181
182        **Returns:**
183        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
184        """
185        x, y = batch
186
187        # zero the gradients before forward pass in manual optimization mode
188        opt = self.optimizers()
189        opt.zero_grad()
190
191        # classification loss
192        logits, weight_mask, bias_mask, activations = self.forward(
193            x, stage="train", task_id=self.task_id
194        )
195        loss_cls = self.criterion(logits, y)
196
197        # total loss
198        loss = loss_cls
199
200        # backward step (manually)
201        self.manual_backward(loss)  # calculate the gradients
202        # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
203        self.clip_grad_by_mask()
204
205        # update parameters with the modified gradients
206        opt.step()
207
208        # accuracy of the batch
209        acc = (logits.argmax(dim=1) == y).float().mean()
210
211        return {
212            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
213            "loss_cls": loss_cls,
214            "acc": acc,
215            "activations": activations,
216            "weight_mask": weight_mask,  # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()`
217            "bias_mask": bias_mask,
218        }
219
220    def on_train_end(self) -> None:
221        r"""Store the weight and bias masks and update the cumulative masks after training the task."""
222
223        # get the weight and bias mask for the current task
224        weight_mask_t = {}
225        bias_mask_t = {}
226        for layer_name in self.backbone.weighted_layer_names:
227            layer = self.backbone.get_layer_by_name(layer_name)
228
229            weight_mask_t[layer_name] = self.backbone.gate_fn.apply(
230                self.backbone.weight_score_t[layer_name].weight, self.mask_percentage
231            )
232            if layer.bias is not None:
233                bias_mask_t[layer_name] = self.backbone.gate_fn.apply(
234                    self.backbone.bias_score_t[layer_name].weight.squeeze(
235                        0
236                    ),  # from (1, output_dim) to (output_dim, )
237                    self.mask_percentage,
238                )
239            else:
240                bias_mask_t[layer_name] = None
241
242        # store the weight and bias mask for the current task
243        self.weight_masks[self.task_id] = weight_mask_t
244        self.bias_masks[self.task_id] = bias_mask_t
245
246        # update the cumulative mask
247        for layer_name in self.backbone.weighted_layer_names:
248            layer = self.backbone.get_layer_by_name(layer_name)
249
250            self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max(
251                self.cumulative_weight_mask_for_previous_tasks[layer_name],
252                weight_mask_t[layer_name],
253            )
254            if layer.bias is not None:
255                print(
256                    self.cumulative_bias_mask_for_previous_tasks[layer_name].shape,
257                    bias_mask_t[layer_name].shape,
258                )
259                self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max(
260                    self.cumulative_bias_mask_for_previous_tasks[layer_name],
261                    bias_mask_t[layer_name],
262                )
263            else:
264                self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
265
266        print(self.cumulative_bias_mask_for_previous_tasks)
267
268    def validation_step(self, batch: Any) -> dict[str, Tensor]:
269        r"""Validation step for current task `self.task_id`.
270
271        **Args:**
272        - **batch** (`Any`): a batch of validation data.
273
274        **Returns:**
275        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
276        """
277        x, y = batch
278        logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
279        loss_cls = self.criterion(logits, y)
280        acc = (logits.argmax(dim=1) == y).float().mean()
281
282        return {
283            "loss_cls": loss_cls,
284            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()`
285        }
286
287    def test_step(
288        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
289    ) -> dict[str, Tensor]:
290        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
291
292        **Args:**
293        - **batch** (`Any`): a batch of test data.
294        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
295
296        **Returns:**
297        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
298        """
299        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
300
301        x, y = batch
302        logits, _, _, _ = self.forward(
303            x,
304            stage="test",
305            task_id=test_task_id,
306        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
307        loss_cls = self.criterion(logits, y)
308        acc = (logits.argmax(dim=1) == y).float().mean()
309
310        return {
311            "loss_cls": loss_cls,
312            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_test_batch_end()`
313        }

WSN (Winning Subnetworks) algorithm.

An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task.

WSN( backbone: clarena.backbones.WSNMaskBackbone, heads: clarena.heads.HeadsTIL, mask_percentage: float, parameter_score_init_mode: str = 'default', non_algorithmic_hparams: dict[str, typing.Any] = {})
29    def __init__(
30        self,
31        backbone: WSNMaskBackbone,
32        heads: HeadsTIL,
33        mask_percentage: float,
34        parameter_score_init_mode: str = "default",
35        non_algorithmic_hparams: dict[str, Any] = {},
36    ) -> None:
37        r"""Initialize the WSN algorithm with the network.
38
39        **Args:**
40        - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism.
41        - **heads** (`HeadsTIL`): output heads. WSN only supports TIL (Task-Incremental Learning).
42        - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
43        - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of:
44            1. 'default': the default initialization in the original WSN code.
45            2. 'N01': standard normal distribution $N(0, 1)$.
46            3. 'U01': uniform distribution $U(0, 1)$.
47        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
48
49        """
50        super().__init__(
51            backbone=backbone,
52            heads=heads,
53            non_algorithmic_hparams=non_algorithmic_hparams,
54        )
55
56        self.mask_percentage: float = mask_percentage
57        r"""The percentage of parameters to be used for each task."""
58        self.parameter_score_init_mode: str = parameter_score_init_mode
59        r"""The parameter score initialization mode."""
60
61        # save additional algorithmic hyperparameters
62        self.save_hyperparameters(
63            "mask_percentage",
64            "parameter_score_init_mode",
65        )
66
67        self.weight_masks: dict[int, dict[str, Tensor]] = {}
68        r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight."""
69        self.bias_masks: dict[int, dict[str, Tensor]] = {}
70        r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
71
72        self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {}
73        r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight."""
74        self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {}
75        r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`."""
76
77        # set manual optimization
78        self.automatic_optimization = False
79
80        WSN.sanity_check(self)

Initialize the WSN algorithm with the network.

Args:

  • backbone (WSNMaskBackbone): must be a backbone network with the WSN mask mechanism.
  • heads (HeadsTIL): output heads. WSN only supports TIL (Task-Incremental Learning).
  • mask_percentage (float): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the WSN paper.
  • parameter_score_init_mode (str): the initialization mode for parameter scores, must be one of:
    1. 'default': the default initialization in the original WSN code.
    2. 'N01': standard normal distribution $N(0, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
mask_percentage: float

The percentage of parameters to be used for each task.

parameter_score_init_mode: str

The parameter score initialization mode.

weight_masks: dict[int, dict[str, torch.Tensor]]

The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.

bias_masks: dict[int, dict[str, torch.Tensor]]

The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is None.

cumulative_weight_mask_for_previous_tasks: dict[str, torch.Tensor]

The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.

cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, torch.Tensor]]

The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is None.

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
82    def sanity_check(self) -> None:
83        r"""Sanity check."""
84
85        # check the backbone and heads
86        if not isinstance(self.backbone, WSNMaskBackbone):
87            raise ValueError("The backbone should be an instance of WSNMaskBackbone.")
88        if not isinstance(self.heads, HeadsTIL):
89            raise ValueError("The heads should be an instance of `HeadsTIL`.")
90
91        # check the mask percentage
92        if not (0 < self.mask_percentage <= 1):
93            raise ValueError(
94                f"Mask percentage should be in (0, 1], but got {self.mask_percentage}."
95            )

Sanity check.

def on_train_start(self) -> None:
 97    def on_train_start(self) -> None:
 98        r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task."""
 99
100        self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode)
101
102        # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time.
103        if self.task_id == 1:
104            for layer_name in self.backbone.weighted_layer_names:
105                layer = self.backbone.get_layer_by_name(
106                    layer_name
107                )  # get the layer by its name
108
109                self.cumulative_weight_mask_for_previous_tasks[layer_name] = (
110                    torch.zeros_like(layer.weight).to(self.device)
111                )
112                if layer.bias is not None:
113                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = (
114                        torch.zeros_like(layer.bias).to(self.device)
115                    )
116                else:
117                    self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
118                # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$)

Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.

def clip_grad_by_mask(self) -> None:
120    def clip_grad_by_mask(
121        self,
122    ) -> None:
123        r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf)."""
124
125        for layer_name in self.backbone.weighted_layer_names:
126            layer = self.backbone.get_layer_by_name(layer_name)
127
128            layer.weight.grad.data *= (
129                1 - self.cumulative_weight_mask_for_previous_tasks[layer_name]
130            )
131            if layer.bias is not None:
132                layer.bias.grad.data *= (
133                    1 - self.cumulative_bias_mask_for_previous_tasks[layer_name]
134                )

Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the WSN paper.

def forward( self, input: torch.Tensor, stage: str, task_id: int | None = None) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
136    def forward(
137        self,
138        input: torch.Tensor,
139        stage: str,
140        task_id: int | None = None,
141    ) -> tuple[Tensor, dict[str, Tensor]]:
142        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
143
144        **Args:**
145        - **input** (`Tensor`): the input tensor from data.
146        - **stage** (`str`): the stage of the forward pass, should be one of:
147            1. 'train': training stage.
148            2. 'validation': validation stage.
149            3. 'test': testing stage.
150        - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).
151
152        **Returns:**
153        - **logits** (`Tensor`): the output logits tensor.
154        - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight.
155        - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`.
156        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
157        """
158        feature, weight_mask, bias_mask, activations = self.backbone(
159            input,
160            stage=stage,
161            mask_percentage=self.mask_percentage,
162            test_mask=(
163                (self.weight_masks[task_id], self.bias_masks[task_id])
164                if stage == "test"
165                else None
166            ),
167        )
168        logits = self.heads(feature, task_id)
169
170        return (
171            logits
172            if self.if_forward_func_return_logits_only
173            else (logits, weight_mask, bias_mask, activations)
174        )

The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.

Args:

  • input (Tensor): the input tensor from data.
  • stage (str): the stage of the forward pass, should be one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • task_id (int | None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task self.task_id. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).

Returns:

  • logits (Tensor): the output logits tensor.
  • weight_mask (dict[str, Tensor]): the weight mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has same (output features, input features) as weight.
  • bias_mask (dict[str, Tensor]): the bias mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is None.
  • activations (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
176    def training_step(self, batch: Any) -> dict[str, Tensor]:
177        r"""Training step for current task `self.task_id`.
178
179        **Args:**
180        - **batch** (`Any`): a batch of training data.
181
182        **Returns:**
183        - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
184        """
185        x, y = batch
186
187        # zero the gradients before forward pass in manual optimization mode
188        opt = self.optimizers()
189        opt.zero_grad()
190
191        # classification loss
192        logits, weight_mask, bias_mask, activations = self.forward(
193            x, stage="train", task_id=self.task_id
194        )
195        loss_cls = self.criterion(logits, y)
196
197        # total loss
198        loss = loss_cls
199
200        # backward step (manually)
201        self.manual_backward(loss)  # calculate the gradients
202        # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).
203        self.clip_grad_by_mask()
204
205        # update parameters with the modified gradients
206        opt.step()
207
208        # accuracy of the batch
209        acc = (logits.argmax(dim=1) == y).float().mean()
210
211        return {
212            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
213            "loss_cls": loss_cls,
214            "acc": acc,
215            "activations": activations,
216            "weight_mask": weight_mask,  # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()`
217            "bias_mask": bias_mask,
218        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary containing loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
def on_train_end(self) -> None:
220    def on_train_end(self) -> None:
221        r"""Store the weight and bias masks and update the cumulative masks after training the task."""
222
223        # get the weight and bias mask for the current task
224        weight_mask_t = {}
225        bias_mask_t = {}
226        for layer_name in self.backbone.weighted_layer_names:
227            layer = self.backbone.get_layer_by_name(layer_name)
228
229            weight_mask_t[layer_name] = self.backbone.gate_fn.apply(
230                self.backbone.weight_score_t[layer_name].weight, self.mask_percentage
231            )
232            if layer.bias is not None:
233                bias_mask_t[layer_name] = self.backbone.gate_fn.apply(
234                    self.backbone.bias_score_t[layer_name].weight.squeeze(
235                        0
236                    ),  # from (1, output_dim) to (output_dim, )
237                    self.mask_percentage,
238                )
239            else:
240                bias_mask_t[layer_name] = None
241
242        # store the weight and bias mask for the current task
243        self.weight_masks[self.task_id] = weight_mask_t
244        self.bias_masks[self.task_id] = bias_mask_t
245
246        # update the cumulative mask
247        for layer_name in self.backbone.weighted_layer_names:
248            layer = self.backbone.get_layer_by_name(layer_name)
249
250            self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max(
251                self.cumulative_weight_mask_for_previous_tasks[layer_name],
252                weight_mask_t[layer_name],
253            )
254            if layer.bias is not None:
255                print(
256                    self.cumulative_bias_mask_for_previous_tasks[layer_name].shape,
257                    bias_mask_t[layer_name].shape,
258                )
259                self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max(
260                    self.cumulative_bias_mask_for_previous_tasks[layer_name],
261                    bias_mask_t[layer_name],
262                )
263            else:
264                self.cumulative_bias_mask_for_previous_tasks[layer_name] = None
265
266        print(self.cumulative_bias_mask_for_previous_tasks)

Store the weight and bias masks and update the cumulative masks after training the task.

def validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
268    def validation_step(self, batch: Any) -> dict[str, Tensor]:
269        r"""Validation step for current task `self.task_id`.
270
271        **Args:**
272        - **batch** (`Any`): a batch of validation data.
273
274        **Returns:**
275        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
276        """
277        x, y = batch
278        logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id)
279        loss_cls = self.criterion(logits, y)
280        acc = (logits.argmax(dim=1) == y).float().mean()
281
282        return {
283            "loss_cls": loss_cls,
284            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()`
285        }

Validation step for current task self.task_id.

Args:

  • batch (Any): a batch of validation data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
287    def test_step(
288        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
289    ) -> dict[str, Tensor]:
290        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
291
292        **Args:**
293        - **batch** (`Any`): a batch of test data.
294        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
295
296        **Returns:**
297        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
298        """
299        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
300
301        x, y = batch
302        logits, _, _, _ = self.forward(
303            x,
304            stage="test",
305            task_id=test_task_id,
306        )  # use the corresponding head and mask to test (instead of the current task `self.task_id`)
307        loss_cls = self.criterion(logits, y)
308        acc = (logits.argmax(dim=1) == y).float().mean()
309
310        return {
311            "loss_cls": loss_cls,
312            "acc": acc,  # return metrics for Lightning loggers callback to handle at `on_test_batch_end()`
313        }

Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.