clarena.cl_algorithms.ewc
The submodule in cl_algorithms
for EWC (Elastic Weight Consolidation) algorithm.
1r""" 2The submodule in `cl_algorithms` for [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114). 3""" 4 5__all__ = ["EWC"] 6 7import logging 8from copy import deepcopy 9from typing import Any 10 11import torch 12from torch import Tensor, nn 13 14from clarena.backbones import CLBackbone 15from clarena.cl_algorithms import Finetuning 16from clarena.cl_algorithms.regularisers import ParameterChangeReg 17from clarena.cl_heads import HeadsCIL, HeadsTIL 18 19# always get logger for built-in logging in each module 20pylogger = logging.getLogger(__name__) 21 22 23class EWC(Finetuning): 24 r"""EWC (Elastic Weight Consolidation) algorithm. 25 26 [EWC (Elastic Weight Consolidation, 2017)](https://www.pnas.org/doi/10.1073/pnas.1611835114) is a regularisation-based continual learning approach that calculates parameter importance for the previous tasks and penalises the current task loss with the importance of the parameters. 27 28 We implement EWC as a subclass of Finetuning algorithm, as EWC has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 29 """ 30 31 def __init__( 32 self, 33 backbone: CLBackbone, 34 heads: HeadsTIL | HeadsCIL, 35 parameter_change_reg_factor: float, 36 parameter_change_reg_p_norm: float, 37 ) -> None: 38 r"""Initialise the HAT algorithm with the network. 39 40 **Args:** 41 - **backbone** (`CLBackbone`): backbone network. 42 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 43 - **parameter_change_reg_factor** (`float`): the parameter change regularisation factor. It controls the strength of preventing forgetting. 44 - **parameter_change_reg_p_norm** (`float`): the norm of the distance of parameters between previous tasks and current task in the parameter change regularisation. 45 46 """ 47 Finetuning.__init__(self, backbone=backbone, heads=heads) 48 49 self.parameter_importance: dict[str, dict[str, Tensor]] = {} 50 r"""Store the parameter importance of each previous task. Keys are task IDs (string type) and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer. 51 """ 52 self.previous_task_backbones: dict[str, nn.Module] = {} 53 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 54 55 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC. 56 """ 57 58 self.parameter_change_reg_factor = parameter_change_reg_factor 59 r"""Store parameter change regularisation factor.""" 60 self.parameter_change_reg_p_norm = parameter_change_reg_p_norm 61 r"""Store norm of the distance used in parameter change regularisation.""" 62 self.parameter_change_reg = ParameterChangeReg( 63 factor=parameter_change_reg_factor, 64 p_norm=parameter_change_reg_p_norm, 65 ) 66 r"""Initialise and store the parameter change regulariser.""" 67 68 EWC.sanity_check(self) 69 70 def sanity_check(self) -> None: 71 r"""Check the sanity of the arguments. 72 73 **Raises:** 74 - **ValueError**: If the regularisation factor is not positive. 75 """ 76 77 if self.parameter_change_reg_factor <= 0: 78 raise ValueError( 79 f"The parameter change regularisation factor should be positive, but got {self.parameter_change_reg_factor}." 80 ) 81 82 def calculate_parameter_importance(self) -> None: 83 r"""Calculate the parameter importance for the learned task. This is only called after the training of a task, which is the last previous task $t-1$. The calculated importance is stored in `self.parameter_importance[self.task_id]` for constructing the regularisation loss in the future tasks. 84 85 According to [the EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114), the importance tensor is a Laplace approximation to Fisher information matrix by taking the digonal, i.e. $F_i$, where $i$ is the index of a parameter. The calculation is not following that theory but the derived formula below: 86 87 $$\omega_i = F_i =\frac{1}{N_{t-1}} \sum_{(\mathbf{x}, y)\in \mathcal{D}^{(t-1)}_{\text{train}}} \left[\frac{\partial l(f^{(t-1)}\left(\mathbf{x}, \theta), y\right)}{\partial \theta_i}\right]^2$$ 88 89 For a parameter $i$, its importance is the magnitude (square here) of gradient of the loss of model just trained over the training data just used. The $l$ is the classification loss. It shows the sensitivity of the loss to the parameter. The larger it is, the more it changed the performance (which is the loss) of the model, which indicates the importance of the parameter. 90 """ 91 parameter_importance_t = {} 92 93 # set model to evaluation mode to prevent updating the model parameters 94 self.eval() 95 96 # get the training data 97 last_task_train_dataloaders = self.trainer.datamodule.train_dataloader() 98 99 # initialise the accumulation of the squared gradients 100 num_data = 0 101 for param_name, param in self.backbone.named_parameters(): 102 parameter_importance_t[param_name] = torch.zeros_like(param) 103 104 for x, y in last_task_train_dataloaders: 105 106 # move data to device manually 107 x = x.to(self.device) 108 y = y.to(self.device) 109 110 batch_size = len(y) 111 num_data += batch_size 112 113 # compute the gradients within a batch 114 self.backbone.zero_grad() # reset gradients 115 logits, _ = self.forward(x, stage="train", task_id=self.task_id) 116 loss_cls = self.criterion(logits, y) 117 loss_cls.backward() # compute gradients 118 119 # collect and accumulate the squared gradients into parameter importance 120 for param_name, param in self.backbone.named_parameters(): 121 parameter_importance_t[param_name] += batch_size * param.grad**2 122 123 num_params = sum(p.numel() for p in self.backbone.parameters()) 124 125 for param_name, param in self.backbone.named_parameters(): 126 parameter_importance_t[param_name] /= ( 127 num_data * num_params 128 ) # average over data and parameters 129 130 self.parameter_importance[self.task_id] = parameter_importance_t 131 132 def training_step(self, batch: Any) -> dict[str, Tensor]: 133 r"""Training step for current task `self.task_id`. 134 135 **Args:** 136 - **batch** (`Any`): a batch of training data. 137 138 **Returns:** 139 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 140 """ 141 x, y = batch 142 143 # classification loss 144 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 145 loss_cls = self.criterion(logits, y) 146 147 # regularisation loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 148 loss_reg = 0.0 149 for previous_task_id in range(1, self.task_id): 150 # sum over all previous models, because [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) says: "When moving to a third task, task C, EWC will try to keep the network parameters close to the learned parameters of both tasks A and B. This can be enforced either with two separate penalties or as one by noting that the sum of two quadratic penalties is itself a quadratic penalty." 151 loss_reg += self.parameter_change_reg( 152 target_model=self.backbone, 153 ref_model=self.previous_task_backbones[previous_task_id], 154 weights=self.parameter_importance[previous_task_id], 155 ) 156 157 if self.task_id != 1: 158 loss_reg /= ( 159 self.task_id 160 ) # average over tasks to avoid linear increase of the regularisation loss 161 162 # total loss 163 loss = loss_cls + loss_reg 164 165 # accuracy of the batch 166 acc = (logits.argmax(dim=1) == y).float().mean() 167 168 return { 169 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 170 "loss_cls": loss_cls, 171 "loss_reg": loss_reg, 172 "acc": acc, 173 "hidden_features": hidden_features, 174 } 175 176 def on_train_end(self) -> None: 177 r"""Calculate the parameter importance and store the backbone model after the training of a task. 178 179 The calculated importance and model are stored in `self.parameter_importance[self.task_id]` and `self.previous_task_backbones[self.task_id]` respectively for constructing the regularisation loss in the future tasks. 180 """ 181 self.calculate_parameter_importance() 182 183 previous_backbone = deepcopy(self.backbone) 184 previous_backbone.eval() # set the store model to evaluation mode to prevent updating 185 self.previous_task_backbones[self.task_id] = previous_backbone
24class EWC(Finetuning): 25 r"""EWC (Elastic Weight Consolidation) algorithm. 26 27 [EWC (Elastic Weight Consolidation, 2017)](https://www.pnas.org/doi/10.1073/pnas.1611835114) is a regularisation-based continual learning approach that calculates parameter importance for the previous tasks and penalises the current task loss with the importance of the parameters. 28 29 We implement EWC as a subclass of Finetuning algorithm, as EWC has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 30 """ 31 32 def __init__( 33 self, 34 backbone: CLBackbone, 35 heads: HeadsTIL | HeadsCIL, 36 parameter_change_reg_factor: float, 37 parameter_change_reg_p_norm: float, 38 ) -> None: 39 r"""Initialise the HAT algorithm with the network. 40 41 **Args:** 42 - **backbone** (`CLBackbone`): backbone network. 43 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 44 - **parameter_change_reg_factor** (`float`): the parameter change regularisation factor. It controls the strength of preventing forgetting. 45 - **parameter_change_reg_p_norm** (`float`): the norm of the distance of parameters between previous tasks and current task in the parameter change regularisation. 46 47 """ 48 Finetuning.__init__(self, backbone=backbone, heads=heads) 49 50 self.parameter_importance: dict[str, dict[str, Tensor]] = {} 51 r"""Store the parameter importance of each previous task. Keys are task IDs (string type) and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer. 52 """ 53 self.previous_task_backbones: dict[str, nn.Module] = {} 54 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 55 56 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC. 57 """ 58 59 self.parameter_change_reg_factor = parameter_change_reg_factor 60 r"""Store parameter change regularisation factor.""" 61 self.parameter_change_reg_p_norm = parameter_change_reg_p_norm 62 r"""Store norm of the distance used in parameter change regularisation.""" 63 self.parameter_change_reg = ParameterChangeReg( 64 factor=parameter_change_reg_factor, 65 p_norm=parameter_change_reg_p_norm, 66 ) 67 r"""Initialise and store the parameter change regulariser.""" 68 69 EWC.sanity_check(self) 70 71 def sanity_check(self) -> None: 72 r"""Check the sanity of the arguments. 73 74 **Raises:** 75 - **ValueError**: If the regularisation factor is not positive. 76 """ 77 78 if self.parameter_change_reg_factor <= 0: 79 raise ValueError( 80 f"The parameter change regularisation factor should be positive, but got {self.parameter_change_reg_factor}." 81 ) 82 83 def calculate_parameter_importance(self) -> None: 84 r"""Calculate the parameter importance for the learned task. This is only called after the training of a task, which is the last previous task $t-1$. The calculated importance is stored in `self.parameter_importance[self.task_id]` for constructing the regularisation loss in the future tasks. 85 86 According to [the EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114), the importance tensor is a Laplace approximation to Fisher information matrix by taking the digonal, i.e. $F_i$, where $i$ is the index of a parameter. The calculation is not following that theory but the derived formula below: 87 88 $$\omega_i = F_i =\frac{1}{N_{t-1}} \sum_{(\mathbf{x}, y)\in \mathcal{D}^{(t-1)}_{\text{train}}} \left[\frac{\partial l(f^{(t-1)}\left(\mathbf{x}, \theta), y\right)}{\partial \theta_i}\right]^2$$ 89 90 For a parameter $i$, its importance is the magnitude (square here) of gradient of the loss of model just trained over the training data just used. The $l$ is the classification loss. It shows the sensitivity of the loss to the parameter. The larger it is, the more it changed the performance (which is the loss) of the model, which indicates the importance of the parameter. 91 """ 92 parameter_importance_t = {} 93 94 # set model to evaluation mode to prevent updating the model parameters 95 self.eval() 96 97 # get the training data 98 last_task_train_dataloaders = self.trainer.datamodule.train_dataloader() 99 100 # initialise the accumulation of the squared gradients 101 num_data = 0 102 for param_name, param in self.backbone.named_parameters(): 103 parameter_importance_t[param_name] = torch.zeros_like(param) 104 105 for x, y in last_task_train_dataloaders: 106 107 # move data to device manually 108 x = x.to(self.device) 109 y = y.to(self.device) 110 111 batch_size = len(y) 112 num_data += batch_size 113 114 # compute the gradients within a batch 115 self.backbone.zero_grad() # reset gradients 116 logits, _ = self.forward(x, stage="train", task_id=self.task_id) 117 loss_cls = self.criterion(logits, y) 118 loss_cls.backward() # compute gradients 119 120 # collect and accumulate the squared gradients into parameter importance 121 for param_name, param in self.backbone.named_parameters(): 122 parameter_importance_t[param_name] += batch_size * param.grad**2 123 124 num_params = sum(p.numel() for p in self.backbone.parameters()) 125 126 for param_name, param in self.backbone.named_parameters(): 127 parameter_importance_t[param_name] /= ( 128 num_data * num_params 129 ) # average over data and parameters 130 131 self.parameter_importance[self.task_id] = parameter_importance_t 132 133 def training_step(self, batch: Any) -> dict[str, Tensor]: 134 r"""Training step for current task `self.task_id`. 135 136 **Args:** 137 - **batch** (`Any`): a batch of training data. 138 139 **Returns:** 140 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 141 """ 142 x, y = batch 143 144 # classification loss 145 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 146 loss_cls = self.criterion(logits, y) 147 148 # regularisation loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 149 loss_reg = 0.0 150 for previous_task_id in range(1, self.task_id): 151 # sum over all previous models, because [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) says: "When moving to a third task, task C, EWC will try to keep the network parameters close to the learned parameters of both tasks A and B. This can be enforced either with two separate penalties or as one by noting that the sum of two quadratic penalties is itself a quadratic penalty." 152 loss_reg += self.parameter_change_reg( 153 target_model=self.backbone, 154 ref_model=self.previous_task_backbones[previous_task_id], 155 weights=self.parameter_importance[previous_task_id], 156 ) 157 158 if self.task_id != 1: 159 loss_reg /= ( 160 self.task_id 161 ) # average over tasks to avoid linear increase of the regularisation loss 162 163 # total loss 164 loss = loss_cls + loss_reg 165 166 # accuracy of the batch 167 acc = (logits.argmax(dim=1) == y).float().mean() 168 169 return { 170 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 171 "loss_cls": loss_cls, 172 "loss_reg": loss_reg, 173 "acc": acc, 174 "hidden_features": hidden_features, 175 } 176 177 def on_train_end(self) -> None: 178 r"""Calculate the parameter importance and store the backbone model after the training of a task. 179 180 The calculated importance and model are stored in `self.parameter_importance[self.task_id]` and `self.previous_task_backbones[self.task_id]` respectively for constructing the regularisation loss in the future tasks. 181 """ 182 self.calculate_parameter_importance() 183 184 previous_backbone = deepcopy(self.backbone) 185 previous_backbone.eval() # set the store model to evaluation mode to prevent updating 186 self.previous_task_backbones[self.task_id] = previous_backbone
EWC (Elastic Weight Consolidation) algorithm.
EWC (Elastic Weight Consolidation, 2017) is a regularisation-based continual learning approach that calculates parameter importance for the previous tasks and penalises the current task loss with the importance of the parameters.
We implement EWC as a subclass of Finetuning algorithm, as EWC has the same forward()
, validation_step()
and test_step()
method as Finetuning
class.
32 def __init__( 33 self, 34 backbone: CLBackbone, 35 heads: HeadsTIL | HeadsCIL, 36 parameter_change_reg_factor: float, 37 parameter_change_reg_p_norm: float, 38 ) -> None: 39 r"""Initialise the HAT algorithm with the network. 40 41 **Args:** 42 - **backbone** (`CLBackbone`): backbone network. 43 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 44 - **parameter_change_reg_factor** (`float`): the parameter change regularisation factor. It controls the strength of preventing forgetting. 45 - **parameter_change_reg_p_norm** (`float`): the norm of the distance of parameters between previous tasks and current task in the parameter change regularisation. 46 47 """ 48 Finetuning.__init__(self, backbone=backbone, heads=heads) 49 50 self.parameter_importance: dict[str, dict[str, Tensor]] = {} 51 r"""Store the parameter importance of each previous task. Keys are task IDs (string type) and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer. 52 """ 53 self.previous_task_backbones: dict[str, nn.Module] = {} 54 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 55 56 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC. 57 """ 58 59 self.parameter_change_reg_factor = parameter_change_reg_factor 60 r"""Store parameter change regularisation factor.""" 61 self.parameter_change_reg_p_norm = parameter_change_reg_p_norm 62 r"""Store norm of the distance used in parameter change regularisation.""" 63 self.parameter_change_reg = ParameterChangeReg( 64 factor=parameter_change_reg_factor, 65 p_norm=parameter_change_reg_p_norm, 66 ) 67 r"""Initialise and store the parameter change regulariser.""" 68 69 EWC.sanity_check(self)
Initialise the HAT algorithm with the network.
Args:
- backbone (
CLBackbone
): backbone network. - heads (
HeadsTIL
|HeadsCIL
): output heads. - parameter_change_reg_factor (
float
): the parameter change regularisation factor. It controls the strength of preventing forgetting. - parameter_change_reg_p_norm (
float
): the norm of the distance of parameters between previous tasks and current task in the parameter change regularisation.
Store the parameter importance of each previous task. Keys are task IDs (string type) and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by named_parameters()
of the nn.Module
) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.
Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a nn.Module
backbone after the corresponding previous task was trained.
Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.
71 def sanity_check(self) -> None: 72 r"""Check the sanity of the arguments. 73 74 **Raises:** 75 - **ValueError**: If the regularisation factor is not positive. 76 """ 77 78 if self.parameter_change_reg_factor <= 0: 79 raise ValueError( 80 f"The parameter change regularisation factor should be positive, but got {self.parameter_change_reg_factor}." 81 )
Check the sanity of the arguments.
Raises:
- ValueError: If the regularisation factor is not positive.
83 def calculate_parameter_importance(self) -> None: 84 r"""Calculate the parameter importance for the learned task. This is only called after the training of a task, which is the last previous task $t-1$. The calculated importance is stored in `self.parameter_importance[self.task_id]` for constructing the regularisation loss in the future tasks. 85 86 According to [the EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114), the importance tensor is a Laplace approximation to Fisher information matrix by taking the digonal, i.e. $F_i$, where $i$ is the index of a parameter. The calculation is not following that theory but the derived formula below: 87 88 $$\omega_i = F_i =\frac{1}{N_{t-1}} \sum_{(\mathbf{x}, y)\in \mathcal{D}^{(t-1)}_{\text{train}}} \left[\frac{\partial l(f^{(t-1)}\left(\mathbf{x}, \theta), y\right)}{\partial \theta_i}\right]^2$$ 89 90 For a parameter $i$, its importance is the magnitude (square here) of gradient of the loss of model just trained over the training data just used. The $l$ is the classification loss. It shows the sensitivity of the loss to the parameter. The larger it is, the more it changed the performance (which is the loss) of the model, which indicates the importance of the parameter. 91 """ 92 parameter_importance_t = {} 93 94 # set model to evaluation mode to prevent updating the model parameters 95 self.eval() 96 97 # get the training data 98 last_task_train_dataloaders = self.trainer.datamodule.train_dataloader() 99 100 # initialise the accumulation of the squared gradients 101 num_data = 0 102 for param_name, param in self.backbone.named_parameters(): 103 parameter_importance_t[param_name] = torch.zeros_like(param) 104 105 for x, y in last_task_train_dataloaders: 106 107 # move data to device manually 108 x = x.to(self.device) 109 y = y.to(self.device) 110 111 batch_size = len(y) 112 num_data += batch_size 113 114 # compute the gradients within a batch 115 self.backbone.zero_grad() # reset gradients 116 logits, _ = self.forward(x, stage="train", task_id=self.task_id) 117 loss_cls = self.criterion(logits, y) 118 loss_cls.backward() # compute gradients 119 120 # collect and accumulate the squared gradients into parameter importance 121 for param_name, param in self.backbone.named_parameters(): 122 parameter_importance_t[param_name] += batch_size * param.grad**2 123 124 num_params = sum(p.numel() for p in self.backbone.parameters()) 125 126 for param_name, param in self.backbone.named_parameters(): 127 parameter_importance_t[param_name] /= ( 128 num_data * num_params 129 ) # average over data and parameters 130 131 self.parameter_importance[self.task_id] = parameter_importance_t
Calculate the parameter importance for the learned task. This is only called after the training of a task, which is the last previous task $t-1$. The calculated importance is stored in self.parameter_importance[self.task_id]
for constructing the regularisation loss in the future tasks.
According to the EWC paper, the importance tensor is a Laplace approximation to Fisher information matrix by taking the digonal, i.e. $F_i$, where $i$ is the index of a parameter. The calculation is not following that theory but the derived formula below:
$$\omega_i = F_i =\frac{1}{N_{t-1}} \sum_{(\mathbf{x}, y)\in \mathcal{D}^{(t-1)}_{\text{train}}} \left[\frac{\partial l(f^{(t-1)}\left(\mathbf{x}, \theta), y\right)}{\partial \theta_i}\right]^2$$
For a parameter $i$, its importance is the magnitude (square here) of gradient of the loss of model just trained over the training data just used. The $l$ is the classification loss. It shows the sensitivity of the loss to the parameter. The larger it is, the more it changed the performance (which is the loss) of the model, which indicates the importance of the parameter.
133 def training_step(self, batch: Any) -> dict[str, Tensor]: 134 r"""Training step for current task `self.task_id`. 135 136 **Args:** 137 - **batch** (`Any`): a batch of training data. 138 139 **Returns:** 140 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 141 """ 142 x, y = batch 143 144 # classification loss 145 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 146 loss_cls = self.criterion(logits, y) 147 148 # regularisation loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 149 loss_reg = 0.0 150 for previous_task_id in range(1, self.task_id): 151 # sum over all previous models, because [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) says: "When moving to a third task, task C, EWC will try to keep the network parameters close to the learned parameters of both tasks A and B. This can be enforced either with two separate penalties or as one by noting that the sum of two quadratic penalties is itself a quadratic penalty." 152 loss_reg += self.parameter_change_reg( 153 target_model=self.backbone, 154 ref_model=self.previous_task_backbones[previous_task_id], 155 weights=self.parameter_importance[previous_task_id], 156 ) 157 158 if self.task_id != 1: 159 loss_reg /= ( 160 self.task_id 161 ) # average over tasks to avoid linear increase of the regularisation loss 162 163 # total loss 164 loss = loss_cls + loss_reg 165 166 # accuracy of the batch 167 acc = (logits.argmax(dim=1) == y).float().mean() 168 169 return { 170 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 171 "loss_cls": loss_cls, 172 "loss_reg": loss_reg, 173 "acc": acc, 174 "hidden_features": hidden_features, 175 }
Training step for current task self.task_id
.
Args:
- batch (
Any
): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this training step. Key (str
) is the metrics name, value (Tensor
) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
177 def on_train_end(self) -> None: 178 r"""Calculate the parameter importance and store the backbone model after the training of a task. 179 180 The calculated importance and model are stored in `self.parameter_importance[self.task_id]` and `self.previous_task_backbones[self.task_id]` respectively for constructing the regularisation loss in the future tasks. 181 """ 182 self.calculate_parameter_importance() 183 184 previous_backbone = deepcopy(self.backbone) 185 previous_backbone.eval() # set the store model to evaluation mode to prevent updating 186 self.previous_task_backbones[self.task_id] = previous_backbone
Calculate the parameter importance and store the backbone model after the training of a task.
The calculated importance and model are stored in self.parameter_importance[self.task_id]
and self.previous_task_backbones[self.task_id]
respectively for constructing the regularisation loss in the future tasks.