clarena.cl_algorithms.ewc

The submodule in cl_algorithms for EWC (Elastic Weight Consolidation) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114).
  3"""
  4
  5__all__ = ["EWC", "AmnesiacEWC"]
  6
  7import logging
  8from copy import deepcopy
  9from typing import Any
 10
 11import torch
 12from torch import Tensor, nn
 13
 14from clarena.backbones import CLBackbone
 15from clarena.cl_algorithms import AmnesiacCLAlgorithm, Finetuning
 16from clarena.cl_algorithms.regularizers import ParameterChangeReg
 17from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23class EWC(Finetuning):
 24    r"""[EWC (Elastic Weight Consolidation)](https://www.pnas.org/doi/10.1073/pnas.1611835114) algorithm.
 25
 26    A regularization-based approach that calculates the fisher information as parameter importance for the previous tasks and penalizes the current task loss with the importance of the parameters.
 27
 28    We implement EWC as a subclass of Finetuning algorithm, as EWC has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: CLBackbone,
 34        heads: HeadsTIL | HeadsCIL | HeadDIL,
 35        parameter_change_reg_factor: float,
 36        when_calculate_fisher_information: str,
 37        non_algorithmic_hparams: dict[str, Any] = {},
 38        **kwargs,
 39    ) -> None:
 40        r"""Initialize the EWC algorithm with the network.
 41
 42        **Args:**
 43        - **backbone** (`CLBackbone`): backbone network.
 44        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
 45        - **parameter_change_reg_factor** (`float`): the parameter change regularization factor. It controls the strength of preventing forgetting.
 46        - **when_calculate_fisher_information** (`str`): when to calculate the fisher information. It should be one of the following:
 47            1. 'train_end': calculate the fisher information at the end of training of the task.
 48            2. 'train': accumulate the fisher information in the training step of the task.
 49        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 50        - **kwargs**: Reserved for multiple inheritance.
 51
 52        """
 53        super().__init__(
 54            backbone=backbone,
 55            heads=heads,
 56            non_algorithmic_hparams=non_algorithmic_hparams,
 57            **kwargs,
 58        )
 59
 60        # save additional algorithmic hyperparameters
 61        self.save_hyperparameters(
 62            "parameter_change_reg_factor",
 63            "when_calculate_fisher_information",
 64        )
 65
 66        self.parameter_importance: dict[int, dict[str, Tensor]] = {}
 67        r"""The parameter importance of each previous task. Keys are task IDs and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.
 68        """
 69
 70        self.previous_task_backbones: dict[int, nn.Module] = {}
 71        r"""The backbone models of the previous tasks. Keys are task IDs and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
 72
 73        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.
 74        """
 75        self.parameter_importance_heads: dict[int, dict[str, Tensor]] = {}
 76        r"""The head parameter importance of each previous task (DIL only)."""
 77        self.previous_task_heads: dict[int, nn.Module] = {}
 78        r"""The head models of the previous tasks (DIL only)."""
 79
 80        self.parameter_change_reg_factor = parameter_change_reg_factor
 81        r"""The parameter change regularization factor."""
 82        self.parameter_change_reg = ParameterChangeReg(
 83            factor=parameter_change_reg_factor,
 84        )
 85        r"""Initialize and store the parameter change regularizer."""
 86
 87        self.when_calculate_fisher_information: str = when_calculate_fisher_information
 88        r"""When to calculate the fisher information."""
 89        self.num_data: int
 90        r"""The number of data used to calculate the fisher information. It is used to average the fisher information over the data."""
 91
 92        # set manual optimization because we need to access gradients to calculate the fisher information in the training step
 93        self.automatic_optimization = False
 94
 95        EWC.sanity_check(self)
 96
 97    def sanity_check(self) -> None:
 98        r"""Sanity check."""
 99        if self.parameter_change_reg_factor <= 0:
