clarena.cl_algorithms.wsn
The submodule in cl_algorithms for WSN (Winning Subnetworks) algorithm.
1r""" 2The submodule in `cl_algorithms` for [WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm. 3""" 4 5__all__ = ["WSN"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12from torch.utils.data import DataLoader 13 14from clarena.backbones import WSNMaskBackbone 15from clarena.cl_algorithms import CLAlgorithm 16from clarena.heads import HeadsTIL 17 18# always get logger for built-in logging in each module 19pylogger = logging.getLogger(__name__) 20 21 22class WSN(CLAlgorithm): 23 r"""[WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm. 24 25 An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task. 26 """ 27 28 def __init__( 29 self, 30 backbone: WSNMaskBackbone, 31 heads: HeadsTIL, 32 mask_percentage: float, 33 parameter_score_init_mode: str = "default", 34 non_algorithmic_hparams: dict[str, Any] = {}, 35 ) -> None: 36 r"""Initialize the WSN algorithm with the network. 37 38 **Args:** 39 - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism. 40 - **heads** (`HeadsTIL`): output heads. WSN only supports TIL (Task-Incremental Learning). 41 - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 42 - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of: 43 1. 'default': the default initialization in the original WSN code. 44 2. 'N01': standard normal distribution $N(0, 1)$. 45 3. 'U01': uniform distribution $U(0, 1)$. 46 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 47 48 """ 49 super().__init__( 50 backbone=backbone, 51 heads=heads, 52 non_algorithmic_hparams=non_algorithmic_hparams, 53 ) 54 55 self.mask_percentage: float = mask_percentage 56 r"""The percentage of parameters to be used for each task.""" 57 self.parameter_score_init_mode: str = parameter_score_init_mode 58 r"""The parameter score initialization mode.""" 59 60 # save additional algorithmic hyperparameters 61 self.save_hyperparameters( 62 "mask_percentage", 63 "parameter_score_init_mode", 64 ) 65 66 self.weight_masks: dict[int, dict[str, Tensor]] = {} 67 r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.""" 68 self.bias_masks: dict[int, dict[str, Tensor]] = {} 69 r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 70 71 self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {} 72 r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.""" 73 self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {} 74 r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 75 76 # set manual optimization 77 self.automatic_optimization = False 78 79 WSN.sanity_check(self) 80 81 def sanity_check(self) -> None: 82 r"""Sanity check.""" 83 84 # check the backbone and heads 85 if not isinstance(self.backbone, WSNMaskBackbone): 86 raise ValueError("The backbone should be an instance of WSNMaskBackbone.") 87 if not isinstance(self.heads, HeadsTIL): 88 raise ValueError("The heads should be an instance of `HeadsTIL`.") 89 90 # check the mask percentage 91 if not (0 < self.mask_percentage <= 1): 92 raise ValueError( 93 f"Mask percentage should be in (0, 1], but got {self.mask_percentage}." 94 ) 95 96 def on_train_start(self) -> None: 97 r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.""" 98 99 self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode) 100 101 # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 102 if self.task_id == 1: 103 for layer_name in self.backbone.weighted_layer_names: 104 layer = self.backbone.get_layer_by_name( 105 layer_name 106 ) # get the layer by its name 107 108 self.cumulative_weight_mask_for_previous_tasks[layer_name] = ( 109 torch.zeros_like(layer.weight).to(self.device) 110 ) 111 if layer.bias is not None: 112 self.cumulative_bias_mask_for_previous_tasks[layer_name] = ( 113 torch.zeros_like(layer.bias).to(self.device) 114 ) 115 else: 116 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 117 # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$) 118 119 def clip_grad_by_mask( 120 self, 121 ) -> None: 122 r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).""" 123 124 for layer_name in self.backbone.weighted_layer_names: 125 layer = self.backbone.get_layer_by_name(layer_name) 126 127 layer.weight.grad.data *= ( 128 1 - self.cumulative_weight_mask_for_previous_tasks[layer_name] 129 ) 130 if layer.bias is not None: 131 layer.bias.grad.data *= ( 132 1 - self.cumulative_bias_mask_for_previous_tasks[layer_name] 133 ) 134 135 def forward( 136 self, 137 input: torch.Tensor, 138 stage: str, 139 task_id: int | None = None, 140 ) -> tuple[Tensor, dict[str, Tensor]]: 141 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 142 143 **Args:** 144 - **input** (`Tensor`): the input tensor from data. 145 - **stage** (`str`): the stage of the forward pass, should be one of: 146 1. 'train': training stage. 147 2. 'validation': validation stage. 148 3. 'test': testing stage. 149 - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing). 150 151 **Returns:** 152 - **logits** (`Tensor`): the output logits tensor. 153 - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight. 154 - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`. 155 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 156 """ 157 feature, weight_mask, bias_mask, activations = self.backbone( 158 input, 159 stage=stage, 160 mask_percentage=self.mask_percentage, 161 test_mask=( 162 (self.weight_masks[task_id], self.bias_masks[task_id]) 163 if stage == "test" 164 else None 165 ), 166 ) 167 logits = self.heads(feature, task_id) 168 169 return ( 170 logits 171 if self.if_forward_func_return_logits_only 172 else (logits, weight_mask, bias_mask, activations) 173 ) 174 175 def training_step(self, batch: Any) -> dict[str, Tensor]: 176 r"""Training step for current task `self.task_id`. 177 178 **Args:** 179 - **batch** (`Any`): a batch of training data. 180 181 **Returns:** 182 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging. 183 """ 184 x, y = batch 185 186 # zero the gradients before forward pass in manual optimization mode 187 opt = self.optimizers() 188 opt.zero_grad() 189 190 # classification loss 191 logits, weight_mask, bias_mask, activations = self.forward( 192 x, stage="train", task_id=self.task_id 193 ) 194 loss_cls = self.criterion(logits, y) 195 196 # total loss 197 loss = loss_cls 198 199 # backward step (manually) 200 self.manual_backward(loss) # calculate the gradients 201 # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 202 self.clip_grad_by_mask() 203 204 # update parameters with the modified gradients 205 opt.step() 206 207 # accuracy of the batch 208 acc = (logits.argmax(dim=1) == y).float().mean() 209 210 return { 211 "loss": loss, # return loss is essential for training step, or backpropagation will fail 212 "loss_cls": loss_cls, 213 "acc": acc, 214 "activations": activations, 215 "weight_mask": weight_mask, # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()` 216 "bias_mask": bias_mask, 217 } 218 219 def on_train_end(self) -> None: 220 r"""Store the weight and bias masks and update the cumulative masks after training the task.""" 221 222 # get the weight and bias mask for the current task 223 weight_mask_t = {} 224 bias_mask_t = {} 225 for layer_name in self.backbone.weighted_layer_names: 226 layer = self.backbone.get_layer_by_name(layer_name) 227 228 weight_mask_t[layer_name] = self.backbone.gate_fn.apply( 229 self.backbone.weight_score_t[layer_name].weight, self.mask_percentage 230 ) 231 if layer.bias is not None: 232 bias_mask_t[layer_name] = self.backbone.gate_fn.apply( 233 self.backbone.bias_score_t[layer_name].weight.squeeze( 234 0 235 ), # from (1, output_dim) to (output_dim, ) 236 self.mask_percentage, 237 ) 238 else: 239 bias_mask_t[layer_name] = None 240 241 # store the weight and bias mask for the current task 242 self.weight_masks[self.task_id] = weight_mask_t 243 self.bias_masks[self.task_id] = bias_mask_t 244 245 # update the cumulative mask 246 for layer_name in self.backbone.weighted_layer_names: 247 layer = self.backbone.get_layer_by_name(layer_name) 248 249 self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max( 250 self.cumulative_weight_mask_for_previous_tasks[layer_name], 251 weight_mask_t[layer_name], 252 ) 253 if layer.bias is not None: 254 print( 255 self.cumulative_bias_mask_for_previous_tasks[layer_name].shape, 256 bias_mask_t[layer_name].shape, 257 ) 258 self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max( 259 self.cumulative_bias_mask_for_previous_tasks[layer_name], 260 bias_mask_t[layer_name], 261 ) 262 else: 263 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 264 265 print(self.cumulative_bias_mask_for_previous_tasks) 266 267 def validation_step(self, batch: Any) -> dict[str, Tensor]: 268 r"""Validation step for current task `self.task_id`. 269 270 **Args:** 271 - **batch** (`Any`): a batch of validation data. 272 273 **Returns:** 274 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 275 """ 276 x, y = batch 277 logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 278 loss_cls = self.criterion(logits, y) 279 acc = (logits.argmax(dim=1) == y).float().mean() 280 281 return { 282 "loss_cls": loss_cls, 283 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()` 284 } 285 286 def test_step( 287 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 288 ) -> dict[str, Tensor]: 289 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 290 291 **Args:** 292 - **batch** (`Any`): a batch of test data. 293 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 294 295 **Returns:** 296 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 297 """ 298 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 299 300 x, y = batch 301 logits, _, _, _ = self.forward( 302 x, 303 stage="test", 304 task_id=test_task_id, 305 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 306 loss_cls = self.criterion(logits, y) 307 acc = (logits.argmax(dim=1) == y).float().mean() 308 309 return { 310 "loss_cls": loss_cls, 311 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_test_batch_end()` 312 }
23class WSN(CLAlgorithm): 24 r"""[WSN (Winning Subnetworks)](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf) algorithm. 25 26 An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task. 27 """ 28 29 def __init__( 30 self, 31 backbone: WSNMaskBackbone, 32 heads: HeadsTIL, 33 mask_percentage: float, 34 parameter_score_init_mode: str = "default", 35 non_algorithmic_hparams: dict[str, Any] = {}, 36 ) -> None: 37 r"""Initialize the WSN algorithm with the network. 38 39 **Args:** 40 - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism. 41 - **heads** (`HeadsTIL`): output heads. WSN only supports TIL (Task-Incremental Learning). 42 - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 43 - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of: 44 1. 'default': the default initialization in the original WSN code. 45 2. 'N01': standard normal distribution $N(0, 1)$. 46 3. 'U01': uniform distribution $U(0, 1)$. 47 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 48 49 """ 50 super().__init__( 51 backbone=backbone, 52 heads=heads, 53 non_algorithmic_hparams=non_algorithmic_hparams, 54 ) 55 56 self.mask_percentage: float = mask_percentage 57 r"""The percentage of parameters to be used for each task.""" 58 self.parameter_score_init_mode: str = parameter_score_init_mode 59 r"""The parameter score initialization mode.""" 60 61 # save additional algorithmic hyperparameters 62 self.save_hyperparameters( 63 "mask_percentage", 64 "parameter_score_init_mode", 65 ) 66 67 self.weight_masks: dict[int, dict[str, Tensor]] = {} 68 r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.""" 69 self.bias_masks: dict[int, dict[str, Tensor]] = {} 70 r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 71 72 self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {} 73 r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.""" 74 self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {} 75 r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 76 77 # set manual optimization 78 self.automatic_optimization = False 79 80 WSN.sanity_check(self) 81 82 def sanity_check(self) -> None: 83 r"""Sanity check.""" 84 85 # check the backbone and heads 86 if not isinstance(self.backbone, WSNMaskBackbone): 87 raise ValueError("The backbone should be an instance of WSNMaskBackbone.") 88 if not isinstance(self.heads, HeadsTIL): 89 raise ValueError("The heads should be an instance of `HeadsTIL`.") 90 91 # check the mask percentage 92 if not (0 < self.mask_percentage <= 1): 93 raise ValueError( 94 f"Mask percentage should be in (0, 1], but got {self.mask_percentage}." 95 ) 96 97 def on_train_start(self) -> None: 98 r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.""" 99 100 self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode) 101 102 # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 103 if self.task_id == 1: 104 for layer_name in self.backbone.weighted_layer_names: 105 layer = self.backbone.get_layer_by_name( 106 layer_name 107 ) # get the layer by its name 108 109 self.cumulative_weight_mask_for_previous_tasks[layer_name] = ( 110 torch.zeros_like(layer.weight).to(self.device) 111 ) 112 if layer.bias is not None: 113 self.cumulative_bias_mask_for_previous_tasks[layer_name] = ( 114 torch.zeros_like(layer.bias).to(self.device) 115 ) 116 else: 117 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 118 # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$) 119 120 def clip_grad_by_mask( 121 self, 122 ) -> None: 123 r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).""" 124 125 for layer_name in self.backbone.weighted_layer_names: 126 layer = self.backbone.get_layer_by_name(layer_name) 127 128 layer.weight.grad.data *= ( 129 1 - self.cumulative_weight_mask_for_previous_tasks[layer_name] 130 ) 131 if layer.bias is not None: 132 layer.bias.grad.data *= ( 133 1 - self.cumulative_bias_mask_for_previous_tasks[layer_name] 134 ) 135 136 def forward( 137 self, 138 input: torch.Tensor, 139 stage: str, 140 task_id: int | None = None, 141 ) -> tuple[Tensor, dict[str, Tensor]]: 142 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 143 144 **Args:** 145 - **input** (`Tensor`): the input tensor from data. 146 - **stage** (`str`): the stage of the forward pass, should be one of: 147 1. 'train': training stage. 148 2. 'validation': validation stage. 149 3. 'test': testing stage. 150 - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing). 151 152 **Returns:** 153 - **logits** (`Tensor`): the output logits tensor. 154 - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight. 155 - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`. 156 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 157 """ 158 feature, weight_mask, bias_mask, activations = self.backbone( 159 input, 160 stage=stage, 161 mask_percentage=self.mask_percentage, 162 test_mask=( 163 (self.weight_masks[task_id], self.bias_masks[task_id]) 164 if stage == "test" 165 else None 166 ), 167 ) 168 logits = self.heads(feature, task_id) 169 170 return ( 171 logits 172 if self.if_forward_func_return_logits_only 173 else (logits, weight_mask, bias_mask, activations) 174 ) 175 176 def training_step(self, batch: Any) -> dict[str, Tensor]: 177 r"""Training step for current task `self.task_id`. 178 179 **Args:** 180 - **batch** (`Any`): a batch of training data. 181 182 **Returns:** 183 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging. 184 """ 185 x, y = batch 186 187 # zero the gradients before forward pass in manual optimization mode 188 opt = self.optimizers() 189 opt.zero_grad() 190 191 # classification loss 192 logits, weight_mask, bias_mask, activations = self.forward( 193 x, stage="train", task_id=self.task_id 194 ) 195 loss_cls = self.criterion(logits, y) 196 197 # total loss 198 loss = loss_cls 199 200 # backward step (manually) 201 self.manual_backward(loss) # calculate the gradients 202 # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 203 self.clip_grad_by_mask() 204 205 # update parameters with the modified gradients 206 opt.step() 207 208 # accuracy of the batch 209 acc = (logits.argmax(dim=1) == y).float().mean() 210 211 return { 212 "loss": loss, # return loss is essential for training step, or backpropagation will fail 213 "loss_cls": loss_cls, 214 "acc": acc, 215 "activations": activations, 216 "weight_mask": weight_mask, # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()` 217 "bias_mask": bias_mask, 218 } 219 220 def on_train_end(self) -> None: 221 r"""Store the weight and bias masks and update the cumulative masks after training the task.""" 222 223 # get the weight and bias mask for the current task 224 weight_mask_t = {} 225 bias_mask_t = {} 226 for layer_name in self.backbone.weighted_layer_names: 227 layer = self.backbone.get_layer_by_name(layer_name) 228 229 weight_mask_t[layer_name] = self.backbone.gate_fn.apply( 230 self.backbone.weight_score_t[layer_name].weight, self.mask_percentage 231 ) 232 if layer.bias is not None: 233 bias_mask_t[layer_name] = self.backbone.gate_fn.apply( 234 self.backbone.bias_score_t[layer_name].weight.squeeze( 235 0 236 ), # from (1, output_dim) to (output_dim, ) 237 self.mask_percentage, 238 ) 239 else: 240 bias_mask_t[layer_name] = None 241 242 # store the weight and bias mask for the current task 243 self.weight_masks[self.task_id] = weight_mask_t 244 self.bias_masks[self.task_id] = bias_mask_t 245 246 # update the cumulative mask 247 for layer_name in self.backbone.weighted_layer_names: 248 layer = self.backbone.get_layer_by_name(layer_name) 249 250 self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max( 251 self.cumulative_weight_mask_for_previous_tasks[layer_name], 252 weight_mask_t[layer_name], 253 ) 254 if layer.bias is not None: 255 print( 256 self.cumulative_bias_mask_for_previous_tasks[layer_name].shape, 257 bias_mask_t[layer_name].shape, 258 ) 259 self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max( 260 self.cumulative_bias_mask_for_previous_tasks[layer_name], 261 bias_mask_t[layer_name], 262 ) 263 else: 264 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 265 266 print(self.cumulative_bias_mask_for_previous_tasks) 267 268 def validation_step(self, batch: Any) -> dict[str, Tensor]: 269 r"""Validation step for current task `self.task_id`. 270 271 **Args:** 272 - **batch** (`Any`): a batch of validation data. 273 274 **Returns:** 275 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 276 """ 277 x, y = batch 278 logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 279 loss_cls = self.criterion(logits, y) 280 acc = (logits.argmax(dim=1) == y).float().mean() 281 282 return { 283 "loss_cls": loss_cls, 284 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()` 285 } 286 287 def test_step( 288 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 289 ) -> dict[str, Tensor]: 290 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 291 292 **Args:** 293 - **batch** (`Any`): a batch of test data. 294 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 295 296 **Returns:** 297 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 298 """ 299 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 300 301 x, y = batch 302 logits, _, _, _ = self.forward( 303 x, 304 stage="test", 305 task_id=test_task_id, 306 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 307 loss_cls = self.criterion(logits, y) 308 acc = (logits.argmax(dim=1) == y).float().mean() 309 310 return { 311 "loss_cls": loss_cls, 312 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_test_batch_end()` 313 }
WSN (Winning Subnetworks) algorithm.
An architecture-based continual learning approach that trains learnable parameter-wise scores and selects the most scored c% of network parameters per task.
29 def __init__( 30 self, 31 backbone: WSNMaskBackbone, 32 heads: HeadsTIL, 33 mask_percentage: float, 34 parameter_score_init_mode: str = "default", 35 non_algorithmic_hparams: dict[str, Any] = {}, 36 ) -> None: 37 r"""Initialize the WSN algorithm with the network. 38 39 **Args:** 40 - **backbone** (`WSNMaskBackbone`): must be a backbone network with the WSN mask mechanism. 41 - **heads** (`HeadsTIL`): output heads. WSN only supports TIL (Task-Incremental Learning). 42 - **mask_percentage** (`float`): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 43 - **parameter_score_init_mode** (`str`): the initialization mode for parameter scores, must be one of: 44 1. 'default': the default initialization in the original WSN code. 45 2. 'N01': standard normal distribution $N(0, 1)$. 46 3. 'U01': uniform distribution $U(0, 1)$. 47 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 48 49 """ 50 super().__init__( 51 backbone=backbone, 52 heads=heads, 53 non_algorithmic_hparams=non_algorithmic_hparams, 54 ) 55 56 self.mask_percentage: float = mask_percentage 57 r"""The percentage of parameters to be used for each task.""" 58 self.parameter_score_init_mode: str = parameter_score_init_mode 59 r"""The parameter score initialization mode.""" 60 61 # save additional algorithmic hyperparameters 62 self.save_hyperparameters( 63 "mask_percentage", 64 "parameter_score_init_mode", 65 ) 66 67 self.weight_masks: dict[int, dict[str, Tensor]] = {} 68 r"""The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.""" 69 self.bias_masks: dict[int, dict[str, Tensor]] = {} 70 r"""The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 71 72 self.cumulative_weight_mask_for_previous_tasks: dict[str, Tensor] = {} 73 r"""The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.""" 74 self.cumulative_bias_mask_for_previous_tasks: dict[str, dict[str, Tensor]] = {} 75 r"""The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is `None`.""" 76 77 # set manual optimization 78 self.automatic_optimization = False 79 80 WSN.sanity_check(self)
Initialize the WSN algorithm with the network.
Args:
- backbone (
WSNMaskBackbone): must be a backbone network with the WSN mask mechanism. - heads (
HeadsTIL): output heads. WSN only supports TIL (Task-Incremental Learning). - mask_percentage (
float): the percentage $c\%$ of parameters to be used for each task. See Sec. 3 and Eq. (4) in the WSN paper. - parameter_score_init_mode (
str): the initialization mode for parameter scores, must be one of:- 'default': the default initialization in the original WSN code.
- 'N01': standard normal distribution $N(0, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility.
The binary weight mask of each previous task percentile-gated from the weight score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, input features) as weight.
The binary bias mask of each previous task percentile-gated from the bias score. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is None.
The cumulative binary weight mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the weight score. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has the same size (output features, input features) as weight.
The cumulative binary bias mask $\mathbf{M}_{t-1}$ of previous tasks $1, \cdots, t-1$, percentile-gated from the bias score. It is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has the same size (output features, ) as bias. If the layer doesn't have bias, it is None.
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False you are responsible for calling .backward(), .step(), .zero_grad().
82 def sanity_check(self) -> None: 83 r"""Sanity check.""" 84 85 # check the backbone and heads 86 if not isinstance(self.backbone, WSNMaskBackbone): 87 raise ValueError("The backbone should be an instance of WSNMaskBackbone.") 88 if not isinstance(self.heads, HeadsTIL): 89 raise ValueError("The heads should be an instance of `HeadsTIL`.") 90 91 # check the mask percentage 92 if not (0 < self.mask_percentage <= 1): 93 raise ValueError( 94 f"Mask percentage should be in (0, 1], but got {self.mask_percentage}." 95 )
Sanity check.
97 def on_train_start(self) -> None: 98 r"""Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.""" 99 100 self.backbone.initialize_parameter_score(mode=self.parameter_score_init_mode) 101 102 # initialize the cumulative mask at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 103 if self.task_id == 1: 104 for layer_name in self.backbone.weighted_layer_names: 105 layer = self.backbone.get_layer_by_name( 106 layer_name 107 ) # get the layer by its name 108 109 self.cumulative_weight_mask_for_previous_tasks[layer_name] = ( 110 torch.zeros_like(layer.weight).to(self.device) 111 ) 112 if layer.bias is not None: 113 self.cumulative_bias_mask_for_previous_tasks[layer_name] = ( 114 torch.zeros_like(layer.bias).to(self.device) 115 ) 116 else: 117 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 118 # the cumulative mask $\mathrm{M}_{t-1}$ is initialized as zeros mask ($t = 1$)
Initialize the parameter scores before training the next task and initialize the cumulative masks at the beginning of the first task.
120 def clip_grad_by_mask( 121 self, 122 ) -> None: 123 r"""Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf).""" 124 125 for layer_name in self.backbone.weighted_layer_names: 126 layer = self.backbone.get_layer_by_name(layer_name) 127 128 layer.weight.grad.data *= ( 129 1 - self.cumulative_weight_mask_for_previous_tasks[layer_name] 130 ) 131 if layer.bias is not None: 132 layer.bias.grad.data *= ( 133 1 - self.cumulative_bias_mask_for_previous_tasks[layer_name] 134 )
Clip the gradients by the cumulative masks. The gradients are multiplied by (1 - cumulative_previous_mask) to keep previously masked parameters fixed. See Eq. (4) in the WSN paper.
136 def forward( 137 self, 138 input: torch.Tensor, 139 stage: str, 140 task_id: int | None = None, 141 ) -> tuple[Tensor, dict[str, Tensor]]: 142 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 143 144 **Args:** 145 - **input** (`Tensor`): the input tensor from data. 146 - **stage** (`str`): the stage of the forward pass, should be one of: 147 1. 'train': training stage. 148 2. 'validation': validation stage. 149 3. 'test': testing stage. 150 - **task_id** (`int` | `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing). 151 152 **Returns:** 153 - **logits** (`Tensor`): the output logits tensor. 154 - **weight_mask** (`dict[str, Tensor]`): the weight mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, input features) as weight. 155 - **bias_mask** (`dict[str, Tensor]`): the bias mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it is `None`. 156 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 157 """ 158 feature, weight_mask, bias_mask, activations = self.backbone( 159 input, 160 stage=stage, 161 mask_percentage=self.mask_percentage, 162 test_mask=( 163 (self.weight_masks[task_id], self.bias_masks[task_id]) 164 if stage == "test" 165 else None 166 ), 167 ) 168 logits = self.heads(feature, task_id) 169 170 return ( 171 logits 172 if self.if_forward_func_return_logits_only 173 else (logits, weight_mask, bias_mask, activations) 174 )
The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.
Args:
- input (
Tensor): the input tensor from data. - stage (
str): the stage of the forward pass, should be one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int|None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current taskself.task_id. If the stage is 'test', it could be from any seen task (TIL uses the provided task IDs for testing).
Returns:
- logits (
Tensor): the output logits tensor. - weight_mask (
dict[str, Tensor]): the weight mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has same (output features, input features) as weight. - bias_mask (
dict[str, Tensor]): the bias mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has same (output features, ) as bias. If the layer doesn't have bias, it isNone. - activations (
dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
176 def training_step(self, batch: Any) -> dict[str, Tensor]: 177 r"""Training step for current task `self.task_id`. 178 179 **Args:** 180 - **batch** (`Any`): a batch of training data. 181 182 **Returns:** 183 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging. 184 """ 185 x, y = batch 186 187 # zero the gradients before forward pass in manual optimization mode 188 opt = self.optimizers() 189 opt.zero_grad() 190 191 # classification loss 192 logits, weight_mask, bias_mask, activations = self.forward( 193 x, stage="train", task_id=self.task_id 194 ) 195 loss_cls = self.criterion(logits, y) 196 197 # total loss 198 loss = loss_cls 199 200 # backward step (manually) 201 self.manual_backward(loss) # calculate the gradients 202 # WSN hard-clips gradients using the cumulative masks. See Eq. (4) in the [WSN paper](https://proceedings.mlr.press/v162/kang22b/kang22b.pdf). 203 self.clip_grad_by_mask() 204 205 # update parameters with the modified gradients 206 opt.step() 207 208 # accuracy of the batch 209 acc = (logits.argmax(dim=1) == y).float().mean() 210 211 return { 212 "loss": loss, # return loss is essential for training step, or backpropagation will fail 213 "loss_cls": loss_cls, 214 "acc": acc, 215 "activations": activations, 216 "weight_mask": weight_mask, # return other metrics for Lightning loggers callback to handle at `on_train_batch_end()` 217 "bias_mask": bias_mask, 218 }
Training step for current task self.task_id.
Args:
- batch (
Any): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary containing loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For WSN, it includes 'weight_mask' and 'bias_mask' for logging.
220 def on_train_end(self) -> None: 221 r"""Store the weight and bias masks and update the cumulative masks after training the task.""" 222 223 # get the weight and bias mask for the current task 224 weight_mask_t = {} 225 bias_mask_t = {} 226 for layer_name in self.backbone.weighted_layer_names: 227 layer = self.backbone.get_layer_by_name(layer_name) 228 229 weight_mask_t[layer_name] = self.backbone.gate_fn.apply( 230 self.backbone.weight_score_t[layer_name].weight, self.mask_percentage 231 ) 232 if layer.bias is not None: 233 bias_mask_t[layer_name] = self.backbone.gate_fn.apply( 234 self.backbone.bias_score_t[layer_name].weight.squeeze( 235 0 236 ), # from (1, output_dim) to (output_dim, ) 237 self.mask_percentage, 238 ) 239 else: 240 bias_mask_t[layer_name] = None 241 242 # store the weight and bias mask for the current task 243 self.weight_masks[self.task_id] = weight_mask_t 244 self.bias_masks[self.task_id] = bias_mask_t 245 246 # update the cumulative mask 247 for layer_name in self.backbone.weighted_layer_names: 248 layer = self.backbone.get_layer_by_name(layer_name) 249 250 self.cumulative_weight_mask_for_previous_tasks[layer_name] = torch.max( 251 self.cumulative_weight_mask_for_previous_tasks[layer_name], 252 weight_mask_t[layer_name], 253 ) 254 if layer.bias is not None: 255 print( 256 self.cumulative_bias_mask_for_previous_tasks[layer_name].shape, 257 bias_mask_t[layer_name].shape, 258 ) 259 self.cumulative_bias_mask_for_previous_tasks[layer_name] = torch.max( 260 self.cumulative_bias_mask_for_previous_tasks[layer_name], 261 bias_mask_t[layer_name], 262 ) 263 else: 264 self.cumulative_bias_mask_for_previous_tasks[layer_name] = None 265 266 print(self.cumulative_bias_mask_for_previous_tasks)
Store the weight and bias masks and update the cumulative masks after training the task.
268 def validation_step(self, batch: Any) -> dict[str, Tensor]: 269 r"""Validation step for current task `self.task_id`. 270 271 **Args:** 272 - **batch** (`Any`): a batch of validation data. 273 274 **Returns:** 275 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 276 """ 277 x, y = batch 278 logits, _, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 279 loss_cls = self.criterion(logits, y) 280 acc = (logits.argmax(dim=1) == y).float().mean() 281 282 return { 283 "loss_cls": loss_cls, 284 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_validation_batch_end()` 285 }
Validation step for current task self.task_id.
Args:
- batch (
Any): a batch of validation data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
287 def test_step( 288 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 289 ) -> dict[str, Tensor]: 290 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 291 292 **Args:** 293 - **batch** (`Any`): a batch of test data. 294 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 295 296 **Returns:** 297 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 298 """ 299 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 300 301 x, y = batch 302 logits, _, _, _ = self.forward( 303 x, 304 stage="test", 305 task_id=test_task_id, 306 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 307 loss_cls = self.criterion(logits, y) 308 acc = (logits.argmax(dim=1) == y).float().mean() 309 310 return { 311 "loss_cls": loss_cls, 312 "acc": acc, # return metrics for Lightning loggers callback to handle at `on_test_batch_end()` 313 }
Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.
Args:
- batch (
Any): a batch of test data. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.