clarena.cl_algorithms.wsn
The submodule in cl_algorithms for WSN (Winning Subnetworks) algorithm.
1r""" 2The submodule in `cl_algorithms` for [WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm. 3""" 4 5__all__ = ["WSN"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12from torch.utils.data import DataLoader 13 14from clarena.backbones import WSNMaskBackbone 15from clarena.cl_algorithms import CLAlgorithm 16from clarena.heads import HeadDIL, HeadsTIL 17 18# always get logger for built-in logging in each module 19pylogger = logging.getLogger(__name__) 20 21 22class WSN(CLAlgorithm): 23 r"""[WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm. 24 25 An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task. 26 """ 27 28 def __init__( 29 self, 30 backbone: WSNMaskBackbone, 31 heads: HeadsTIL | HeadDIL, 32 mask_percentage: float, 33 parameter_score_init_mode: str = "default", 34 non_algorithmic_hparams: dict[str, Any] = {}, 35 **kwargs, 36 ) -> None: 37 r"""Initialize the WSN algorithm with the network. 38 39 **Args:** 40 - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism. 41 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. WSN supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 42 - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 43 - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of: 44 1. 'default': the default initialization in the original WSN code. 45 2. 'N01': standard normal distribution $N(0, 1)$. 46 3. 'U01': uniform distribution $U(0, 1)$. 47 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 48 - **kwargs**: Reserved for multiple inheritance. 49 50 """ 51 super().__init__( 52 backbone=backbone, 53 heads=heads, 54 non_algorithmic_hparams=non_algorithmic_hparams, 55 **kwargs, 56 ) 57 58 self.mask_percentage: float = mask_percentage 59 r"""The percentage of parameters to be used for each task.""" 60 self.parameter_score_init_mode: str = parameter_score_init_mode 61 r"""The parameter score initialization mode.""" 62 63 # save additional algorithmic hyperparameters 64 self.save_hyperparameters( 65 "mask_percentage", 66 "parameter_score_init_mode", 67 ) 68 69 self.weight_masks: dict[int, dict[str, Tensor]] = {} 70 r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.""" 71 self.bias_masks: dict[int, dict[str, Tensor]] = {} 72 r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 73 74 self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {} 75 r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.""" 76 self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {} 77 r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 78 79 # set manual optimization 80 self.automatic_optimization = False 81 82 WSN.sanity_check(self) 83 84 def sanity_check(self) -> None: 85 r"""Sanity check.""" 86 87 # check the backbone and heads 88 if not isinstance(self.backbone, WSNMaskBackbone): 89 raise ValueError("The backbone should be an instance of WSNMaskBackbone.") 90 if not isinstance(self.heads, HeadsTIL): 91 raise ValueError("The heads should be an instance of `HeadsTIL`.") 92 93 # check the mask percentage 94 if not (0 < self.mask_percentage <= 1): 95 raise ValueError( 96 f"Mask percentage should be in (0, 1], but got {self.mask_percentage}." 97 ) 98 99 def on_train_start(self) -> None: 100 r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.""" 101 102 self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode) 103 104 # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 105 if self.task_id == 1: 106 for layer_name in self.backbone.weighted_layer_names: 107 layer = self.backbone.get_layer_by_name( 108 layer_name 109 ) # get the layer by its name 110 111 self.cumulative_weight_mask_for_previous_tasks[layer_name] = ( 112 torch.zeros_like(layer.weight).to(self.device) 113 ) 114 if layer.bias is not None: 115 self.cumulative_bias_mask_for_previous_tasks[layer_name] = ( 116 torch.zeros_like(layer.bias).to(self.device) 117 ) 118 else: 119 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 120 # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$) 121 122 def clip_grad_by_mask( 123 self, 124 ) -> None: 125 r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).""" 126 127 for layer_name in self.backbone.weighted_layer_names: 128 layer = self.backbone.get_layer_by_name(layer_name) 129 130 layer.weight.grad.data *= ( 131 1 - self.cumulative_weight_mask_for_previous_tasks[layer_name] 132 ) 133 if layer.bias is not None: 134 layer.bias.grad.data *= ( 135 1 - self.cumulative_bias_mask_for_previous_tasks[layer_name] 136 ) 137 138 def forward( 139 self, 140 input: torch.Tensor, 141 stage: str, 142 task_id: int | None = None, 143 ) -> tuple[Tensor, dict[str, Tensor]]: 144 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 145 146 **Args:** 147 - **input** (`Tensor`): the input tensor from data. 148 - **stage** (`str`): the stage of the forward pass, should be one of: 149 1. 'train': training stage. 150 2. 'validation': validation stage. 151 3. 'test': testing stage. 152 - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing). 153 154 **Returns:** 155 - **logits** (`Tensor`): the output logits tensor. 156 - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight. 157 - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`. 158 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 159 """ 160 feature, weight_mask, bias_mask, activations = self.backbone( 161 input, 162 stage=stage, 163 mask_percentage=self.mask_percentage, 164 test_mask=( 165 (self.weight_masks[task_id], self.bias_masks[task_id]) 166 if stage == "test" 167 else None 168 ), 169 ) 170 logits = self.heads(feature, task_id) 171 172 return ( 173 logits 174 if self.if_forward_func_return_logits_only 175 else (logits, weight_mask, bias_mask, activations) 176 ) 177 178 def training_step(self, batch: Any) -> dict[str, Tensor]: 179 r"""Training step for current task `self.task_id`. 180 181 **Args:** 182 - **batch** (`Any`): a batch of training data. 183 184 **Returns:** 185 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging. 186 """ 187 x, y = batch 188 189 # zero the gradients before forward pass in manual optimization mode 190 opt = self.optimizers() 191 opt.zero_grad() 192 193 # classification loss 194 logits, weight_mask, bias_mask, activations = self.forward( 195 x, stage="train", task_id=self.task_id 196 ) 197 loss_cls = self.criterion(logits, y) 198 199 # total loss 200 loss = loss_cls 201 202 # backward step (manually) 203 self.manual_backward(loss) # calculate the gradients 204 # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 205 self.clip_grad_by_mask() 206 207 # update parameters with the modified gradients 208 opt.step() 209 210 # predicted labels 211 preds = logits.argmax(dim=1) 212 213 # accuracy of the batch 214 acc = (preds == y).float().mean() 215 216 return { 217 "preds": preds, 218 "loss": loss, # return loss is essential for training step, or backpropagation will fail 219 "loss_cls": loss_cls, 220 "acc": acc, 221 "activations": activations, 222 "weight_mask": weight_mask, # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()` 223 "bias_mask": bias_mask, 224 } 225 226 def on_train_end(self) -> None: 227 r"""Store the weight and bias masks and update the cumulative masks after training the task.""" 228 229 # get the weight and bias mask for the current task 230 weight_mask_t = {} 231 bias_mask_t = {} 232 for layer_name in self.backbone.weighted_layer_names: 233 layer = self.backbone.get_layer_by_name(layer_name) 234 235 weight_mask_t[layer_name] = self.backbone.gate_fn.apply( 236 self.backbone.weight_score_t[layer_name].weight, self.mask_percentage 237 ) 238 if layer.bias is not None: 239 bias_mask_t[layer_name] = self.backbone.gate_fn.apply( 240 self.backbone.bias_score_t[layer_name].weight.squeeze( 241 0 242 ), # from (1, output_dim) to (output_dim, ) 243 self.mask_percentage, 244 ) 245 else: 246 bias_mask_t[layer_name] = None 247 248 # store the weight and bias mask for the current task 249 self.weight_masks[self.task_id] = weight_mask_t 250 self.bias_masks[self.task_id] = bias_mask_t 251 252 # update the cumulative mask 253 for layer_name in self.backbone.weighted_layer_names: 254 layer = self.backbone.get_layer_by_name(layer_name) 255 256 self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max( 257 self.cumulative_weight_mask_for_previous_tasks[layer_name], 258 weight_mask_t[layer_name], 259 ) 260 if layer.bias is not None: 261 print( 262 self.cumulative_bias_mask_for_previous_tasks[layer_name].shape, 263 bias_mask_t[layer_name].shape, 264 ) 265 self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max( 266 self.cumulative_bias_mask_for_previous_tasks[layer_name], 267 bias_mask_t[layer_name], 268 ) 269 else: 270 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 271 272 print(self.cumulative_bias_mask_for_previous_tasks) 273 274 def validation_step(self, batch: Any) -> dict[str, Tensor]: 275 r"""Validation step for current task `self.task_id`. 276 277 **Args:** 278 - **batch** (`Any`): a batch of validation data. 279 280 **Returns:** 281 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 282 """ 283 x, y = batch 284 logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 285 loss_cls = self.criterion(logits, y) 286 preds = logits.argmax(dim=1) 287 acc = (preds == y).float().mean() 288 289 return { 290 "loss_cls": loss_cls, 291 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()` 292 "preds": preds, 293 } 294 295 def test_step( 296 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 297 ) -> dict[str, Tensor]: 298 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 299 300 **Args:** 301 - **batch** (`Any`): a batch of test data. 302 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 303 304 **Returns:** 305 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 306 """ 307 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 308 309 x, y = batch 310 logits, _, _, _ = self.forward( 311 x, 312 stage="test", 313 task_id=test_task_id, 314 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 315 loss_cls = self.criterion(logits, y) 316 preds = logits.argmax(dim=1) 317 acc = (preds == y).float().mean() 318 319 return { 320 "loss_cls": loss_cls, 321 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_test_batch_end()` 322 "preds": preds, 323 }
23class WSN(CLAlgorithm): 24 r"""[WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm. 25 26 An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task. 27 """ 28 29 def __init__( 30 self, 31 backbone: WSNMaskBackbone, 32 heads: HeadsTIL | HeadDIL, 33 mask_percentage: float, 34 parameter_score_init_mode: str = "default", 35 non_algorithmic_hparams: dict[str, Any] = {}, 36 **kwargs, 37 ) -> None: 38 r"""Initialize the WSN algorithm with the network. 39 40 **Args:** 41 - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism. 42 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. WSN supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 43 - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 44 - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of: 45 1. 'default': the default initialization in the original WSN code. 46 2. 'N01': standard normal distribution $N(0, 1)$. 47 3. 'U01': uniform distribution $U(0, 1)$. 48 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 49 - **kwargs**: Reserved for multiple inheritance. 50 51 """ 52 super().__init__( 53 backbone=backbone, 54 heads=heads, 55 non_algorithmic_hparams=non_algorithmic_hparams, 56 **kwargs, 57 ) 58 59 self.mask_percentage: float = mask_percentage 60 r"""The percentage of parameters to be used for each task.""" 61 self.parameter_score_init_mode: str = parameter_score_init_mode 62 r"""The parameter score initialization mode.""" 63 64 # save additional algorithmic hyperparameters 65 self.save_hyperparameters( 66 "mask_percentage", 67 "parameter_score_init_mode", 68 ) 69 70 self.weight_masks: dict[int, dict[str, Tensor]] = {} 71 r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.""" 72 self.bias_masks: dict[int, dict[str, Tensor]] = {} 73 r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 74 75 self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {} 76 r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.""" 77 self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {} 78 r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 79 80 # set manual optimization 81 self.automatic_optimization = False 82 83 WSN.sanity_check(self) 84 85 def sanity_check(self) -> None: 86 r"""Sanity check.""" 87 88 # check the backbone and heads 89 if not isinstance(self.backbone, WSNMaskBackbone): 90 raise ValueError("The backbone should be an instance of WSNMaskBackbone.") 91 if not isinstance(self.heads, HeadsTIL): 92 raise ValueError("The heads should be an instance of `HeadsTIL`.") 93 94 # check the mask percentage 95 if not (0 < self.mask_percentage <= 1): 96 raise ValueError( 97 f"Mask percentage should be in (0, 1], but got {self.mask_percentage}." 98 ) 99 100 def on_train_start(self) -> None: 101 r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.""" 102 103 self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode) 104 105 # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 106 if self.task_id == 1: 107 for layer_name in self.backbone.weighted_layer_names: 108 layer = self.backbone.get_layer_by_name( 109 layer_name 110 ) # get the layer by its name 111 112 self.cumulative_weight_mask_for_previous_tasks[layer_name] = ( 113 torch.zeros_like(layer.weight).to(self.device) 114 ) 115 if layer.bias is not None: 116 self.cumulative_bias_mask_for_previous_tasks[layer_name] = ( 117 torch.zeros_like(layer.bias).to(self.device) 118 ) 119 else: 120 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 121 # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$) 122 123 def clip_grad_by_mask( 124 self, 125 ) -> None: 126 r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).""" 127 128 for layer_name in self.backbone.weighted_layer_names: 129 layer = self.backbone.get_layer_by_name(layer_name) 130 131 layer.weight.grad.data *= ( 132 1 - self.cumulative_weight_mask_for_previous_tasks[layer_name] 133 ) 134 if layer.bias is not None: 135 layer.bias.grad.data *= ( 136 1 - self.cumulative_bias_mask_for_previous_tasks[layer_name] 137 ) 138 139 def forward( 140 self, 141 input: torch.Tensor, 142 stage: str, 143 task_id: int | None = None, 144 ) -> tuple[Tensor, dict[str, Tensor]]: 145 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 146 147 **Args:** 148 - **input** (`Tensor`): the input tensor from data. 149 - **stage** (`str`): the stage of the forward pass, should be one of: 150 1. 'train': training stage. 151 2. 'validation': validation stage. 152 3. 'test': testing stage. 153 - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing). 154 155 **Returns:** 156 - **logits** (`Tensor`): the output logits tensor. 157 - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight. 158 - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`. 159 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 160 """ 161 feature, weight_mask, bias_mask, activations = self.backbone( 162 input, 163 stage=stage, 164 mask_percentage=self.mask_percentage, 165 test_mask=( 166 (self.weight_masks[task_id], self.bias_masks[task_id]) 167 if stage == "test" 168 else None 169 ), 170 ) 171 logits = self.heads(feature, task_id) 172 173 return ( 174 logits 175 if self.if_forward_func_return_logits_only 176 else (logits, weight_mask, bias_mask, activations) 177 ) 178 179 def training_step(self, batch: Any) -> dict[str, Tensor]: 180 r"""Training step for current task `self.task_id`. 181 182 **Args:** 183 - **batch** (`Any`): a batch of training data. 184 185 **Returns:** 186 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging. 187 """ 188 x, y = batch 189 190 # zero the gradients before forward pass in manual optimization mode 191 opt = self.optimizers() 192 opt.zero_grad() 193 194 # classification loss 195 logits, weight_mask, bias_mask, activations = self.forward( 196 x, stage="train", task_id=self.task_id 197 ) 198 loss_cls = self.criterion(logits, y) 199 200 # total loss 201 loss = loss_cls 202 203 # backward step (manually) 204 self.manual_backward(loss) # calculate the gradients 205 # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 206 self.clip_grad_by_mask() 207 208 # update parameters with the modified gradients 209 opt.step() 210 211 # predicted labels 212 preds = logits.argmax(dim=1) 213 214 # accuracy of the batch 215 acc = (preds == y).float().mean() 216 217 return { 218 "preds": preds, 219 "loss": loss, # return loss is essential for training step, or backpropagation will fail 220 "loss_cls": loss_cls, 221 "acc": acc, 222 "activations": activations, 223 "weight_mask": weight_mask, # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()` 224 "bias_mask": bias_mask, 225 } 226 227 def on_train_end(self) -> None: 228 r"""Store the weight and bias masks and update the cumulative masks after training the task.""" 229 230 # get the weight and bias mask for the current task 231 weight_mask_t = {} 232 bias_mask_t = {} 233 for layer_name in self.backbone.weighted_layer_names: 234 layer = self.backbone.get_layer_by_name(layer_name) 235 236 weight_mask_t[layer_name] = self.backbone.gate_fn.apply( 237 self.backbone.weight_score_t[layer_name].weight, self.mask_percentage 238 ) 239 if layer.bias is not None: 240 bias_mask_t[layer_name] = self.backbone.gate_fn.apply( 241 self.backbone.bias_score_t[layer_name].weight.squeeze( 242 0 243 ), # from (1, output_dim) to (output_dim, ) 244 self.mask_percentage, 245 ) 246 else: 247 bias_mask_t[layer_name] = None 248 249 # store the weight and bias mask for the current task 250 self.weight_masks[self.task_id] = weight_mask_t 251 self.bias_masks[self.task_id] = bias_mask_t 252 253 # update the cumulative mask 254 for layer_name in self.backbone.weighted_layer_names: 255 layer = self.backbone.get_layer_by_name(layer_name) 256 257 self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max( 258 self.cumulative_weight_mask_for_previous_tasks[layer_name], 259 weight_mask_t[layer_name], 260 ) 261 if layer.bias is not None: 262 print( 263 self.cumulative_bias_mask_for_previous_tasks[layer_name].shape, 264 bias_mask_t[layer_name].shape, 265 ) 266 self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max( 267 self.cumulative_bias_mask_for_previous_tasks[layer_name], 268 bias_mask_t[layer_name], 269 ) 270 else: 271 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 272 273 print(self.cumulative_bias_mask_for_previous_tasks) 274 275 def validation_step(self, batch: Any) -> dict[str, Tensor]: 276 r"""Validation step for current task `self.task_id`. 277 278 **Args:** 279 - **batch** (`Any`): a batch of validation data. 280 281 **Returns:** 282 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 283 """ 284 x, y = batch 285 logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 286 loss_cls = self.criterion(logits, y) 287 preds = logits.argmax(dim=1) 288 acc = (preds == y).float().mean() 289 290 return { 291 "loss_cls": loss_cls, 292 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()` 293 "preds": preds, 294 } 295 296 def test_step( 297 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 298 ) -> dict[str, Tensor]: 299 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 300 301 **Args:** 302 - **batch** (`Any`): a batch of test data. 303 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 304 305 **Returns:** 306 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 307 """ 308 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 309 310 x, y = batch 311 logits, _, _, _ = self.forward( 312 x, 313 stage="test", 314 task_id=test_task_id, 315 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 316 loss_cls = self.criterion(logits, y) 317 preds = logits.argmax(dim=1) 318 acc = (preds == y).float().mean() 319 320 return { 321 "loss_cls": loss_cls, 322 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_test_batch_end()` 323 "preds": preds, 324 }
WSN (Winning Subnetworks) algorithm.
An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task.
29 def __init__( 30 self, 31 backbone: WSNMaskBackbone, 32 heads: HeadsTIL | HeadDIL, 33 mask_percentage: float, 34 parameter_score_init_mode: str = "default", 35 non_algorithmic_hparams: dict[str, Any] = {}, 36 **kwargs, 37 ) -> None: 38 r"""Initialize the WSN algorithm with the network. 39 40 **Args:** 41 - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism. 42 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. WSN supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 43 - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 44 - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of: 45 1. 'default': the default initialization in the original WSN code. 46 2. 'N01': standard normal distribution $N(0, 1)$. 47 3. 'U01': uniform distribution $U(0, 1)$. 48 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 49 - **kwargs**: Reserved for multiple inheritance. 50 51 """ 52 super().__init__( 53 backbone=backbone, 54 heads=heads, 55 non_algorithmic_hparams=non_algorithmic_hparams, 56 **kwargs, 57 ) 58 59 self.mask_percentage: float = mask_percentage 60 r"""The percentage of parameters to be used for each task.""" 61 self.parameter_score_init_mode: str = parameter_score_init_mode 62 r"""The parameter score initialization mode.""" 63 64 # save additional algorithmic hyperparameters 65 self.save_hyperparameters( 66 "mask_percentage", 67 "parameter_score_init_mode", 68 ) 69 70 self.weight_masks: dict[int, dict[str, Tensor]] = {} 71 r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.""" 72 self.bias_masks: dict[int, dict[str, Tensor]] = {} 73 r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 74 75 self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {} 76 r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.""" 77 self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {} 78 r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 79 80 # set manual optimization 81 self.automatic_optimization = False 82 83 WSN.sanity_check(self)
Initialize the WSN algorithm with the network.
Args:
- backbone (
WSNMaskBackbone): must be a backbone network with the WSN mask mechanism. - heads (
HeadsTIL|HeadDIL): output heads. WSN supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). - mask_percentage (
float): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the WSN paper. - parameter_score_init_mode (
str): the initialization mode for parameter scores, must be one of:- 'default': the default initialization in the original WSN code.
- 'N01': standard normal distribution $N(0, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.
The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is None.
The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.
The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is None.
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False you are responsible for calling .backward(), .step(), .zero_grad().
85 def sanity_check(self) -> None: 86 r"""Sanity check.""" 87 88 # check the backbone and heads 89 if not isinstance(self.backbone, WSNMaskBackbone): 90 raise ValueError("The backbone should be an instance of WSNMaskBackbone.") 91 if not isinstance(self.heads, HeadsTIL): 92 raise ValueError("The heads should be an instance of `HeadsTIL`.") 93 94 # check the mask percentage 95 if not (0 < self.mask_percentage <= 1): 96 raise ValueError( 97 f"Mask percentage should be in (0, 1], but got {self.mask_percentage}." 98 )
Sanity check.
100 def on_train_start(self) -> None: 101 r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.""" 102 103 self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode) 104 105 # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 106 if self.task_id == 1: 107 for layer_name in self.backbone.weighted_layer_names: 108 layer = self.backbone.get_layer_by_name( 109 layer_name 110 ) # get the layer by its name 111 112 self.cumulative_weight_mask_for_previous_tasks[layer_name] = ( 113 torch.zeros_like(layer.weight).to(self.device) 114 ) 115 if layer.bias is not None: 116 self.cumulative_bias_mask_for_previous_tasks[layer_name] = ( 117 torch.zeros_like(layer.bias).to(self.device) 118 ) 119 else: 120 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 121 # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$)
Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.
123 def clip_grad_by_mask( 124 self, 125 ) -> None: 126 r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).""" 127 128 for layer_name in self.backbone.weighted_layer_names: 129 layer = self.backbone.get_layer_by_name(layer_name) 130 131 layer.weight.grad.data *= ( 132 1 - self.cumulative_weight_mask_for_previous_tasks[layer_name] 133 ) 134 if layer.bias is not None: 135 layer.bias.grad.data *= ( 136 1 - self.cumulative_bias_mask_for_previous_tasks[layer_name] 137 )
Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the WSN paper.
139 def forward( 140 self, 141 input: torch.Tensor, 142 stage: str, 143 task_id: int | None = None, 144 ) -> tuple[Tensor, dict[str, Tensor]]: 145 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 146 147 **Args:** 148 - **input** (`Tensor`): the input tensor from data. 149 - **stage** (`str`): the stage of the forward pass, should be one of: 150 1. 'train': training stage. 151 2. 'validation': validation stage. 152 3. 'test': testing stage. 153 - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing). 154 155 **Returns:** 156 - **logits** (`Tensor`): the output logits tensor. 157 - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight. 158 - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`. 159 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 160 """ 161 feature, weight_mask, bias_mask, activations = self.backbone( 162 input, 163 stage=stage, 164 mask_percentage=self.mask_percentage, 165 test_mask=( 166 (self.weight_masks[task_id], self.bias_masks[task_id]) 167 if stage == "test" 168 else None 169 ), 170 ) 171 logits = self.heads(feature, task_id) 172 173 return ( 174 logits 175 if self.if_forward_func_return_logits_only 176 else (logits, weight_mask, bias_mask, activations) 177 )
The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.
Args:
- input (
Tensor): the input tensor from data. - stage (
str): the stage of the forward pass, should be one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int|None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current taskself.task_id. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).
Returns:
- logits (
Tensor): the output logits tensor. - weight_mask (
dict[str, Tensor]): the weight mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has same (output features, input features) as weight. - bias_mask (
dict[str, Tensor]): the bias mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it isNone. - activations (
dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
179 def training_step(self, batch: Any) -> dict[str, Tensor]: 180 r"""Training step for current task `self.task_id`. 181 182 **Args:** 183 - **batch** (`Any`): a batch of training data. 184 185 **Returns:** 186 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging. 187 """ 188 x, y = batch 189 190 # zero the gradients before forward pass in manual optimization mode 191 opt = self.optimizers() 192 opt.zero_grad() 193 194 # classification loss 195 logits, weight_mask, bias_mask, activations = self.forward( 196 x, stage="train", task_id=self.task_id 197 ) 198 loss_cls = self.criterion(logits, y) 199 200 # total loss 201 loss = loss_cls 202 203 # backward step (manually) 204 self.manual_backward(loss) # calculate the gradients 205 # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 206 self.clip_grad_by_mask() 207 208 # update parameters with the modified gradients 209 opt.step() 210 211 # predicted labels 212 preds = logits.argmax(dim=1) 213 214 # accuracy of the batch 215 acc = (preds == y).float().mean() 216 217 return { 218 "preds": preds, 219 "loss": loss, # return loss is essential for training step, or backpropagation will fail 220 "loss_cls": loss_cls, 221 "acc": acc, 222 "activations": activations, 223 "weight_mask": weight_mask, # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()` 224 "bias_mask": bias_mask, 225 }
Training step for current task self.task_id.
Args:
- batch (
Any): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary containing loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
227 def on_train_end(self) -> None: 228 r"""Store the weight and bias masks and update the cumulative masks after training the task.""" 229 230 # get the weight and bias mask for the current task 231 weight_mask_t = {} 232 bias_mask_t = {} 233 for layer_name in self.backbone.weighted_layer_names: 234 layer = self.backbone.get_layer_by_name(layer_name) 235 236 weight_mask_t[layer_name] = self.backbone.gate_fn.apply( 237 self.backbone.weight_score_t[layer_name].weight, self.mask_percentage 238 ) 239 if layer.bias is not None: 240 bias_mask_t[layer_name] = self.backbone.gate_fn.apply( 241 self.backbone.bias_score_t[layer_name].weight.squeeze( 242 0 243 ), # from (1, output_dim) to (output_dim, ) 244 self.mask_percentage, 245 ) 246 else: 247 bias_mask_t[layer_name] = None 248 249 # store the weight and bias mask for the current task 250 self.weight_masks[self.task_id] = weight_mask_t 251 self.bias_masks[self.task_id] = bias_mask_t 252 253 # update the cumulative mask 254 for layer_name in self.backbone.weighted_layer_names: 255 layer = self.backbone.get_layer_by_name(layer_name) 256 257 self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max( 258 self.cumulative_weight_mask_for_previous_tasks[layer_name], 259 weight_mask_t[layer_name], 260 ) 261 if layer.bias is not None: 262 print( 263 self.cumulative_bias_mask_for_previous_tasks[layer_name].shape, 264 bias_mask_t[layer_name].shape, 265 ) 266 self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max( 267 self.cumulative_bias_mask_for_previous_tasks[layer_name], 268 bias_mask_t[layer_name], 269 ) 270 else: 271 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 272 273 print(self.cumulative_bias_mask_for_previous_tasks)
Store the weight and bias masks and update the cumulative masks after training the task.
275 def validation_step(self, batch: Any) -> dict[str, Tensor]: 276 r"""Validation step for current task `self.task_id`. 277 278 **Args:** 279 - **batch** (`Any`): a batch of validation data. 280 281 **Returns:** 282 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 283 """ 284 x, y = batch 285 logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 286 loss_cls = self.criterion(logits, y) 287 preds = logits.argmax(dim=1) 288 acc = (preds == y).float().mean() 289 290 return { 291 "loss_cls": loss_cls, 292 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()` 293 "preds": preds, 294 }
Validation step for current task self.task_id.
Args:
- batch (
Any): a batch of validation data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
296 def test_step( 297 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 298 ) -> dict[str, Tensor]: 299 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 300 301 **Args:** 302 - **batch** (`Any`): a batch of test data. 303 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 304 305 **Returns:** 306 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 307 """ 308 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 309 310 x, y = batch 311 logits, _, _, _ = self.forward( 312 x, 313 stage="test", 314 task_id=test_task_id, 315 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 316 loss_cls = self.criterion(logits, y) 317 preds = logits.argmax(dim=1) 318 acc = (preds == y).float().mean() 319 320 return { 321 "loss_cls": loss_cls, 322 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_test_batch_end()` 323 "preds": preds, 324 }
Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.
Args:
- batch (
Any): a batch of test data. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.