100            raise ValueError(
101                f"The parameter change regularization factor should be positive, but got {self.parameter_change_reg_factor}."
102            )
103
104    def on_train_start(self) -> None:
105        r"""Initialize the parameter importance and num of data counter."""
106
107        self.parameter_importance[self.task_id] = {}
108        for param_name, param in self.backbone.named_parameters():
109            self.parameter_importance[self.task_id][param_name] = 0 * param.data
110        if isinstance(self.heads, HeadDIL):
111            self.parameter_importance_heads[self.task_id] = {}
112            for param_name, param in self.heads.named_parameters():
113                self.parameter_importance_heads[self.task_id][param_name] = (
114                    0 * param.data
115                )
116        self.num_data = 0
117
118    def training_step(self, batch: Any) -> dict[str, Tensor]:
119        r"""Training step for current task `self.task_id`."""
120        x, y = batch
121
122        opt = self.optimizers()
123        opt.zero_grad()
124
125        # classification loss
126        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
127        loss_cls = self.criterion(logits, y)
128
129        batch_size = len(y)
130        self.num_data += batch_size
131
132        # accumulate fisher information during training step if specified
133        if self.when_calculate_fisher_information == "train":
134            # Use autograd.grad to get explicit gradients for Fisher accumulation without
135            # relying on global .backward() state in manual optimization.
136            backbone_params: list[tuple[str, Tensor]] = []
137            for param_name, param in self.backbone.named_parameters():
138                if not param.requires_grad:
139                    continue
140                backbone_params.append((param_name, param))
141            if isinstance(self.heads, HeadDIL):
142                head_params: list[tuple[str, Tensor]] = []
143                for param_name, param in self.heads.named_parameters():
144                    if not param.requires_grad:
145                        continue
146                    head_params.append((param_name, param))
147            else:
148                head_params = []
149
150            if backbone_params:
151                grads = torch.autograd.grad(
152                    loss_cls,
153                    [param for _, param in backbone_params],
154                    retain_graph=True,
155                    allow_unused=True,
156                )
157                for (param_name, _), grad in zip(backbone_params, grads):
158                    if grad is None:
159                        continue
160                    self.parameter_importance[self.task_id][param_name] += (
161                        batch_size * grad.detach() ** 2
162                    )
163
164            if head_params:
165                grads = torch.autograd.grad(
166                    loss_cls,
167                    [param for _, param in head_params],
168                    retain_graph=True,
169                    allow_unused=True,
170                )
171                for (param_name, _), grad in zip(head_params, grads):
172                    if grad is None:
173                        continue
174                    self.parameter_importance_heads[self.task_id][param_name] += (
175                        batch_size * grad.detach() ** 2
176                    )
177
178        # regularization loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114)
179        ewc_reg = 0.0
180        for previous_task_id, previous_backbone in self.previous_task_backbones.items():
181            ewc_reg += 0.5 * self.parameter_change_reg(
182                target_model=self.backbone,
183                ref_model=previous_backbone,
184                weights=self.parameter_importance[previous_task_id],
185            )
186            if isinstance(self.heads, HeadDIL):
187                ewc_reg += 0.5 * self.parameter_change_reg(
188                    target_model=self.heads,
189                    ref_model=self.previous_task_heads[previous_task_id],
190                    weights=self.parameter_importance_heads[previous_task_id],
191                )
192
193        # total loss
194        loss = loss_cls + ewc_reg
195
196        self.manual_backward(loss)
197        opt.step()
198
199        # predicted labels
200        preds = logits.argmax(dim=1)
201
202        # accuracy of the batch
203        acc = (preds == y).float().mean()
204
205        return {
206            "preds": preds,
207            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
208            "loss_cls": loss_cls,
209            "ewc_reg": ewc_reg,
210            "acc": acc,
211            "activations": activations,
212        }
213
214    def on_train_end(self) -> None:
215        r"""Calculate the fisher information as parameter importance and store the backbone model after the training of a task."""
216
217        # calculate fisher information at the end of training if specified
218        if self.when_calculate_fisher_information == "train_end":
219            fisher, fisher_heads, fisher_num_data = (
220                self.accumulate_fisher_information_on_train_end()
221            )
222            self.parameter_importance[self.task_id] = fisher
223            if fisher_heads is not None:
224                self.parameter_importance_heads[self.task_id] = fisher_heads
225            num_data = fisher_num_data
226        else:
227            num_data = self.num_data
228
229        # no matter when we calculate the fisher information, we need to average it over the number of data
230        for param_name, param in self.backbone.named_parameters():
231            self.parameter_importance[self.task_id][param_name] /= num_data
232        if isinstance(self.heads, HeadDIL):
233            for param_name, param in self.heads.named_parameters():
234                self.parameter_importance_heads[self.task_id][param_name] /= num_data
235
236        # store the backbone model after training the task
237        previous_backbone = deepcopy(self.backbone)
238        previous_backbone.eval()
239        self.previous_task_backbones[self.task_id] = previous_backbone
240        if isinstance(self.heads, HeadDIL):
241            previous_heads = deepcopy(self.heads)
242            previous_heads.eval()
243            self.previous_task_heads[self.task_id] = previous_heads
244
245    def accumulate_fisher_information_on_train_end(
246        self,
247    ) -> tuple[dict[str, Tensor], dict[str, Tensor] | None, int]:
248        r"""Accumulate the fisher information as the parameter importance for the learned task `self.task_id` at the end of its training."""
249        fisher_information_t = {}
250        fisher_information_heads: dict[str, Tensor] | None = None
251        num_data = 0
252
253        self.eval()
254        last_task_train_dataloaders = self.trainer.datamodule.train_dataloader()
255
256        for param_name, param in self.backbone.named_parameters():
257            fisher_information_t[param_name] = torch.zeros_like(param)
258        if isinstance(self.heads, HeadDIL):
259            fisher_information_heads = {}
260            for param_name, param in self.heads.named_parameters():
261                fisher_information_heads[param_name] = torch.zeros_like(param)
262
263        for x, y in last_task_train_dataloaders:
264            x = x.to(self.device)
265            y = y.to(self.device)
266            batch_size = len(y)
267            num_data += batch_size
268
269            self.backbone.zero_grad()
270            if isinstance(self.heads, HeadDIL):
271                self.heads.zero_grad()
272            logits, _ = self.forward(x, stage="train", task_id=self.task_id)
273            loss_cls = self.criterion(logits, y)
274            loss_cls.backward()
275
276            for param_name, param in self.backbone.named_parameters():
277                fisher_information_t[param_name] += batch_size * param.grad**2
278            if fisher_information_heads is not None:
279                for param_name, param in self.heads.named_parameters():
280                    if param.grad is None:
281                        continue
282                    fisher_information_heads[param_name] += batch_size * param.grad**2
283
284        return fisher_information_t, fisher_information_heads, num_data
285
286
287class AmnesiacEWC(AmnesiacCLAlgorithm, EWC):
288    r"""Amnesiac EWC algorithm."""
289
290    def __init__(
291        self,
292        backbone: CLBackbone,
293        heads: HeadsTIL | HeadsCIL | HeadDIL,
294        non_algorithmic_hparams: dict[str, Any] = {},
295        disable_unlearning: bool = False,
296        **kwargs,
297    ) -> None:
298        r"""Initialize the Amnesiac EWC algorithm with the network.
299
300        **Args:**
301        - **backbone** (`CLBackbone`): backbone network.
302        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
303        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
304        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
305        - **kwargs**: Reserved for multiple inheritance.
306        """
307        super().__init__(
308            backbone=backbone,
309            heads=heads,
310            non_algorithmic_hparams=non_algorithmic_hparams,
311            disable_unlearning=disable_unlearning,
312            **kwargs,
313        )
314
315    def on_train_start(self) -> None:
316        """Record backbone parameters before training current task."""
317        EWC.on_train_start(self)
318        AmnesiacCLAlgorithm.on_train_start(self)
319
320    def on_train_end(self) -> None:
321        """Record backbone parameters before training current task."""
322        EWC.on_train_end(self)
323        AmnesiacCLAlgorithm.on_train_end(self)
 24class EWC(Finetuning):
 25    r"""[EWC (Elastic Weight Consolidation)](https://www.pnas.org/doi/10.1073/pnas.1611835114) algorithm.
 26
 27    A regularization-based approach that calculates the fisher information as parameter importance for the previous tasks and penalizes the current task loss with the importance of the parameters.
 28
 29    We implement EWC as a subclass of Finetuning algorithm, as EWC has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 30    """
 31
 32    def __init__(
 33        self,
 34        backbone: CLBackbone,
 35        heads: HeadsTIL | HeadsCIL | HeadDIL,
 36        parameter_change_reg_factor: float,
 37        when_calculate_fisher_information: str,
 38        non_algorithmic_hparams: dict[str, Any] = {},
 39        **kwargs,
 40    ) -> None:
 41        r"""Initialize the EWC algorithm with the network.
 42
 43        **Args:**
 44        - **backbone** (`CLBackbone`): backbone network.
 45        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
 46        - **parameter_change_reg_factor** (`float`): the parameter change regularization factor. It controls the strength of preventing forgetting.
 47        - **when_calculate_fisher_information** (`str`): when to calculate the fisher information. It should be one of the following:
 48            1. 'train_end': calculate the fisher information at the end of training of the task.
 49            2. 'train': accumulate the fisher information in the training step of the task.
 50        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 51        - **kwargs**: Reserved for multiple inheritance.
 52
 53        """
 54        super().__init__(
 55            backbone=backbone,
 56            heads=heads,
 57            non_algorithmic_hparams=non_algorithmic_hparams,
 58            **kwargs,
 59        )
 60
 61        # save additional algorithmic hyperparameters
 62        self.save_hyperparameters(
 63            "parameter_change_reg_factor",
 64            "when_calculate_fisher_information",
 65        )
 66
 67        self.parameter_importance: dict[int, dict[str, Tensor]] = {}
 68        r"""The parameter importance of each previous task. Keys are task IDs and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.
 69        """
 70
 71        self.previous_task_backbones: dict[int, nn.Module] = {}
 72        r"""The backbone models of the previous tasks. Keys are task IDs and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
 73
 74        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.
 75        """
 76        self.parameter_importance_heads: dict[int, dict[str, Tensor]] = {}
 77        r"""The head parameter importance of each previous task (DIL only)."""
 78        self.previous_task_heads: dict[int, nn.Module] = {}
 79        r"""The head models of the previous tasks (DIL only)."""
 80
 81        self.parameter_change_reg_factor = parameter_change_reg_factor
 82        r"""The parameter change regularization factor."""
 83        self.parameter_change_reg = ParameterChangeReg(
 84            factor=parameter_change_reg_factor,
 85        )
 86        r"""Initialize and store the parameter change regularizer."""
 87
 88        self.when_calculate_fisher_information: str = when_calculate_fisher_information
 89        r"""When to calculate the fisher information."""
 90        self.num_data: int
 91        r"""The number of data used to calculate the fisher information. It is used to average the fisher information over the data."""
 92
 93        # set manual optimization because we need to access gradients to calculate the fisher information in the training step
 94        self.automatic_optimization = False
 95
 96        EWC.sanity_check(self)
 97
 98    def sanity_check(self) -> None:
 99        r"""Sanity check."""
