clarena.cl_algorithms.ewc

The submodule in cl_algorithms for EWC (Elastic Weight Consolidation) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114).
  3"""
  4
  5__all__ = ["EWC"]
  6
  7import logging
  8from copy import deepcopy
  9from typing import Any
 10
 11import torch
 12from torch import Tensor, nn
 13
 14from clarena.backbones import CLBackbone
 15from clarena.cl_algorithms import Finetuning
 16from clarena.cl_algorithms.regularisers import ParameterChangeReg
 17from clarena.cl_heads import HeadsCIL, HeadsTIL
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23class EWC(Finetuning):
 24    r"""EWC (Elastic Weight Consolidation) algorithm.
 25
 26    [EWC (Elastic Weight Consolidation, 2017)](https://www.pnas.org/doi/10.1073/pnas.1611835114) is a regularisation-based continual learning approach that calculates parameter importance for the previous tasks and penalises the current task loss with the importance of the parameters.
 27
 28    We implement EWC as a subclass of Finetuning algorithm, as EWC has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: CLBackbone,
 34        heads: HeadsTIL | HeadsCIL,
 35        parameter_change_reg_factor: float,
 36        parameter_change_reg_p_norm: float,
 37    ) -> None:
 38        r"""Initialise the HAT algorithm with the network.
 39
 40        **Args:**
 41        - **backbone** (`CLBackbone`): backbone network.
 42        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 43        - **parameter_change_reg_factor** (`float`): the parameter change regularisation factor. It controls the strength of preventing forgetting.
 44        - **parameter_change_reg_p_norm** (`float`): the norm of the distance of parameters between previous tasks and current task in the parameter change regularisation.
 45
 46        """
 47        Finetuning.__init__(self, backbone=backbone, heads=heads)
 48
 49        self.parameter_importance: dict[str, dict[str, Tensor]] = {}
 50        r"""Store the parameter importance of each previous task. Keys are task IDs (string type) and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.
 51        """
 52        self.previous_task_backbones: dict[str, nn.Module] = {}
 53        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
 54        
 55        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.
 56        """
 57
 58        self.parameter_change_reg_factor = parameter_change_reg_factor
 59        r"""Store parameter change regularisation factor."""
 60        self.parameter_change_reg_p_norm = parameter_change_reg_p_norm
 61        r"""Store norm of the distance used in parameter change regularisation."""
 62        self.parameter_change_reg = ParameterChangeReg(
 63            factor=parameter_change_reg_factor,
 64            p_norm=parameter_change_reg_p_norm,
 65        )
 66        r"""Initialise and store the parameter change regulariser."""
 67
 68        EWC.sanity_check(self)
 69
 70    def sanity_check(self) -> None:
 71        r"""Check the sanity of the arguments.
 72
 73        **Raises:**
 74        - **ValueError**: If the regularisation factor is not positive.
 75        """
 76
 77        if self.parameter_change_reg_factor <= 0:
 78            raise ValueError(
 79                f"The parameter change regularisation factor should be positive, but got {self.parameter_change_reg_factor}."
 80            )
 81
 82    def calculate_parameter_importance(self) -> None:
 83        r"""Calculate the parameter importance for the learned task. This is only called after the training of a task, which is the last previous task $t-1$. The calculated importance is stored in `self.parameter_importance[self.task_id]` for constructing the regularisation loss in the future tasks.
 84
 85        According to [the EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114), the importance tensor is a Laplace approximation to Fisher information matrix by taking the digonal, i.e. $F_i$, where $i$ is the index of a parameter. The calculation is not following that theory but the derived formula below:
 86
 87        $$\omega_i = F_i  =\frac{1}{N_{t-1}} \sum_{(\mathbf{x}, y)\in \mathcal{D}^{(t-1)}_{\text{train}}} \left[\frac{\partial l(f^{(t-1)}\left(\mathbf{x}, \theta), y\right)}{\partial \theta_i}\right]^2$$
 88
 89        For a parameter $i$, its importance is the magnitude (square here) of gradient of the loss of model just trained over the training data just used. The $l$ is the classification loss. It shows the sensitivity of the loss to the parameter. The larger it is, the more it changed the performance (which is the loss) of the model, which indicates the importance of the parameter.
 90        """
 91        parameter_importance_t = {}
 92
 93        # set model to evaluation mode to prevent updating the model parameters
 94        self.eval()
 95
 96        # get the training data
 97        last_task_train_dataloaders = self.trainer.datamodule.train_dataloader()
 98
 99        # initialise the accumulation of the squared gradients
100        num_data = 0
101        for param_name, param in self.backbone.named_parameters():
102            parameter_importance_t[param_name] = torch.zeros_like(param)
103
104        for x, y in last_task_train_dataloaders:
105
106            # move data to device manually
107            x = x.to(self.device)
108            y = y.to(self.device)
109
110            batch_size = len(y)
111            num_data += batch_size
112
113            # compute the gradients within a batch
114            self.backbone.zero_grad()  # reset gradients
115            logits, _ = self.forward(x, stage="train", task_id=self.task_id)
116            loss_cls = self.criterion(logits, y)
117            loss_cls.backward()  # compute gradients
118
119            # collect and accumulate the squared gradients into parameter importance
120            for param_name, param in self.backbone.named_parameters():
121                parameter_importance_t[param_name] += batch_size * param.grad**2
122
123        num_params = sum(p.numel() for p in self.backbone.parameters())
124
125        for param_name, param in self.backbone.named_parameters():
126            parameter_importance_t[param_name] /= (
127                num_data * num_params
128            )  # average over data and parameters
129
130        self.parameter_importance[self.task_id] = parameter_importance_t
131
132    def training_step(self, batch: Any) -> dict[str, Tensor]:
133        r"""Training step for current task `self.task_id`.
134
135        **Args:**
136        - **batch** (`Any`): a batch of training data.
137
138        **Returns:**
139        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
140        """
141        x, y = batch
142
143        # classification loss
144        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
145        loss_cls = self.criterion(logits, y)
146
147        # regularisation loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
148        loss_reg = 0.0
149        for previous_task_id in range(1, self.task_id):
150            # sum over all previous models, because [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) says: "When moving to a third task, task C, EWC will try to keep the network parameters close to the learned parameters of both tasks A and B. This can be enforced either with two separate penalties or as one by noting that the sum of two quadratic penalties is itself a quadratic penalty."
151            loss_reg += self.parameter_change_reg(
152                target_model=self.backbone,
153                ref_model=self.previous_task_backbones[previous_task_id],
154                weights=self.parameter_importance[previous_task_id],
155            )
156
157        if self.task_id != 1:
158            loss_reg /= (
159                self.task_id
160            )  # average over tasks to avoid linear increase of the regularisation loss
161
162        # total loss
163        loss = loss_cls + loss_reg
164
165        # accuracy of the batch
166        acc = (logits.argmax(dim=1) == y).float().mean()
167
168        return {
169            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
170            "loss_cls": loss_cls,
171            "loss_reg": loss_reg,
172            "acc": acc,
173            "hidden_features": hidden_features,
174        }
175
176    def on_train_end(self) -> None:
177        r"""Calculate the parameter importance and store the backbone model after the training of a task.
178
179        The calculated importance and model are stored in `self.parameter_importance[self.task_id]` and `self.previous_task_backbones[self.task_id]` respectively for constructing the regularisation loss in the future tasks.
180        """
181        self.calculate_parameter_importance()
182
183        previous_backbone = deepcopy(self.backbone)
184        previous_backbone.eval()  # set the store model to evaluation mode to prevent updating
185        self.previous_task_backbones[self.task_id] = previous_backbone
 24class EWC(Finetuning):
 25    r"""EWC (Elastic Weight Consolidation) algorithm.
 26
 27    [EWC (Elastic Weight Consolidation, 2017)](https://www.pnas.org/doi/10.1073/pnas.1611835114) is a regularisation-based continual learning approach that calculates parameter importance for the previous tasks and penalises the current task loss with the importance of the parameters.
 28
 29    We implement EWC as a subclass of Finetuning algorithm, as EWC has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 30    """
 31
 32    def __init__(
 33        self,
 34        backbone: CLBackbone,
 35        heads: HeadsTIL | HeadsCIL,
 36        parameter_change_reg_factor: float,
 37        parameter_change_reg_p_norm: float,
 38    ) -> None:
 39        r"""Initialise the HAT algorithm with the network.
 40
 41        **Args:**
 42        - **backbone** (`CLBackbone`): backbone network.
 43        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 44        - **parameter_change_reg_factor** (`float`): the parameter change regularisation factor. It controls the strength of preventing forgetting.
 45        - **parameter_change_reg_p_norm** (`float`): the norm of the distance of parameters between previous tasks and current task in the parameter change regularisation.
 46
 47        """
 48        Finetuning.__init__(self, backbone=backbone, heads=heads)
 49
 50        self.parameter_importance: dict[str, dict[str, Tensor]] = {}
 51        r"""Store the parameter importance of each previous task. Keys are task IDs (string type) and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.
 52        """
 53        self.previous_task_backbones: dict[str, nn.Module] = {}
 54        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
 55        
 56        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.
 57        """
 58
 59        self.parameter_change_reg_factor = parameter_change_reg_factor
 60        r"""Store parameter change regularisation factor."""
 61        self.parameter_change_reg_p_norm = parameter_change_reg_p_norm
 62        r"""Store norm of the distance used in parameter change regularisation."""
 63        self.parameter_change_reg = ParameterChangeReg(
 64            factor=parameter_change_reg_factor,
 65            p_norm=parameter_change_reg_p_norm,
 66        )
 67        r"""Initialise and store the parameter change regulariser."""
 68
 69        EWC.sanity_check(self)
 70
 71    def sanity_check(self) -> None:
 72        r"""Check the sanity of the arguments.
 73
 74        **Raises:**
 75        - **ValueError**: If the regularisation factor is not positive.
 76        """
 77
 78        if self.parameter_change_reg_factor <= 0:
 79            raise ValueError(
 80                f"The parameter change regularisation factor should be positive, but got {self.parameter_change_reg_factor}."
 81            )
 82
 83    def calculate_parameter_importance(self) -> None:
 84        r"""Calculate the parameter importance for the learned task. This is only called after the training of a task, which is the last previous task $t-1$. The calculated importance is stored in `self.parameter_importance[self.task_id]` for constructing the regularisation loss in the future tasks.
 85
 86        According to [the EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114), the importance tensor is a Laplace approximation to Fisher information matrix by taking the digonal, i.e. $F_i$, where $i$ is the index of a parameter. The calculation is not following that theory but the derived formula below:
 87
 88        $$\omega_i = F_i  =\frac{1}{N_{t-1}} \sum_{(\mathbf{x}, y)\in \mathcal{D}^{(t-1)}_{\text{train}}} \left[\frac{\partial l(f^{(t-1)}\left(\mathbf{x}, \theta), y\right)}{\partial \theta_i}\right]^2$$
 89
 90        For a parameter $i$, its importance is the magnitude (square here) of gradient of the loss of model just trained over the training data just used. The $l$ is the classification loss. It shows the sensitivity of the loss to the parameter. The larger it is, the more it changed the performance (which is the loss) of the model, which indicates the importance of the parameter.
 91        """
 92        parameter_importance_t = {}
 93
 94        # set model to evaluation mode to prevent updating the model parameters
 95        self.eval()
 96
 97        # get the training data
 98        last_task_train_dataloaders = self.trainer.datamodule.train_dataloader()
 99
100        # initialise the accumulation of the squared gradients
101        num_data = 0
102        for param_name, param in self.backbone.named_parameters():
103            parameter_importance_t[param_name] = torch.zeros_like(param)
104
105        for x, y in last_task_train_dataloaders:
106
107            # move data to device manually
108            x = x.to(self.device)
109            y = y.to(self.device)
110
111            batch_size = len(y)
112            num_data += batch_size
113
114            # compute the gradients within a batch
115            self.backbone.zero_grad()  # reset gradients
116            logits, _ = self.forward(x, stage="train", task_id=self.task_id)
117            loss_cls = self.criterion(logits, y)
118            loss_cls.backward()  # compute gradients
119
120            # collect and accumulate the squared gradients into parameter importance
121            for param_name, param in self.backbone.named_parameters():
122                parameter_importance_t[param_name] += batch_size * param.grad**2
123
124        num_params = sum(p.numel() for p in self.backbone.parameters())
125
126        for param_name, param in self.backbone.named_parameters():
127            parameter_importance_t[param_name] /= (
128                num_data * num_params
129            )  # average over data and parameters
130
131        self.parameter_importance[self.task_id] = parameter_importance_t
132
133    def training_step(self, batch: Any) -> dict[str, Tensor]:
134        r"""Training step for current task `self.task_id`.
135
136        **Args:**
137        - **batch** (`Any`): a batch of training data.
138
139        **Returns:**
140        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
141        """
142        x, y = batch
143
144        # classification loss
145        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
146        loss_cls = self.criterion(logits, y)
147
148        # regularisation loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
149        loss_reg = 0.0
150        for previous_task_id in range(1, self.task_id):
151            # sum over all previous models, because [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) says: "When moving to a third task, task C, EWC will try to keep the network parameters close to the learned parameters of both tasks A and B. This can be enforced either with two separate penalties or as one by noting that the sum of two quadratic penalties is itself a quadratic penalty."
152            loss_reg += self.parameter_change_reg(
153                target_model=self.backbone,
154                ref_model=self.previous_task_backbones[previous_task_id],
155                weights=self.parameter_importance[previous_task_id],
156            )
157
158        if self.task_id != 1:
159            loss_reg /= (
160                self.task_id
161            )  # average over tasks to avoid linear increase of the regularisation loss
162
163        # total loss
164        loss = loss_cls + loss_reg
165
166        # accuracy of the batch
167        acc = (logits.argmax(dim=1) == y).float().mean()
168
169        return {
170            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
171            "loss_cls": loss_cls,
172            "loss_reg": loss_reg,
173            "acc": acc,
174            "hidden_features": hidden_features,
175        }
176
177    def on_train_end(self) -> None:
178        r"""Calculate the parameter importance and store the backbone model after the training of a task.
179
180        The calculated importance and model are stored in `self.parameter_importance[self.task_id]` and `self.previous_task_backbones[self.task_id]` respectively for constructing the regularisation loss in the future tasks.
181        """
182        self.calculate_parameter_importance()
183
184        previous_backbone = deepcopy(self.backbone)
185        previous_backbone.eval()  # set the store model to evaluation mode to prevent updating
186        self.previous_task_backbones[self.task_id] = previous_backbone

EWC (Elastic Weight Consolidation) algorithm.

EWC (Elastic Weight Consolidation, 2017) is a regularisation-based continual learning approach that calculates parameter importance for the previous tasks and penalises the current task loss with the importance of the parameters.

We implement EWC as a subclass of Finetuning algorithm, as EWC has the same forward(), validation_step() and test_step() method as Finetuning class.

EWC( backbone: clarena.backbones.CLBackbone, heads: clarena.cl_heads.HeadsTIL | clarena.cl_heads.HeadsCIL, parameter_change_reg_factor: float, parameter_change_reg_p_norm: float)
32    def __init__(
33        self,
34        backbone: CLBackbone,
35        heads: HeadsTIL | HeadsCIL,
36        parameter_change_reg_factor: float,
37        parameter_change_reg_p_norm: float,
38    ) -> None:
39        r"""Initialise the HAT algorithm with the network.
40
41        **Args:**
42        - **backbone** (`CLBackbone`): backbone network.
43        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
44        - **parameter_change_reg_factor** (`float`): the parameter change regularisation factor. It controls the strength of preventing forgetting.
45        - **parameter_change_reg_p_norm** (`float`): the norm of the distance of parameters between previous tasks and current task in the parameter change regularisation.
46
47        """
48        Finetuning.__init__(self, backbone=backbone, heads=heads)
49
50        self.parameter_importance: dict[str, dict[str, Tensor]] = {}
51        r"""Store the parameter importance of each previous task. Keys are task IDs (string type) and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.
52        """
53        self.previous_task_backbones: dict[str, nn.Module] = {}
54        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
55        
56        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.
57        """
58
59        self.parameter_change_reg_factor = parameter_change_reg_factor
60        r"""Store parameter change regularisation factor."""
61        self.parameter_change_reg_p_norm = parameter_change_reg_p_norm
62        r"""Store norm of the distance used in parameter change regularisation."""
63        self.parameter_change_reg = ParameterChangeReg(
64            factor=parameter_change_reg_factor,
65            p_norm=parameter_change_reg_p_norm,
66        )
67        r"""Initialise and store the parameter change regulariser."""
68
69        EWC.sanity_check(self)

Initialise the HAT algorithm with the network.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL): output heads.
  • parameter_change_reg_factor (float): the parameter change regularisation factor. It controls the strength of preventing forgetting.
  • parameter_change_reg_p_norm (float): the norm of the distance of parameters between previous tasks and current task in the parameter change regularisation.
parameter_importance: dict[str, dict[str, torch.Tensor]]

Store the parameter importance of each previous task. Keys are task IDs (string type) and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by named_parameters() of the nn.Module) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.

previous_task_backbones: dict[str, torch.nn.modules.module.Module]

Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a nn.Module backbone after the corresponding previous task was trained.

Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.

parameter_change_reg_factor

Store parameter change regularisation factor.

parameter_change_reg_p_norm

Store norm of the distance used in parameter change regularisation.

parameter_change_reg

Initialise and store the parameter change regulariser.

def sanity_check(self) -> None:
71    def sanity_check(self) -> None:
72        r"""Check the sanity of the arguments.
73
74        **Raises:**
75        - **ValueError**: If the regularisation factor is not positive.
76        """
77
78        if self.parameter_change_reg_factor <= 0:
79            raise ValueError(
80                f"The parameter change regularisation factor should be positive, but got {self.parameter_change_reg_factor}."
81            )

Check the sanity of the arguments.

Raises:

  • ValueError: If the regularisation factor is not positive.
def calculate_parameter_importance(self) -> None:
 83    def calculate_parameter_importance(self) -> None:
 84        r"""Calculate the parameter importance for the learned task. This is only called after the training of a task, which is the last previous task $t-1$. The calculated importance is stored in `self.parameter_importance[self.task_id]` for constructing the regularisation loss in the future tasks.
 85
 86        According to [the EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114), the importance tensor is a Laplace approximation to Fisher information matrix by taking the digonal, i.e. $F_i$, where $i$ is the index of a parameter. The calculation is not following that theory but the derived formula below:
 87
 88        $$\omega_i = F_i  =\frac{1}{N_{t-1}} \sum_{(\mathbf{x}, y)\in \mathcal{D}^{(t-1)}_{\text{train}}} \left[\frac{\partial l(f^{(t-1)}\left(\mathbf{x}, \theta), y\right)}{\partial \theta_i}\right]^2$$
 89
 90        For a parameter $i$, its importance is the magnitude (square here) of gradient of the loss of model just trained over the training data just used. The $l$ is the classification loss. It shows the sensitivity of the loss to the parameter. The larger it is, the more it changed the performance (which is the loss) of the model, which indicates the importance of the parameter.
 91        """
 92        parameter_importance_t = {}
 93
 94        # set model to evaluation mode to prevent updating the model parameters
 95        self.eval()
 96
 97        # get the training data
 98        last_task_train_dataloaders = self.trainer.datamodule.train_dataloader()
 99
100        # initialise the accumulation of the squared gradients
101        num_data = 0
102        for param_name, param in self.backbone.named_parameters():
103            parameter_importance_t[param_name] = torch.zeros_like(param)
104
105        for x, y in last_task_train_dataloaders:
106
107            # move data to device manually
108            x = x.to(self.device)
109            y = y.to(self.device)
110
111            batch_size = len(y)
112            num_data += batch_size
113
114            # compute the gradients within a batch
115            self.backbone.zero_grad()  # reset gradients
116            logits, _ = self.forward(x, stage="train", task_id=self.task_id)
117            loss_cls = self.criterion(logits, y)
118            loss_cls.backward()  # compute gradients
119
120            # collect and accumulate the squared gradients into parameter importance
121            for param_name, param in self.backbone.named_parameters():
122                parameter_importance_t[param_name] += batch_size * param.grad**2
123
124        num_params = sum(p.numel() for p in self.backbone.parameters())
125
126        for param_name, param in self.backbone.named_parameters():
127            parameter_importance_t[param_name] /= (
128                num_data * num_params
129            )  # average over data and parameters
130
131        self.parameter_importance[self.task_id] = parameter_importance_t

Calculate the parameter importance for the learned task. This is only called after the training of a task, which is the last previous task $t-1$. The calculated importance is stored in self.parameter_importance[self.task_id] for constructing the regularisation loss in the future tasks.

According to the EWC paper, the importance tensor is a Laplace approximation to Fisher information matrix by taking the digonal, i.e. $F_i$, where $i$ is the index of a parameter. The calculation is not following that theory but the derived formula below:

$$\omega_i = F_i =\frac{1}{N_{t-1}} \sum_{(\mathbf{x}, y)\in \mathcal{D}^{(t-1)}_{\text{train}}} \left[\frac{\partial l(f^{(t-1)}\left(\mathbf{x}, \theta), y\right)}{\partial \theta_i}\right]^2$$

For a parameter $i$, its importance is the magnitude (square here) of gradient of the loss of model just trained over the training data just used. The $l$ is the classification loss. It shows the sensitivity of the loss to the parameter. The larger it is, the more it changed the performance (which is the loss) of the model, which indicates the importance of the parameter.

def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
133    def training_step(self, batch: Any) -> dict[str, Tensor]:
134        r"""Training step for current task `self.task_id`.
135
136        **Args:**
137        - **batch** (`Any`): a batch of training data.
138
139        **Returns:**
140        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
141        """
142        x, y = batch
143
144        # classification loss
145        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
146        loss_cls = self.criterion(logits, y)
147
148        # regularisation loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
149        loss_reg = 0.0
150        for previous_task_id in range(1, self.task_id):
151            # sum over all previous models, because [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114) says: "When moving to a third task, task C, EWC will try to keep the network parameters close to the learned parameters of both tasks A and B. This can be enforced either with two separate penalties or as one by noting that the sum of two quadratic penalties is itself a quadratic penalty."
152            loss_reg += self.parameter_change_reg(
153                target_model=self.backbone,
154                ref_model=self.previous_task_backbones[previous_task_id],
155                weights=self.parameter_importance[previous_task_id],
156            )
157
158        if self.task_id != 1:
159            loss_reg /= (
160                self.task_id
161            )  # average over tasks to avoid linear increase of the regularisation loss
162
163        # total loss
164        loss = loss_cls + loss_reg
165
166        # accuracy of the batch
167        acc = (logits.argmax(dim=1) == y).float().mean()
168
169        return {
170            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
171            "loss_cls": loss_cls,
172            "loss_reg": loss_reg,
173            "acc": acc,
174            "hidden_features": hidden_features,
175        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Key (str) is the metrics name, value (Tensor) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def on_train_end(self) -> None:
177    def on_train_end(self) -> None:
178        r"""Calculate the parameter importance and store the backbone model after the training of a task.
179
180        The calculated importance and model are stored in `self.parameter_importance[self.task_id]` and `self.previous_task_backbones[self.task_id]` respectively for constructing the regularisation loss in the future tasks.
181        """
182        self.calculate_parameter_importance()
183
184        previous_backbone = deepcopy(self.backbone)
185        previous_backbone.eval()  # set the store model to evaluation mode to prevent updating
186        self.previous_task_backbones[self.task_id] = previous_backbone

Calculate the parameter importance and store the backbone model after the training of a task.

The calculated importance and model are stored in self.parameter_importance[self.task_id] and self.previous_task_backbones[self.task_id] respectively for constructing the regularisation loss in the future tasks.