clarena.cl_algorithms.ewc
The submodule in cl_algorithms for EWC (Elastic Weight Consolidation) algorithm.
1r""" 2The submodule in `cl_algorithms` for [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114). 3""" 4 5__all__ = ["EWC", "AmnesiacEWC"] 6 7import logging 8from copy import deepcopy 9from typing import Any 10 11import torch 12from torch import Tensor, nn 13 14from clarena.backbones import CLBackbone 15from clarena.cl_algorithms import AmnesiacCLAlgorithm, Finetuning 16from clarena.cl_algorithms.regularizers import ParameterChangeReg 17from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL 18 19# always get logger for built-in logging in each module 20pylogger = logging.getLogger(__name__) 21 22 23class EWC(Finetuning): 24 r"""[EWC (Elastic Weight Consolidation)](https://www.pnas.org/doi/10.1073/pnas.1611835114) algorithm. 25 26 A regularization-based approach that calculates the fisher information as parameter importance for the previous tasks and penalizes the current task loss with the importance of the parameters. 27 28 We implement EWC as a subclass of Finetuning algorithm, as EWC has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 29 """ 30 31 def __init__( 32 self, 33 backbone: CLBackbone, 34 heads: HeadsTIL | HeadsCIL | HeadDIL, 35 parameter_change_reg_factor: float, 36 when_calculate_fisher_information: str, 37 non_algorithmic_hparams: dict[str, Any] = {}, 38 **kwargs, 39 ) -> None: 40 r"""Initialize the EWC algorithm with the network. 41 42 **Args:** 43 - **backbone** (`CLBackbone`): backbone network. 44 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 45 - **parameter_change_reg_factor** (`float`): the parameter change regularization factor. It controls the strength of preventing forgetting. 46 - **when_calculate_fisher_information** (`str`): when to calculate the fisher information. It should be one of the following: 47 1. 'train_end': calculate the fisher information at the end of training of the task. 48 2. 'train': accumulate the fisher information in the training step of the task. 49 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 50 - **kwargs**: Reserved for multiple inheritance. 51 52 """ 53 super().__init__( 54 backbone=backbone, 55 heads=heads, 56 non_algorithmic_hparams=non_algorithmic_hparams, 57 **kwargs, 58 ) 59 60 # save additional algorithmic hyperparameters 61 self.save_hyperparameters( 62 "parameter_change_reg_factor", 63 "when_calculate_fisher_information", 64 ) 65 66 self.parameter_importance: dict[int, dict[str, Tensor]] = {} 67 r"""The parameter importance of each previous task. Keys are task IDs and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer. 68 """ 69 70 self.previous_task_backbones: dict[int, nn.Module] = {} 71 r"""The backbone models of the previous tasks. Keys are task IDs and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 72 73 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC. 74 """ 75 self.parameter_importance_heads: dict[int, dict[str, Tensor]] = {} 76 r"""The head parameter importance of each previous task (DIL only).""" 77 self.previous_task_heads: dict[int, nn.Module] = {} 78 r"""The head models of the previous tasks (DIL only).""" 79 80 self.parameter_change_reg_factor = parameter_change_reg_factor 81 r"""The parameter change regularization factor.""" 82 self.parameter_change_reg = ParameterChangeReg( 83 factor=parameter_change_reg_factor, 84 ) 85 r"""Initialize and store the parameter change regularizer.""" 86 87 self.when_calculate_fisher_information: str = when_calculate_fisher_information 88 r"""When to calculate the fisher information.""" 89 self.num_data: int 90 r"""The number of data used to calculate the fisher information. It is used to average the fisher information over the data.""" 91 92 # set manual optimization because we need to access gradients to calculate the fisher information in the training step 93 self.automatic_optimization = False 94 95 EWC.sanity_check(self) 96 97 def sanity_check(self) -> None: 98 r"""Sanity check.""" 99 if self.parameter_change_reg_factor <= 0: 100 raise ValueError( 101 f"The parameter change regularization factor should be positive, but got {self.parameter_change_reg_factor}." 102 ) 103 104 def on_train_start(self) -> None: 105 r"""Initialize the parameter importance and num of data counter.""" 106 107 self.parameter_importance[self.task_id] = {} 108 for param_name, param in self.backbone.named_parameters(): 109 self.parameter_importance[self.task_id][param_name] = 0 * param.data 110 if isinstance(self.heads, HeadDIL): 111 self.parameter_importance_heads[self.task_id] = {} 112 for param_name, param in self.heads.named_parameters(): 113 self.parameter_importance_heads[self.task_id][param_name] = ( 114 0 * param.data 115 ) 116 self.num_data = 0 117 118 def training_step(self, batch: Any) -> dict[str, Tensor]: 119 r"""Training step for current task `self.task_id`.""" 120 x, y = batch 121 122 opt = self.optimizers() 123 opt.zero_grad() 124 125 # classification loss 126 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 127 loss_cls = self.criterion(logits, y) 128 129 batch_size = len(y) 130 self.num_data += batch_size 131 132 # accumulate fisher information during training step if specified 133 if self.when_calculate_fisher_information == "train": 134 # Use autograd.grad to get explicit gradients for Fisher accumulation without 135 # relying on global .backward() state in manual optimization. 136 backbone_params: list[tuple[str, Tensor]] = [] 137 for param_name, param in self.backbone.named_parameters(): 138 if not param.requires_grad: 139 continue 140 backbone_params.append((param_name, param)) 141 if isinstance(self.heads, HeadDIL): 142 head_params: list[tuple[str, Tensor]] = [] 143 for param_name, param in self.heads.named_parameters(): 144 if not param.requires_grad: 145 continue 146 head_params.append((param_name, param)) 147 else: 148 head_params = [] 149 150 if backbone_params: 151 grads = torch.autograd.grad( 152 loss_cls, 153 [param for _, param in backbone_params], 154 retain_graph=True, 155 allow_unused=True, 156 ) 157 for (param_name, _), grad in zip(backbone_params, grads): 158 if grad is None: 159 continue 160 self.parameter_importance[self.task_id][param_name] += ( 161 batch_size * grad.detach() ** 2 162 ) 163 164 if head_params: 165 grads = torch.autograd.grad( 166 loss_cls, 167 [param for _, param in head_params], 168 retain_graph=True, 169 allow_unused=True, 170 ) 171 for (param_name, _), grad in zip(head_params, grads): 172 if grad is None: 173 continue 174 self.parameter_importance_heads[self.task_id][param_name] += ( 175 batch_size * grad.detach() ** 2 176 ) 177 178 # regularization loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) 179 ewc_reg = 0.0 180 for previous_task_id, previous_backbone in self.previous_task_backbones.items(): 181 ewc_reg += 0.5 * self.parameter_change_reg( 182 target_model=self.backbone, 183 ref_model=previous_backbone, 184 weights=self.parameter_importance[previous_task_id], 185 ) 186 if isinstance(self.heads, HeadDIL): 187 ewc_reg += 0.5 * self.parameter_change_reg( 188 target_model=self.heads, 189 ref_model=self.previous_task_heads[previous_task_id], 190 weights=self.parameter_importance_heads[previous_task_id], 191 ) 192 193 # total loss 194 loss = loss_cls + ewc_reg 195 196 self.manual_backward(loss) 197 opt.step() 198 199 # predicted labels 200 preds = logits.argmax(dim=1) 201 202 # accuracy of the batch 203 acc = (preds == y).float().mean() 204 205 return { 206 "preds": preds, 207 "loss": loss, # return loss is essential for training step, or backpropagation will fail 208 "loss_cls": loss_cls, 209 "ewc_reg": ewc_reg, 210 "acc": acc, 211 "activations": activations, 212 } 213 214 def on_train_end(self) -> None: 215 r"""Calculate the fisher information as parameter importance and store the backbone model after the training of a task.""" 216 217 # calculate fisher information at the end of training if specified 218 if self.when_calculate_fisher_information == "train_end": 219 fisher, fisher_heads, fisher_num_data = ( 220 self.accumulate_fisher_information_on_train_end() 221 ) 222 self.parameter_importance[self.task_id] = fisher 223 if fisher_heads is not None: 224 self.parameter_importance_heads[self.task_id] = fisher_heads 225 num_data = fisher_num_data 226 else: 227 num_data = self.num_data 228 229 # no matter when we calculate the fisher information, we need to average it over the number of data 230 for param_name, param in self.backbone.named_parameters(): 231 self.parameter_importance[self.task_id][param_name] /= num_data 232 if isinstance(self.heads, HeadDIL): 233 for param_name, param in self.heads.named_parameters(): 234 self.parameter_importance_heads[self.task_id][param_name] /= num_data 235 236 # store the backbone model after training the task 237 previous_backbone = deepcopy(self.backbone) 238 previous_backbone.eval() 239 self.previous_task_backbones[self.task_id] = previous_backbone 240 if isinstance(self.heads, HeadDIL): 241 previous_heads = deepcopy(self.heads) 242 previous_heads.eval() 243 self.previous_task_heads[self.task_id] = previous_heads 244 245 def accumulate_fisher_information_on_train_end( 246 self, 247 ) -> tuple[dict[str, Tensor], dict[str, Tensor] | None, int]: 248 r"""Accumulate the fisher information as the parameter importance for the learned task `self.task_id` at the end of its training.""" 249 fisher_information_t = {} 250 fisher_information_heads: dict[str, Tensor] | None = None 251 num_data = 0 252 253 self.eval() 254 last_task_train_dataloaders = self.trainer.datamodule.train_dataloader() 255 256 for param_name, param in self.backbone.named_parameters(): 257 fisher_information_t[param_name] = torch.zeros_like(param) 258 if isinstance(self.heads, HeadDIL): 259 fisher_information_heads = {} 260 for param_name, param in self.heads.named_parameters(): 261 fisher_information_heads[param_name] = torch.zeros_like(param) 262 263 for x, y in last_task_train_dataloaders: 264 x = x.to(self.device) 265 y = y.to(self.device) 266 batch_size = len(y) 267 num_data += batch_size 268 269 self.backbone.zero_grad() 270 if isinstance(self.heads, HeadDIL): 271 self.heads.zero_grad() 272 logits, _ = self.forward(x, stage="train", task_id=self.task_id) 273 loss_cls = self.criterion(logits, y) 274 loss_cls.backward() 275 276 for param_name, param in self.backbone.named_parameters(): 277 fisher_information_t[param_name] += batch_size * param.grad**2 278 if fisher_information_heads is not None: 279 for param_name, param in self.heads.named_parameters(): 280 if param.grad is None: 281 continue 282 fisher_information_heads[param_name] += batch_size * param.grad**2 283 284 return fisher_information_t, fisher_information_heads, num_data 285 286 287class AmnesiacEWC(AmnesiacCLAlgorithm, EWC): 288 r"""Amnesiac EWC algorithm.""" 289 290 def __init__( 291 self, 292 backbone: CLBackbone, 293 heads: HeadsTIL | HeadsCIL | HeadDIL, 294 non_algorithmic_hparams: dict[str, Any] = {}, 295 disable_unlearning: bool = False, 296 **kwargs, 297 ) -> None: 298 r"""Initialize the Amnesiac EWC algorithm with the network. 299 300 **Args:** 301 - **backbone** (`CLBackbone`): backbone network. 302 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 303 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 304 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 305 - **kwargs**: Reserved for multiple inheritance. 306 """ 307 super().__init__( 308 backbone=backbone, 309 heads=heads, 310 non_algorithmic_hparams=non_algorithmic_hparams, 311 disable_unlearning=disable_unlearning, 312 **kwargs, 313 ) 314 315 def on_train_start(self) -> None: 316 """Record backbone parameters before training current task.""" 317 EWC.on_train_start(self) 318 AmnesiacCLAlgorithm.on_train_start(self) 319 320 def on_train_end(self) -> None: 321 """Record backbone parameters before training current task.""" 322 EWC.on_train_end(self) 323 AmnesiacCLAlgorithm.on_train_end(self)
24class EWC(Finetuning): 25 r"""[EWC (Elastic Weight Consolidation)](https://www.pnas.org/doi/10.1073/pnas.1611835114) algorithm. 26 27 A regularization-based approach that calculates the fisher information as parameter importance for the previous tasks and penalizes the current task loss with the importance of the parameters. 28 29 We implement EWC as a subclass of Finetuning algorithm, as EWC has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 30 """ 31 32 def __init__( 33 self, 34 backbone: CLBackbone, 35 heads: HeadsTIL | HeadsCIL | HeadDIL, 36 parameter_change_reg_factor: float, 37 when_calculate_fisher_information: str, 38 non_algorithmic_hparams: dict[str, Any] = {}, 39 **kwargs, 40 ) -> None: 41 r"""Initialize the EWC algorithm with the network. 42 43 **Args:** 44 - **backbone** (`CLBackbone`): backbone network. 45 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 46 - **parameter_change_reg_factor** (`float`): the parameter change regularization factor. It controls the strength of preventing forgetting. 47 - **when_calculate_fisher_information** (`str`): when to calculate the fisher information. It should be one of the following: 48 1. 'train_end': calculate the fisher information at the end of training of the task. 49 2. 'train': accumulate the fisher information in the training step of the task. 50 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 51 - **kwargs**: Reserved for multiple inheritance. 52 53 """ 54 super().__init__( 55 backbone=backbone, 56 heads=heads, 57 non_algorithmic_hparams=non_algorithmic_hparams, 58 **kwargs, 59 ) 60 61 # save additional algorithmic hyperparameters 62 self.save_hyperparameters( 63 "parameter_change_reg_factor", 64 "when_calculate_fisher_information", 65 ) 66 67 self.parameter_importance: dict[int, dict[str, Tensor]] = {} 68 r"""The parameter importance of each previous task. Keys are task IDs and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer. 69 """ 70 71 self.previous_task_backbones: dict[int, nn.Module] = {} 72 r"""The backbone models of the previous tasks. Keys are task IDs and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 73 74 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC. 75 """ 76 self.parameter_importance_heads: dict[int, dict[str, Tensor]] = {} 77 r"""The head parameter importance of each previous task (DIL only).""" 78 self.previous_task_heads: dict[int, nn.Module] = {} 79 r"""The head models of the previous tasks (DIL only).""" 80 81 self.parameter_change_reg_factor = parameter_change_reg_factor 82 r"""The parameter change regularization factor.""" 83 self.parameter_change_reg = ParameterChangeReg( 84 factor=parameter_change_reg_factor, 85 ) 86 r"""Initialize and store the parameter change regularizer.""" 87 88 self.when_calculate_fisher_information: str = when_calculate_fisher_information 89 r"""When to calculate the fisher information.""" 90 self.num_data: int 91 r"""The number of data used to calculate the fisher information. It is used to average the fisher information over the data.""" 92 93 # set manual optimization because we need to access gradients to calculate the fisher information in the training step 94 self.automatic_optimization = False 95 96 EWC.sanity_check(self) 97 98 def sanity_check(self) -> None: 99 r"""Sanity check.""" 100 if self.parameter_change_reg_factor <= 0: 101 raise ValueError( 102 f"The parameter change regularization factor should be positive, but got {self.parameter_change_reg_factor}." 103 ) 104 105 def on_train_start(self) -> None: 106 r"""Initialize the parameter importance and num of data counter.""" 107 108 self.parameter_importance[self.task_id] = {} 109 for param_name, param in self.backbone.named_parameters(): 110 self.parameter_importance[self.task_id][param_name] = 0 * param.data 111 if isinstance(self.heads, HeadDIL): 112 self.parameter_importance_heads[self.task_id] = {} 113 for param_name, param in self.heads.named_parameters(): 114 self.parameter_importance_heads[self.task_id][param_name] = ( 115 0 * param.data 116 ) 117 self.num_data = 0 118 119 def training_step(self, batch: Any) -> dict[str, Tensor]: 120 r"""Training step for current task `self.task_id`.""" 121 x, y = batch 122 123 opt = self.optimizers() 124 opt.zero_grad() 125 126 # classification loss 127 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 128 loss_cls = self.criterion(logits, y) 129 130 batch_size = len(y) 131 self.num_data += batch_size 132 133 # accumulate fisher information during training step if specified 134 if self.when_calculate_fisher_information == "train": 135 # Use autograd.grad to get explicit gradients for Fisher accumulation without 136 # relying on global .backward() state in manual optimization. 137 backbone_params: list[tuple[str, Tensor]] = [] 138 for param_name, param in self.backbone.named_parameters(): 139 if not param.requires_grad: 140 continue 141 backbone_params.append((param_name, param)) 142 if isinstance(self.heads, HeadDIL): 143 head_params: list[tuple[str, Tensor]] = [] 144 for param_name, param in self.heads.named_parameters(): 145 if not param.requires_grad: 146 continue 147 head_params.append((param_name, param)) 148 else: 149 head_params = [] 150 151 if backbone_params: 152 grads = torch.autograd.grad( 153 loss_cls, 154 [param for _, param in backbone_params], 155 retain_graph=True, 156 allow_unused=True, 157 ) 158 for (param_name, _), grad in zip(backbone_params, grads): 159 if grad is None: 160 continue 161 self.parameter_importance[self.task_id][param_name] += ( 162 batch_size * grad.detach() ** 2 163 ) 164 165 if head_params: 166 grads = torch.autograd.grad( 167 loss_cls, 168 [param for _, param in head_params], 169 retain_graph=True, 170 allow_unused=True, 171 ) 172 for (param_name, _), grad in zip(head_params, grads): 173 if grad is None: 174 continue 175 self.parameter_importance_heads[self.task_id][param_name] += ( 176 batch_size * grad.detach() ** 2 177 ) 178 179 # regularization loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) 180 ewc_reg = 0.0 181 for previous_task_id, previous_backbone in self.previous_task_backbones.items(): 182 ewc_reg += 0.5 * self.parameter_change_reg( 183 target_model=self.backbone, 184 ref_model=previous_backbone, 185 weights=self.parameter_importance[previous_task_id], 186 ) 187 if isinstance(self.heads, HeadDIL): 188 ewc_reg += 0.5 * self.parameter_change_reg( 189 target_model=self.heads, 190 ref_model=self.previous_task_heads[previous_task_id], 191 weights=self.parameter_importance_heads[previous_task_id], 192 ) 193 194 # total loss 195 loss = loss_cls + ewc_reg 196 197 self.manual_backward(loss) 198 opt.step() 199 200 # predicted labels 201 preds = logits.argmax(dim=1) 202 203 # accuracy of the batch 204 acc = (preds == y).float().mean() 205 206 return { 207 "preds": preds, 208 "loss": loss, # return loss is essential for training step, or backpropagation will fail 209 "loss_cls": loss_cls, 210 "ewc_reg": ewc_reg, 211 "acc": acc, 212 "activations": activations, 213 } 214 215 def on_train_end(self) -> None: 216 r"""Calculate the fisher information as parameter importance and store the backbone model after the training of a task.""" 217 218 # calculate fisher information at the end of training if specified 219 if self.when_calculate_fisher_information == "train_end": 220 fisher, fisher_heads, fisher_num_data = ( 221 self.accumulate_fisher_information_on_train_end() 222 ) 223 self.parameter_importance[self.task_id] = fisher 224 if fisher_heads is not None: 225 self.parameter_importance_heads[self.task_id] = fisher_heads 226 num_data = fisher_num_data 227 else: 228 num_data = self.num_data 229 230 # no matter when we calculate the fisher information, we need to average it over the number of data 231 for param_name, param in self.backbone.named_parameters(): 232 self.parameter_importance[self.task_id][param_name] /= num_data 233 if isinstance(self.heads, HeadDIL): 234 for param_name, param in self.heads.named_parameters(): 235 self.parameter_importance_heads[self.task_id][param_name] /= num_data 236 237 # store the backbone model after training the task 238 previous_backbone = deepcopy(self.backbone) 239 previous_backbone.eval() 240 self.previous_task_backbones[self.task_id] = previous_backbone 241 if isinstance(self.heads, HeadDIL): 242 previous_heads = deepcopy(self.heads) 243 previous_heads.eval() 244 self.previous_task_heads[self.task_id] = previous_heads 245 246 def accumulate_fisher_information_on_train_end( 247 self, 248 ) -> tuple[dict[str, Tensor], dict[str, Tensor] | None, int]: 249 r"""Accumulate the fisher information as the parameter importance for the learned task `self.task_id` at the end of its training.""" 250 fisher_information_t = {} 251 fisher_information_heads: dict[str, Tensor] | None = None 252 num_data = 0 253 254 self.eval() 255 last_task_train_dataloaders = self.trainer.datamodule.train_dataloader() 256 257 for param_name, param in self.backbone.named_parameters(): 258 fisher_information_t[param_name] = torch.zeros_like(param) 259 if isinstance(self.heads, HeadDIL): 260 fisher_information_heads = {} 261 for param_name, param in self.heads.named_parameters(): 262 fisher_information_heads[param_name] = torch.zeros_like(param) 263 264 for x, y in last_task_train_dataloaders: 265 x = x.to(self.device) 266 y = y.to(self.device) 267 batch_size = len(y) 268 num_data += batch_size 269 270 self.backbone.zero_grad() 271 if isinstance(self.heads, HeadDIL): 272 self.heads.zero_grad() 273 logits, _ = self.forward(x, stage="train", task_id=self.task_id) 274 loss_cls = self.criterion(logits, y) 275 loss_cls.backward() 276 277 for param_name, param in self.backbone.named_parameters(): 278 fisher_information_t[param_name] += batch_size * param.grad**2 279 if fisher_information_heads is not None: 280 for param_name, param in self.heads.named_parameters(): 281 if param.grad is None: 282 continue 283 fisher_information_heads[param_name] += batch_size * param.grad**2 284 285 return fisher_information_t, fisher_information_heads, num_data
EWC (Elastic Weight Consolidation) algorithm.
A regularization-based approach that calculates the fisher information as parameter importance for the previous tasks and penalizes the current task loss with the importance of the parameters.
We implement EWC as a subclass of Finetuning algorithm, as EWC has the same forward(), validation_step() and test_step() method as Finetuning class.
32 def __init__( 33 self, 34 backbone: CLBackbone, 35 heads: HeadsTIL | HeadsCIL | HeadDIL, 36 parameter_change_reg_factor: float, 37 when_calculate_fisher_information: str, 38 non_algorithmic_hparams: dict[str, Any] = {}, 39 **kwargs, 40 ) -> None: 41 r"""Initialize the EWC algorithm with the network. 42 43 **Args:** 44 - **backbone** (`CLBackbone`): backbone network. 45 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 46 - **parameter_change_reg_factor** (`float`): the parameter change regularization factor. It controls the strength of preventing forgetting. 47 - **when_calculate_fisher_information** (`str`): when to calculate the fisher information. It should be one of the following: 48 1. 'train_end': calculate the fisher information at the end of training of the task. 49 2. 'train': accumulate the fisher information in the training step of the task. 50 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 51 - **kwargs**: Reserved for multiple inheritance. 52 53 """ 54 super().__init__( 55 backbone=backbone, 56 heads=heads, 57 non_algorithmic_hparams=non_algorithmic_hparams, 58 **kwargs, 59 ) 60 61 # save additional algorithmic hyperparameters 62 self.save_hyperparameters( 63 "parameter_change_reg_factor", 64 "when_calculate_fisher_information", 65 ) 66 67 self.parameter_importance: dict[int, dict[str, Tensor]] = {} 68 r"""The parameter importance of each previous task. Keys are task IDs and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer. 69 """ 70 71 self.previous_task_backbones: dict[int, nn.Module] = {} 72 r"""The backbone models of the previous tasks. Keys are task IDs and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 73 74 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC. 75 """ 76 self.parameter_importance_heads: dict[int, dict[str, Tensor]] = {} 77 r"""The head parameter importance of each previous task (DIL only).""" 78 self.previous_task_heads: dict[int, nn.Module] = {} 79 r"""The head models of the previous tasks (DIL only).""" 80 81 self.parameter_change_reg_factor = parameter_change_reg_factor 82 r"""The parameter change regularization factor.""" 83 self.parameter_change_reg = ParameterChangeReg( 84 factor=parameter_change_reg_factor, 85 ) 86 r"""Initialize and store the parameter change regularizer.""" 87 88 self.when_calculate_fisher_information: str = when_calculate_fisher_information 89 r"""When to calculate the fisher information.""" 90 self.num_data: int 91 r"""The number of data used to calculate the fisher information. It is used to average the fisher information over the data.""" 92 93 # set manual optimization because we need to access gradients to calculate the fisher information in the training step 94 self.automatic_optimization = False 95 96 EWC.sanity_check(self)
Initialize the EWC algorithm with the network.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL|HeadDIL): output heads. - parameter_change_reg_factor (
float): the parameter change regularization factor. It controls the strength of preventing forgetting. - when_calculate_fisher_information (
str): when to calculate the fisher information. It should be one of the following:- 'train_end': calculate the fisher information at the end of training of the task.
- 'train': accumulate the fisher information in the training step of the task.
- non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
The parameter importance of each previous task. Keys are task IDs and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by named_parameters() of the nn.Module) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.
The backbone models of the previous tasks. Keys are task IDs and values are the corresponding models. Each model is a nn.Module backbone after the corresponding previous task was trained.
Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.
The head parameter importance of each previous task (DIL only).
The head models of the previous tasks (DIL only).
The number of data used to calculate the fisher information. It is used to average the fisher information over the data.
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False you are responsible for calling .backward(), .step(), .zero_grad().
98 def sanity_check(self) -> None: 99 r"""Sanity check.""" 100 if self.parameter_change_reg_factor <= 0: 101 raise ValueError( 102 f"The parameter change regularization factor should be positive, but got {self.parameter_change_reg_factor}." 103 )
Sanity check.
105 def on_train_start(self) -> None: 106 r"""Initialize the parameter importance and num of data counter.""" 107 108 self.parameter_importance[self.task_id] = {} 109 for param_name, param in self.backbone.named_parameters(): 110 self.parameter_importance[self.task_id][param_name] = 0 * param.data 111 if isinstance(self.heads, HeadDIL): 112 self.parameter_importance_heads[self.task_id] = {} 113 for param_name, param in self.heads.named_parameters(): 114 self.parameter_importance_heads[self.task_id][param_name] = ( 115 0 * param.data 116 ) 117 self.num_data = 0
Initialize the parameter importance and num of data counter.
119 def training_step(self, batch: Any) -> dict[str, Tensor]: 120 r"""Training step for current task `self.task_id`.""" 121 x, y = batch 122 123 opt = self.optimizers() 124 opt.zero_grad() 125 126 # classification loss 127 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 128 loss_cls = self.criterion(logits, y) 129 130 batch_size = len(y) 131 self.num_data += batch_size 132 133 # accumulate fisher information during training step if specified 134 if self.when_calculate_fisher_information == "train": 135 # Use autograd.grad to get explicit gradients for Fisher accumulation without 136 # relying on global .backward() state in manual optimization. 137 backbone_params: list[tuple[str, Tensor]] = [] 138 for param_name, param in self.backbone.named_parameters(): 139 if not param.requires_grad: 140 continue 141 backbone_params.append((param_name, param)) 142 if isinstance(self.heads, HeadDIL): 143 head_params: list[tuple[str, Tensor]] = [] 144 for param_name, param in self.heads.named_parameters(): 145 if not param.requires_grad: 146 continue 147 head_params.append((param_name, param)) 148 else: 149 head_params = [] 150 151 if backbone_params: 152 grads = torch.autograd.grad( 153 loss_cls, 154 [param for _, param in backbone_params], 155 retain_graph=True, 156 allow_unused=True, 157 ) 158 for (param_name, _), grad in zip(backbone_params, grads): 159 if grad is None: 160 continue 161 self.parameter_importance[self.task_id][param_name] += ( 162 batch_size * grad.detach() ** 2 163 ) 164 165 if head_params: 166 grads = torch.autograd.grad( 167 loss_cls, 168 [param for _, param in head_params], 169 retain_graph=True, 170 allow_unused=True, 171 ) 172 for (param_name, _), grad in zip(head_params, grads): 173 if grad is None: 174 continue 175 self.parameter_importance_heads[self.task_id][param_name] += ( 176 batch_size * grad.detach() ** 2 177 ) 178 179 # regularization loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) 180 ewc_reg = 0.0 181 for previous_task_id, previous_backbone in self.previous_task_backbones.items(): 182 ewc_reg += 0.5 * self.parameter_change_reg( 183 target_model=self.backbone, 184 ref_model=previous_backbone, 185 weights=self.parameter_importance[previous_task_id], 186 ) 187 if isinstance(self.heads, HeadDIL): 188 ewc_reg += 0.5 * self.parameter_change_reg( 189 target_model=self.heads, 190 ref_model=self.previous_task_heads[previous_task_id], 191 weights=self.parameter_importance_heads[previous_task_id], 192 ) 193 194 # total loss 195 loss = loss_cls + ewc_reg 196 197 self.manual_backward(loss) 198 opt.step() 199 200 # predicted labels 201 preds = logits.argmax(dim=1) 202 203 # accuracy of the batch 204 acc = (preds == y).float().mean() 205 206 return { 207 "preds": preds, 208 "loss": loss, # return loss is essential for training step, or backpropagation will fail 209 "loss_cls": loss_cls, 210 "ewc_reg": ewc_reg, 211 "acc": acc, 212 "activations": activations, 213 }
Training step for current task self.task_id.
215 def on_train_end(self) -> None: 216 r"""Calculate the fisher information as parameter importance and store the backbone model after the training of a task.""" 217 218 # calculate fisher information at the end of training if specified 219 if self.when_calculate_fisher_information == "train_end": 220 fisher, fisher_heads, fisher_num_data = ( 221 self.accumulate_fisher_information_on_train_end() 222 ) 223 self.parameter_importance[self.task_id] = fisher 224 if fisher_heads is not None: 225 self.parameter_importance_heads[self.task_id] = fisher_heads 226 num_data = fisher_num_data 227 else: 228 num_data = self.num_data 229 230 # no matter when we calculate the fisher information, we need to average it over the number of data 231 for param_name, param in self.backbone.named_parameters(): 232 self.parameter_importance[self.task_id][param_name] /= num_data 233 if isinstance(self.heads, HeadDIL): 234 for param_name, param in self.heads.named_parameters(): 235 self.parameter_importance_heads[self.task_id][param_name] /= num_data 236 237 # store the backbone model after training the task 238 previous_backbone = deepcopy(self.backbone) 239 previous_backbone.eval() 240 self.previous_task_backbones[self.task_id] = previous_backbone 241 if isinstance(self.heads, HeadDIL): 242 previous_heads = deepcopy(self.heads) 243 previous_heads.eval() 244 self.previous_task_heads[self.task_id] = previous_heads
Calculate the fisher information as parameter importance and store the backbone model after the training of a task.
246 def accumulate_fisher_information_on_train_end( 247 self, 248 ) -> tuple[dict[str, Tensor], dict[str, Tensor] | None, int]: 249 r"""Accumulate the fisher information as the parameter importance for the learned task `self.task_id` at the end of its training.""" 250 fisher_information_t = {} 251 fisher_information_heads: dict[str, Tensor] | None = None 252 num_data = 0 253 254 self.eval() 255 last_task_train_dataloaders = self.trainer.datamodule.train_dataloader() 256 257 for param_name, param in self.backbone.named_parameters(): 258 fisher_information_t[param_name] = torch.zeros_like(param) 259 if isinstance(self.heads, HeadDIL): 260 fisher_information_heads = {} 261 for param_name, param in self.heads.named_parameters(): 262 fisher_information_heads[param_name] = torch.zeros_like(param) 263 264 for x, y in last_task_train_dataloaders: 265 x = x.to(self.device) 266 y = y.to(self.device) 267 batch_size = len(y) 268 num_data += batch_size 269 270 self.backbone.zero_grad() 271 if isinstance(self.heads, HeadDIL): 272 self.heads.zero_grad() 273 logits, _ = self.forward(x, stage="train", task_id=self.task_id) 274 loss_cls = self.criterion(logits, y) 275 loss_cls.backward() 276 277 for param_name, param in self.backbone.named_parameters(): 278 fisher_information_t[param_name] += batch_size * param.grad**2 279 if fisher_information_heads is not None: 280 for param_name, param in self.heads.named_parameters(): 281 if param.grad is None: 282 continue 283 fisher_information_heads[param_name] += batch_size * param.grad**2 284 285 return fisher_information_t, fisher_information_heads, num_data
Accumulate the fisher information as the parameter importance for the learned task self.task_id at the end of its training.
Inherited Members
288class AmnesiacEWC(AmnesiacCLAlgorithm, EWC): 289 r"""Amnesiac EWC algorithm.""" 290 291 def __init__( 292 self, 293 backbone: CLBackbone, 294 heads: HeadsTIL | HeadsCIL | HeadDIL, 295 non_algorithmic_hparams: dict[str, Any] = {}, 296 disable_unlearning: bool = False, 297 **kwargs, 298 ) -> None: 299 r"""Initialize the Amnesiac EWC algorithm with the network. 300 301 **Args:** 302 - **backbone** (`CLBackbone`): backbone network. 303 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 304 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 305 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 306 - **kwargs**: Reserved for multiple inheritance. 307 """ 308 super().__init__( 309 backbone=backbone, 310 heads=heads, 311 non_algorithmic_hparams=non_algorithmic_hparams, 312 disable_unlearning=disable_unlearning, 313 **kwargs, 314 ) 315 316 def on_train_start(self) -> None: 317 """Record backbone parameters before training current task.""" 318 EWC.on_train_start(self) 319 AmnesiacCLAlgorithm.on_train_start(self) 320 321 def on_train_end(self) -> None: 322 """Record backbone parameters before training current task.""" 323 EWC.on_train_end(self) 324 AmnesiacCLAlgorithm.on_train_end(self)
Amnesiac EWC algorithm.
291 def __init__( 292 self, 293 backbone: CLBackbone, 294 heads: HeadsTIL | HeadsCIL | HeadDIL, 295 non_algorithmic_hparams: dict[str, Any] = {}, 296 disable_unlearning: bool = False, 297 **kwargs, 298 ) -> None: 299 r"""Initialize the Amnesiac EWC algorithm with the network. 300 301 **Args:** 302 - **backbone** (`CLBackbone`): backbone network. 303 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 304 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 305 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 306 - **kwargs**: Reserved for multiple inheritance. 307 """ 308 super().__init__( 309 backbone=backbone, 310 heads=heads, 311 non_algorithmic_hparams=non_algorithmic_hparams, 312 disable_unlearning=disable_unlearning, 313 **kwargs, 314 )
Initialize the Amnesiac EWC algorithm with the network.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL|HeadDIL): output heads. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - disable_unlearning (
bool): whether to disable the unlearning functionality. Default isFalse. - kwargs: Reserved for multiple inheritance.
316 def on_train_start(self) -> None: 317 """Record backbone parameters before training current task.""" 318 EWC.on_train_start(self) 319 AmnesiacCLAlgorithm.on_train_start(self)
Record backbone parameters before training current task.
321 def on_train_end(self) -> None: 322 """Record backbone parameters before training current task.""" 323 EWC.on_train_end(self) 324 AmnesiacCLAlgorithm.on_train_end(self)
Record backbone parameters before training current task.