100        if self.parameter_change_reg_factor <= 0:
101            raise ValueError(
102                f"The parameter change regularization factor should be positive, but got {self.parameter_change_reg_factor}."
103            )
104
105    def on_train_start(self) -> None:
106        r"""Initialize the parameter importance and num of data counter."""
107
108        self.parameter_importance[self.task_id] = {}
109        for param_name, param in self.backbone.named_parameters():
110            self.parameter_importance[self.task_id][param_name] = 0 * param.data
111        if isinstance(self.heads, HeadDIL):
112            self.parameter_importance_heads[self.task_id] = {}
113            for param_name, param in self.heads.named_parameters():
114                self.parameter_importance_heads[self.task_id][param_name] = (
115                    0 * param.data
116                )
117        self.num_data = 0
118
119    def training_step(self, batch: Any) -> dict[str, Tensor]:
120        r"""Training step for current task `self.task_id`."""
121        x, y = batch
122
123        opt = self.optimizers()
124        opt.zero_grad()
125
126        # classification loss
127        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
128        loss_cls = self.criterion(logits, y)
129
130        batch_size = len(y)
131        self.num_data += batch_size
132
133        # accumulate fisher information during training step if specified
134        if self.when_calculate_fisher_information == "train":
135            # Use autograd.grad to get explicit gradients for Fisher accumulation without
136            # relying on global .backward() state in manual optimization.
137            backbone_params: list[tuple[str, Tensor]] = []
138            for param_name, param in self.backbone.named_parameters():
139                if not param.requires_grad:
140                    continue
141                backbone_params.append((param_name, param))
142            if isinstance(self.heads, HeadDIL):
143                head_params: list[tuple[str, Tensor]] = []
144                for param_name, param in self.heads.named_parameters():
145                    if not param.requires_grad:
146                        continue
147                    head_params.append((param_name, param))
148            else:
149                head_params = []
150
151            if backbone_params:
152                grads = torch.autograd.grad(
153                    loss_cls,
154                    [param for _, param in backbone_params],
155                    retain_graph=True,
156                    allow_unused=True,
157                )
158                for (param_name, _), grad in zip(backbone_params, grads):
159                    if grad is None:
160                        continue
161                    self.parameter_importance[self.task_id][param_name] += (
162                        batch_size * grad.detach() ** 2
163                    )
164
165            if head_params:
166                grads = torch.autograd.grad(
167                    loss_cls,
168                    [param for _, param in head_params],
169                    retain_graph=True,
170                    allow_unused=True,
171                )
172                for (param_name, _), grad in zip(head_params, grads):
173                    if grad is None:
174                        continue
175                    self.parameter_importance_heads[self.task_id][param_name] += (
176                        batch_size * grad.detach() ** 2
177                    )
178
179        # regularization loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114)
180        ewc_reg = 0.0
181        for previous_task_id, previous_backbone in self.previous_task_backbones.items():
182            ewc_reg += 0.5 * self.parameter_change_reg(
183                target_model=self.backbone,
184                ref_model=previous_backbone,
185                weights=self.parameter_importance[previous_task_id],
186            )
187            if isinstance(self.heads, HeadDIL):
188                ewc_reg += 0.5 * self.parameter_change_reg(
189                    target_model=self.heads,
190                    ref_model=self.previous_task_heads[previous_task_id],
191                    weights=self.parameter_importance_heads[previous_task_id],
192                )
193
194        # total loss
195        loss = loss_cls + ewc_reg
196
197        self.manual_backward(loss)
198        opt.step()
199
200        # predicted labels
201        preds = logits.argmax(dim=1)
202
203        # accuracy of the batch
204        acc = (preds == y).float().mean()
205
206        return {
207            "preds": preds,
208            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
209            "loss_cls": loss_cls,
210            "ewc_reg": ewc_reg,
211            "acc": acc,
212            "activations": activations,
213        }
214
215    def on_train_end(self) -> None:
216        r"""Calculate the fisher information as parameter importance and store the backbone model after the training of a task."""
217
218        # calculate fisher information at the end of training if specified
219        if self.when_calculate_fisher_information == "train_end":
220            fisher, fisher_heads, fisher_num_data = (
221                self.accumulate_fisher_information_on_train_end()
222            )
223            self.parameter_importance[self.task_id] = fisher
224            if fisher_heads is not None:
225                self.parameter_importance_heads[self.task_id] = fisher_heads
226            num_data = fisher_num_data
227        else:
228            num_data = self.num_data
229
230        # no matter when we calculate the fisher information, we need to average it over the number of data
231        for param_name, param in self.backbone.named_parameters():
232            self.parameter_importance[self.task_id][param_name] /= num_data
233        if isinstance(self.heads, HeadDIL):
234            for param_name, param in self.heads.named_parameters():
235                self.parameter_importance_heads[self.task_id][param_name] /= num_data
236
237        # store the backbone model after training the task
238        previous_backbone = deepcopy(self.backbone)
239        previous_backbone.eval()
240        self.previous_task_backbones[self.task_id] = previous_backbone
241        if isinstance(self.heads, HeadDIL):
242            previous_heads = deepcopy(self.heads)
243            previous_heads.eval()
244            self.previous_task_heads[self.task_id] = previous_heads
245
246    def accumulate_fisher_information_on_train_end(
247        self,
248    ) -> tuple[dict[str, Tensor], dict[str, Tensor] | None, int]:
249        r"""Accumulate the fisher information as the parameter importance for the learned task `self.task_id` at the end of its training."""
250        fisher_information_t = {}
251        fisher_information_heads: dict[str, Tensor] | None = None
252        num_data = 0
253
254        self.eval()
255        last_task_train_dataloaders = self.trainer.datamodule.train_dataloader()
256
257        for param_name, param in self.backbone.named_parameters():
258            fisher_information_t[param_name] = torch.zeros_like(param)
259        if isinstance(self.heads, HeadDIL):
260            fisher_information_heads = {}
261            for param_name, param in self.heads.named_parameters():
262                fisher_information_heads[param_name] = torch.zeros_like(param)
263
264        for x, y in last_task_train_dataloaders:
265            x = x.to(self.device)
266            y = y.to(self.device)
267            batch_size = len(y)
268            num_data += batch_size
269
270            self.backbone.zero_grad()
271            if isinstance(self.heads, HeadDIL):
272                self.heads.zero_grad()
273            logits, _ = self.forward(x, stage="train", task_id=self.task_id)
274            loss_cls = self.criterion(logits, y)
275            loss_cls.backward()
276
277            for param_name, param in self.backbone.named_parameters():
278                fisher_information_t[param_name] += batch_size * param.grad**2
279            if fisher_information_heads is not None:
280                for param_name, param in self.heads.named_parameters():
281                    if param.grad is None:
282                        continue
283                    fisher_information_heads[param_name] += batch_size * param.grad**2
284
285        return fisher_information_t, fisher_information_heads, num_data

EWC (Elastic Weight Consolidation) algorithm.

A regularization-based approach that calculates the fisher information as parameter importance for the previous tasks and penalizes the current task loss with the importance of the parameters.

We implement EWC as a subclass of Finetuning algorithm, as EWC has the same forward(), validation_step() and test_step() method as Finetuning class.

EWC( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, parameter_change_reg_factor: float, when_calculate_fisher_information: str, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
32    def __init__(
33        self,
34        backbone: CLBackbone,
35        heads: HeadsTIL | HeadsCIL | HeadDIL,
36        parameter_change_reg_factor: float,
37        when_calculate_fisher_information: str,
38        non_algorithmic_hparams: dict[str, Any] = {},
39        **kwargs,
40    ) -> None:
41        r"""Initialize the EWC algorithm with the network.
42
43        **Args:**
44        - **backbone** (`CLBackbone`): backbone network.
45        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
46        - **parameter_change_reg_factor** (`float`): the parameter change regularization factor. It controls the strength of preventing forgetting.
47        - **when_calculate_fisher_information** (`str`): when to calculate the fisher information. It should be one of the following:
48            1. 'train_end': calculate the fisher information at the end of training of the task.
49            2. 'train': accumulate the fisher information in the training step of the task.
50        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
51        - **kwargs**: Reserved for multiple inheritance.
52
53        """
54        super().__init__(
55            backbone=backbone,
56            heads=heads,
57            non_algorithmic_hparams=non_algorithmic_hparams,
58            **kwargs,
59        )
60
61        # save additional algorithmic hyperparameters
62        self.save_hyperparameters(
63            "parameter_change_reg_factor",
64            "when_calculate_fisher_information",
65        )
66
67        self.parameter_importance: dict[int, dict[str, Tensor]] = {}
68        r"""The parameter importance of each previous task. Keys are task IDs and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by `named_parameters()` of the `nn.Module`) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.
69        """
70
71        self.previous_task_backbones: dict[int, nn.Module] = {}
72        r"""The backbone models of the previous tasks. Keys are task IDs and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
73
74        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.
75        """
76        self.parameter_importance_heads: dict[int, dict[str, Tensor]] = {}
77        r"""The head parameter importance of each previous task (DIL only)."""
78        self.previous_task_heads: dict[int, nn.Module] = {}
79        r"""The head models of the previous tasks (DIL only)."""
80
81        self.parameter_change_reg_factor = parameter_change_reg_factor
82        r"""The parameter change regularization factor."""
83        self.parameter_change_reg = ParameterChangeReg(
84            factor=parameter_change_reg_factor,
85        )
86        r"""Initialize and store the parameter change regularizer."""
87
88        self.when_calculate_fisher_information: str = when_calculate_fisher_information
89        r"""When to calculate the fisher information."""
90        self.num_data: int
91        r"""The number of data used to calculate the fisher information. It is used to average the fisher information over the data."""
92
93        # set manual optimization because we need to access gradients to calculate the fisher information in the training step
94        self.automatic_optimization = False
95
96        EWC.sanity_check(self)

Initialize the EWC algorithm with the network.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL | HeadDIL): output heads.
  • parameter_change_reg_factor (float): the parameter change regularization factor. It controls the strength of preventing forgetting.
  • when_calculate_fisher_information (str): when to calculate the fisher information. It should be one of the following:
    1. 'train_end': calculate the fisher information at the end of training of the task.
    2. 'train': accumulate the fisher information in the training step of the task.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
parameter_importance: dict[int, dict[str, torch.Tensor]]

The parameter importance of each previous task. Keys are task IDs and values are the corresponding importance. Each importance entity is a dict where keys are parameter names (named by named_parameters() of the nn.Module) and values are the importance tensor for the layer. It has the same shape as the parameters of the layer.

previous_task_backbones: dict[int, torch.nn.modules.module.Module]

The backbone models of the previous tasks. Keys are task IDs and values are the corresponding models. Each model is a nn.Module backbone after the corresponding previous task was trained.

Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier EWC thing? The thing is, EWC only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use EWC anymore, which is a disadvantage for EWC.

parameter_importance_heads: dict[int, dict[str, torch.Tensor]]

The head parameter importance of each previous task (DIL only).

previous_task_heads: dict[int, torch.nn.modules.module.Module]

The head models of the previous tasks (DIL only).

parameter_change_reg_factor

The parameter change regularization factor.

parameter_change_reg

Initialize and store the parameter change regularizer.

when_calculate_fisher_information: str

When to calculate the fisher information.

num_data: int

The number of data used to calculate the fisher information. It is used to average the fisher information over the data.

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
 98    def sanity_check(self) -> None:
 99        r"""Sanity check."""
100        if self.parameter_change_reg_factor <= 0:
101            raise ValueError(
102                f"The parameter change regularization factor should be positive, but got {self.parameter_change_reg_factor}."
103            )

Sanity check.

def on_train_start(self) -> None:
105    def on_train_start(self) -> None:
106        r"""Initialize the parameter importance and num of data counter."""
107
108        self.parameter_importance[self.task_id] = {}
109        for param_name, param in self.backbone.named_parameters():
110            self.parameter_importance[self.task_id][param_name] = 0 * param.data
111        if isinstance(self.heads, HeadDIL):
112            self.parameter_importance_heads[self.task_id] = {}
113            for param_name, param in self.heads.named_parameters():
114                self.parameter_importance_heads[self.task_id][param_name] = (
115                    0 * param.data
116                )
117        self.num_data = 0

Initialize the parameter importance and num of data counter.

def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
119    def training_step(self, batch: Any) -> dict[str, Tensor]:
120        r"""Training step for current task `self.task_id`."""
121        x, y = batch
122
123        opt = self.optimizers()
124        opt.zero_grad()
125
126        # classification loss
127        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
128        loss_cls = self.criterion(logits, y)
129
130        batch_size = len(y)
131        self.num_data += batch_size
132
133        # accumulate fisher information during training step if specified
134        if self.when_calculate_fisher_information == "train":
135            # Use autograd.grad to get explicit gradients for Fisher accumulation without
136            # relying on global .backward() state in manual optimization.
137            backbone_params: list[tuple[str, Tensor]] = []
138            for param_name, param in self.backbone.named_parameters():
139                if not param.requires_grad:
140                    continue
141                backbone_params.append((param_name, param))
142            if isinstance(self.heads, HeadDIL):
143                head_params: list[tuple[str, Tensor]] = []
144                for param_name, param in self.heads.named_parameters():
145                    if not param.requires_grad:
146                        continue
147                    head_params.append((param_name, param))
148            else:
149                head_params = []
150
151            if backbone_params:
152                grads = torch.autograd.grad(
153                    loss_cls,
154                    [param for _, param in backbone_params],
155                    retain_graph=True,
156                    allow_unused=True,
157                )
158                for (param_name, _), grad in zip(backbone_params, grads):
159                    if grad is None:
160                        continue
161                    self.parameter_importance[self.task_id][param_name] += (
162                        batch_size * grad.detach() ** 2
163                    )
164
165            if head_params:
166                grads = torch.autograd.grad(
167                    loss_cls,
168                    [param for _, param in head_params],
169                    retain_graph=True,
170                    allow_unused=True,
171                )
172                for (param_name, _), grad in zip(head_params, grads):
173                    if grad is None:
174                        continue
175                    self.parameter_importance_heads[self.task_id][param_name] += (
176                        batch_size * grad.detach() ** 2
177                    )
178
179        # regularization loss. See equation (3) in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114)
180        ewc_reg = 0.0
181        for previous_task_id, previous_backbone in self.previous_task_backbones.items():
182            ewc_reg += 0.5 * self.parameter_change_reg(
183                target_model=self.backbone,
184                ref_model=previous_backbone,
185                weights=self.parameter_importance[previous_task_id],
186            )
187            if isinstance(self.heads, HeadDIL):
188                ewc_reg += 0.5 * self.parameter_change_reg(
189                    target_model=self.heads,
190                    ref_model=self.previous_task_heads[previous_task_id],
191                    weights=self.parameter_importance_heads[previous_task_id],
192                )
193
194        # total loss
195        loss = loss_cls + ewc_reg
196
197        self.manual_backward(loss)
198        opt.step()
199
200        # predicted labels
201        preds = logits.argmax(dim=1)
202
203        # accuracy of the batch
204        acc = (preds == y).float().mean()
205
206        return {
207            "preds": preds,
208            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
209            "loss_cls": loss_cls,
210            "ewc_reg": ewc_reg,
211            "acc": acc,
212            "activations": activations,
213        }

Training step for current task self.task_id.

def on_train_end(self) -> None:
215    def on_train_end(self) -> None:
216        r"""Calculate the fisher information as parameter importance and store the backbone model after the training of a task."""
217
218        # calculate fisher information at the end of training if specified
219        if self.when_calculate_fisher_information == "train_end":
220            fisher, fisher_heads, fisher_num_data = (
221                self.accumulate_fisher_information_on_train_end()
222            )
223            self.parameter_importance[self.task_id] = fisher
224            if fisher_heads is not None:
225                self.parameter_importance_heads[self.task_id] = fisher_heads
226            num_data = fisher_num_data
227        else:
228            num_data = self.num_data
229
230        # no matter when we calculate the fisher information, we need to average it over the number of data
231        for param_name, param in self.backbone.named_parameters():
232            self.parameter_importance[self.task_id][param_name] /= num_data
233        if isinstance(self.heads, HeadDIL):
234            for param_name, param in self.heads.named_parameters():
235                self.parameter_importance_heads[self.task_id][param_name] /= num_data
236
237        # store the backbone model after training the task
238        previous_backbone = deepcopy(self.backbone)
239        previous_backbone.eval()
240        self.previous_task_backbones[self.task_id] = previous_backbone
241        if isinstance(self.heads, HeadDIL):
242            previous_heads = deepcopy(self.heads)
243            previous_heads.eval()
244            self.previous_task_heads[self.task_id] = previous_heads

Calculate the fisher information as parameter importance and store the backbone model after the training of a task.

def accumulate_fisher_information_on_train_end( self) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor] | None, int]:
246    def accumulate_fisher_information_on_train_end(
247        self,
248    ) -> tuple[dict[str, Tensor], dict[str, Tensor] | None, int]:
249        r"""Accumulate the fisher information as the parameter importance for the learned task `self.task_id` at the end of its training."""
250        fisher_information_t = {}
251        fisher_information_heads: dict[str, Tensor] | None = None
252        num_data = 0
253
254        self.eval()
255        last_task_train_dataloaders = self.trainer.datamodule.train_dataloader()
256
257        for param_name, param in self.backbone.named_parameters():
258            fisher_information_t[param_name] = torch.zeros_like(param)
259        if isinstance(self.heads, HeadDIL):
260            fisher_information_heads = {}
261            for param_name, param in self.heads.named_parameters():
262                fisher_information_heads[param_name] = torch.zeros_like(param)
263
264        for x, y in last_task_train_dataloaders:
265            x = x.to(self.device)
266            y = y.to(self.device)
267            batch_size = len(y)
268            num_data += batch_size
269
270            self.backbone.zero_grad()
271            if isinstance(self.heads, HeadDIL):
272                self.heads.zero_grad()
273            logits, _ = self.forward(x, stage="train", task_id=self.task_id)
274            loss_cls = self.criterion(logits, y)
275            loss_cls.backward()
276
277            for param_name, param in self.backbone.named_parameters():
278                fisher_information_t[param_name] += batch_size * param.grad**2
279            if fisher_information_heads is not None:
280                for param_name, param in self.heads.named_parameters():
281                    if param.grad is None:
282                        continue
283                    fisher_information_heads[param_name] += batch_size * param.grad**2
284
285        return fisher_information_t, fisher_information_heads, num_data

Accumulate the fisher information as the parameter importance for the learned task self.task_id at the end of its training.

class AmnesiacEWC(clarena.cl_algorithms.base.AmnesiacCLAlgorithm, EWC):
288class AmnesiacEWC(AmnesiacCLAlgorithm, EWC):
289    r"""Amnesiac EWC algorithm."""
290
291    def __init__(
292        self,
293        backbone: CLBackbone,
294        heads: HeadsTIL | HeadsCIL | HeadDIL,
295        non_algorithmic_hparams: dict[str, Any] = {},
296        disable_unlearning: bool = False,
297        **kwargs,
298    ) -> None:
299        r"""Initialize the Amnesiac EWC algorithm with the network.
300
301        **Args:**
302        - **backbone** (`CLBackbone`): backbone network.
303        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
304        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
305        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
306        - **kwargs**: Reserved for multiple inheritance.
307        """
308        super().__init__(
309            backbone=backbone,
310            heads=heads,
311            non_algorithmic_hparams=non_algorithmic_hparams,
312            disable_unlearning=disable_unlearning,
313            **kwargs,
314        )
315
316    def on_train_start(self) -> None:
317        """Record backbone parameters before training current task."""
318        EWC.on_train_start(self)
319        AmnesiacCLAlgorithm.on_train_start(self)
320
321    def on_train_end(self) -> None:
322        """Record backbone parameters before training current task."""
323        EWC.on_train_end(self)
324        AmnesiacCLAlgorithm.on_train_end(self)

Amnesiac EWC algorithm.

AmnesiacEWC( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any] = {}, disable_unlearning: bool = False, **kwargs)
291    def __init__(
292        self,
293        backbone: CLBackbone,
294        heads: HeadsTIL | HeadsCIL | HeadDIL,
295        non_algorithmic_hparams: dict[str, Any] = {},
296        disable_unlearning: bool = False,
297        **kwargs,
298    ) -> None:
299        r"""Initialize the Amnesiac EWC algorithm with the network.
300
301        **Args:**
302        - **backbone** (`CLBackbone`): backbone network.
303        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
304        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
305        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
306        - **kwargs**: Reserved for multiple inheritance.
307        """
308        super().__init__(
309            backbone=backbone,
310            heads=heads,
311            non_algorithmic_hparams=non_algorithmic_hparams,
312            disable_unlearning=disable_unlearning,
313            **kwargs,
314        )

Initialize the Amnesiac EWC algorithm with the network.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL | HeadDIL): output heads.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • disable_unlearning (bool): whether to disable the unlearning functionality. Default is False.
  • kwargs: Reserved for multiple inheritance.
def on_train_start(self) -> None:
316    def on_train_start(self) -> None:
317        """Record backbone parameters before training current task."""
318        EWC.on_train_start(self)
319        AmnesiacCLAlgorithm.on_train_start(self)

Record backbone parameters before training current task.

def on_train_end(self) -> None:
321    def on_train_end(self) -> None:
322        """Record backbone parameters before training current task."""
323        EWC.on_train_end(self)
324        AmnesiacCLAlgorithm.on_train_end(self)

Record backbone parameters before training current task.