clarena.cl_algorithms.lwf

The submodule in cl_algorithms for LwF (Learning without Forgetting) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [LwF (Learning without Forgetting) algorithm](https://ieeexplore.ieee.org/abstract/document/8107520).
  3"""
  4
  5__all__ = ["LwF", "AmnesiacLwF"]
  6
  7import logging
  8from copy import deepcopy
  9from typing import Any
 10
 11import torch
 12import torch.nn.functional as F
 13from torch import Tensor, nn
 14
 15from clarena.backbones import CLBackbone
 16from clarena.cl_algorithms import AmnesiacCLAlgorithm, Finetuning
 17from clarena.cl_algorithms.regularizers import DistillationReg
 18from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL
 19
 20# always get logger for built-in logging in each module
 21pylogger = logging.getLogger(__name__)
 22
 23
 24class LwF(Finetuning):
 25    r"""[LwF (Learning without Forgetting)](https://ieeexplore.ieee.org/abstract/document/8107520) algorithm.
 26
 27    A regularization-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularization term. It is a simple yet effective method for continual learning.
 28
 29    We implement LwF as a subclass of Finetuning algorithm, as LwF has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 30    """
 31
 32    def __init__(
 33        self,
 34        backbone: CLBackbone,
 35        heads: HeadsTIL | HeadsCIL | HeadDIL,
 36        distillation_reg_factor: float,
 37        distillation_reg_temperature: float,
 38        non_algorithmic_hparams: dict[str, Any] = {},
 39        **kwargs,
 40    ) -> None:
 41        r"""Initialize the LwF algorithm with the network.
 42
 43        **Args:**
 44        - **backbone** (`CLBackbone`): backbone network.
 45        - **heads** (`HeadsTIL` | `HeadDIL`): output heads.
 46        - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularization factor. It controls the strength of preventing forgetting.
 47        - **distillation_reg_temperature** (`float`): hyperparameter, the temperature in the distillation regularization. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
 48        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 49        - **kwargs**: Reserved for multiple inheritance.
 50
 51        """
 52        super().__init__(
 53            backbone=backbone,
 54            heads=heads,
 55            non_algorithmic_hparams=non_algorithmic_hparams,
 56            **kwargs,
 57        )
 58
 59        self.previous_task_backbones: dict[str, nn.Module] = {}
 60        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
 61        
 62        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.
 63        """
 64        if isinstance(self.heads, HeadDIL):
 65            self.previous_task_heads: dict[str, nn.Module] = {}
 66            r"""The heads snapshot of the previous task (teacher). This is only used when the heads is `HeadDIL`, because in DIL scenario, all tasks share the same head, so we need to store the previous head for distillation; where in TIL scenario, each task has its own head, so we can directly use the head of the previous task without storing it separately."""
 67
 68        self.distillation_reg_factor: float = distillation_reg_factor
 69        r"""The distillation regularization factor."""
 70        self.distillation_reg_temperature: float = distillation_reg_temperature
 71        r"""The distillation regularization temperature."""
 72        self.distillation_reg = DistillationReg(
 73            factor=distillation_reg_factor,
 74            temperature=distillation_reg_temperature,
 75            distance="cross_entropy",
 76        )
 77        r"""Initialize and store the distillation regularizer."""
 78
 79        # save additional algorithmic hyperparameters
 80        self.save_hyperparameters(
 81            "distillation_reg_factor",
 82            "distillation_reg_temperature",
 83        )
 84
 85        LwF.sanity_check(self)
 86
 87    def sanity_check(self) -> None:
 88        r"""Sanity check."""
 89
 90        if self.distillation_reg_factor <= 0:
 91            raise ValueError(
 92                f"The distillation regularization factor should be positive, but got {self.distillation_reg_factor}."
 93            )
 94
 95        if self.distillation_reg_temperature <= 0:
 96            raise ValueError(
 97                f"The distillation regularization temperature should be positive, but got {self.distillation_reg_temperature}."
 98            )
 99
100    def training_step(self, batch: Any) -> dict[str, Tensor]:
101        r"""Training step for current task `self.task_id`.
102
103        **Args:**
104        - **batch** (`Any`): a batch of training data.
105
106        **Returns:**
107        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
108        """
109        x, y = batch
110
111        # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)
112        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
113        loss_cls = self.criterion(logits, y)
114
115        # regularization loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)
116        distillation_reg = 0.0
117        for previous_task_id, previous_backbone in self.previous_task_backbones.items():
118            # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label."
119
120            # get the student logits for this batch, using detached head params to avoid updating old heads
121            student_feature, _ = self.backbone(
122                x, stage="train", task_id=previous_task_id
123            )
124            if isinstance(self.heads, HeadDIL):
125                head = self.heads.get_head()
126            elif isinstance(self.heads, HeadsTIL):
127                head = self.heads.get_head(previous_task_id)
128            else:
129                raise TypeError(f"Unsupported heads type {type(self.heads)} in LwF.")
130            student_logits = F.linear(
131                student_feature,
132                head.weight.detach(),
133                head.bias.detach() if head.bias is not None else None,
134            )
135
136            # get the teacher logits for this batch, which is from the previous model
137            with torch.no_grad():  # stop updating the previous backbones and heads
138                teacher_feature, _ = previous_backbone(
139                    x, stage="test", task_id=previous_task_id
140                )
141                if isinstance(self.heads, HeadDIL):
142                    previous_head = self.previous_task_heads[previous_task_id]
143                    teacher_logits = previous_head(teacher_feature)
144                elif isinstance(self.heads, HeadsTIL):
145                    teacher_logits = self.heads(
146                        teacher_feature, task_id=previous_task_id
147                    )
148                else:
149                    raise TypeError(
150                        f"Unsupported heads type {type(self.heads)} in LwF."
151                    )
152
153            distillation_reg += self.distillation_reg(
154                student_logits=student_logits,
155                teacher_logits=teacher_logits,
156            )
157
158        # do not average over tasks to avoid linear increase of the regularization loss. LwF paper doesn't mention this!
159
160        # total loss
161        loss = loss_cls + distillation_reg
162
163        # predicted labels
164        preds = logits.argmax(dim=1)
165
166        # accuracy of the batch
167        acc = (preds == y).float().mean()
168
169        return {
170            "preds": preds,
171            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
172            "loss_cls": loss_cls,
173            "distillation_reg": distillation_reg,
174            "acc": acc,
175            "activations": activations,
176        }
177
178    def on_train_end(self) -> None:
179        r"""Store the backbone model after the training of a task.
180
181        The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks.
182        """
183        current_backbone = deepcopy(self.backbone)
184        current_backbone.eval()  # set the store model to evaluation mode to prevent updating
185        if isinstance(self.heads, HeadDIL):
186            current_head = deepcopy(self.heads.get_head())
187            current_head.eval()  # set the store model to evaluation mode to prevent updating
188        self.heads.get_head(
189            self.task_id
190        ).eval()  # set the store model to evaluation mode to prevent updating
191        self.previous_task_backbones[self.task_id] = current_backbone
192        if isinstance(self.heads, HeadDIL):
193            self.previous_task_heads[self.task_id] = current_head
194
195
196class AmnesiacLwF(AmnesiacCLAlgorithm, LwF):
197    r"""Amnesiac LwF algorithm."""
198
199    def __init__(
200        self,
201        backbone: CLBackbone,
202        heads: HeadsTIL | HeadDIL,
203        non_algorithmic_hparams: dict[str, Any] = {},
204        disable_unlearning: bool = False,
205        **kwargs,
206    ) -> None:
207        r"""Initialize the Amnesiac LwF algorithm with the network.
208
209
210        **Args:**
211        - **backbone** (`CLBackbone`): backbone network.
212        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. Currently this LwF supports Task-Incremental Learning (TIL) and Domain-Incremental Learning (DIL) scenarios.
213        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
214        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
215        - **kwargs**: Reserved for multiple inheritance.
216        """
217        super().__init__(
218            backbone=backbone,
219            heads=heads,
220            non_algorithmic_hparams=non_algorithmic_hparams,
221            disable_unlearning=disable_unlearning,
222            **kwargs,
223        )
224
225    def on_train_start(self) -> None:
226        """Record backbone parameters before training current task."""
227        LwF.on_train_start(self)
228        AmnesiacCLAlgorithm.on_train_start(self)
229
230    def on_train_end(self) -> None:
231        """Record backbone parameters before training current task."""
232        LwF.on_train_end(self)
233        AmnesiacCLAlgorithm.on_train_end(self)
 25class LwF(Finetuning):
 26    r"""[LwF (Learning without Forgetting)](https://ieeexplore.ieee.org/abstract/document/8107520) algorithm.
 27
 28    A regularization-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularization term. It is a simple yet effective method for continual learning.
 29
 30    We implement LwF as a subclass of Finetuning algorithm, as LwF has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 31    """
 32
 33    def __init__(
 34        self,
 35        backbone: CLBackbone,
 36        heads: HeadsTIL | HeadsCIL | HeadDIL,
 37        distillation_reg_factor: float,
 38        distillation_reg_temperature: float,
 39        non_algorithmic_hparams: dict[str, Any] = {},
 40        **kwargs,
 41    ) -> None:
 42        r"""Initialize the LwF algorithm with the network.
 43
 44        **Args:**
 45        - **backbone** (`CLBackbone`): backbone network.
 46        - **heads** (`HeadsTIL` | `HeadDIL`): output heads.
 47        - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularization factor. It controls the strength of preventing forgetting.
 48        - **distillation_reg_temperature** (`float`): hyperparameter, the temperature in the distillation regularization. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
 49        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 50        - **kwargs**: Reserved for multiple inheritance.
 51
 52        """
 53        super().__init__(
 54            backbone=backbone,
 55            heads=heads,
 56            non_algorithmic_hparams=non_algorithmic_hparams,
 57            **kwargs,
 58        )
 59
 60        self.previous_task_backbones: dict[str, nn.Module] = {}
 61        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
 62        
 63        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.
 64        """
 65        if isinstance(self.heads, HeadDIL):
 66            self.previous_task_heads: dict[str, nn.Module] = {}
 67            r"""The heads snapshot of the previous task (teacher). This is only used when the heads is `HeadDIL`, because in DIL scenario, all tasks share the same head, so we need to store the previous head for distillation; where in TIL scenario, each task has its own head, so we can directly use the head of the previous task without storing it separately."""
 68
 69        self.distillation_reg_factor: float = distillation_reg_factor
 70        r"""The distillation regularization factor."""
 71        self.distillation_reg_temperature: float = distillation_reg_temperature
 72        r"""The distillation regularization temperature."""
 73        self.distillation_reg = DistillationReg(
 74            factor=distillation_reg_factor,
 75            temperature=distillation_reg_temperature,
 76            distance="cross_entropy",
 77        )
 78        r"""Initialize and store the distillation regularizer."""
 79
 80        # save additional algorithmic hyperparameters
 81        self.save_hyperparameters(
 82            "distillation_reg_factor",
 83            "distillation_reg_temperature",
 84        )
 85
 86        LwF.sanity_check(self)
 87
 88    def sanity_check(self) -> None:
 89        r"""Sanity check."""
 90
 91        if self.distillation_reg_factor <= 0:
 92            raise ValueError(
 93                f"The distillation regularization factor should be positive, but got {self.distillation_reg_factor}."
 94            )
 95
 96        if self.distillation_reg_temperature <= 0:
 97            raise ValueError(
 98                f"The distillation regularization temperature should be positive, but got {self.distillation_reg_temperature}."
 99            )
100
101    def training_step(self, batch: Any) -> dict[str, Tensor]:
102        r"""Training step for current task `self.task_id`.
103
104        **Args:**
105        - **batch** (`Any`): a batch of training data.
106
107        **Returns:**
108        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
109        """
110        x, y = batch
111
112        # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)
113        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
114        loss_cls = self.criterion(logits, y)
115
116        # regularization loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)
117        distillation_reg = 0.0
118        for previous_task_id, previous_backbone in self.previous_task_backbones.items():
119            # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label."
120
121            # get the student logits for this batch, using detached head params to avoid updating old heads
122            student_feature, _ = self.backbone(
123                x, stage="train", task_id=previous_task_id
124            )
125            if isinstance(self.heads, HeadDIL):
126                head = self.heads.get_head()
127            elif isinstance(self.heads, HeadsTIL):
128                head = self.heads.get_head(previous_task_id)
129            else:
130                raise TypeError(f"Unsupported heads type {type(self.heads)} in LwF.")
131            student_logits = F.linear(
132                student_feature,
133                head.weight.detach(),
134                head.bias.detach() if head.bias is not None else None,
135            )
136
137            # get the teacher logits for this batch, which is from the previous model
138            with torch.no_grad():  # stop updating the previous backbones and heads
139                teacher_feature, _ = previous_backbone(
140                    x, stage="test", task_id=previous_task_id
141                )
142                if isinstance(self.heads, HeadDIL):
143                    previous_head = self.previous_task_heads[previous_task_id]
144                    teacher_logits = previous_head(teacher_feature)
145                elif isinstance(self.heads, HeadsTIL):
146                    teacher_logits = self.heads(
147                        teacher_feature, task_id=previous_task_id
148                    )
149                else:
150                    raise TypeError(
151                        f"Unsupported heads type {type(self.heads)} in LwF."
152                    )
153
154            distillation_reg += self.distillation_reg(
155                student_logits=student_logits,
156                teacher_logits=teacher_logits,
157            )
158
159        # do not average over tasks to avoid linear increase of the regularization loss. LwF paper doesn't mention this!
160
161        # total loss
162        loss = loss_cls + distillation_reg
163
164        # predicted labels
165        preds = logits.argmax(dim=1)
166
167        # accuracy of the batch
168        acc = (preds == y).float().mean()
169
170        return {
171            "preds": preds,
172            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
173            "loss_cls": loss_cls,
174            "distillation_reg": distillation_reg,
175            "acc": acc,
176            "activations": activations,
177        }
178
179    def on_train_end(self) -> None:
180        r"""Store the backbone model after the training of a task.
181
182        The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks.
183        """
184        current_backbone = deepcopy(self.backbone)
185        current_backbone.eval()  # set the store model to evaluation mode to prevent updating
186        if isinstance(self.heads, HeadDIL):
187            current_head = deepcopy(self.heads.get_head())
188            current_head.eval()  # set the store model to evaluation mode to prevent updating
189        self.heads.get_head(
190            self.task_id
191        ).eval()  # set the store model to evaluation mode to prevent updating
192        self.previous_task_backbones[self.task_id] = current_backbone
193        if isinstance(self.heads, HeadDIL):
194            self.previous_task_heads[self.task_id] = current_head

LwF (Learning without Forgetting) algorithm.

A regularization-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularization term. It is a simple yet effective method for continual learning.

We implement LwF as a subclass of Finetuning algorithm, as LwF has the same forward(), validation_step() and test_step() method as Finetuning class.

LwF( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, distillation_reg_factor: float, distillation_reg_temperature: float, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
33    def __init__(
34        self,
35        backbone: CLBackbone,
36        heads: HeadsTIL | HeadsCIL | HeadDIL,
37        distillation_reg_factor: float,
38        distillation_reg_temperature: float,
39        non_algorithmic_hparams: dict[str, Any] = {},
40        **kwargs,
41    ) -> None:
42        r"""Initialize the LwF algorithm with the network.
43
44        **Args:**
45        - **backbone** (`CLBackbone`): backbone network.
46        - **heads** (`HeadsTIL` | `HeadDIL`): output heads.
47        - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularization factor. It controls the strength of preventing forgetting.
48        - **distillation_reg_temperature** (`float`): hyperparameter, the temperature in the distillation regularization. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
49        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
50        - **kwargs**: Reserved for multiple inheritance.
51
52        """
53        super().__init__(
54            backbone=backbone,
55            heads=heads,
56            non_algorithmic_hparams=non_algorithmic_hparams,
57            **kwargs,
58        )
59
60        self.previous_task_backbones: dict[str, nn.Module] = {}
61        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
62        
63        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.
64        """
65        if isinstance(self.heads, HeadDIL):
66            self.previous_task_heads: dict[str, nn.Module] = {}
67            r"""The heads snapshot of the previous task (teacher). This is only used when the heads is `HeadDIL`, because in DIL scenario, all tasks share the same head, so we need to store the previous head for distillation; where in TIL scenario, each task has its own head, so we can directly use the head of the previous task without storing it separately."""
68
69        self.distillation_reg_factor: float = distillation_reg_factor
70        r"""The distillation regularization factor."""
71        self.distillation_reg_temperature: float = distillation_reg_temperature
72        r"""The distillation regularization temperature."""
73        self.distillation_reg = DistillationReg(
74            factor=distillation_reg_factor,
75            temperature=distillation_reg_temperature,
76            distance="cross_entropy",
77        )
78        r"""Initialize and store the distillation regularizer."""
79
80        # save additional algorithmic hyperparameters
81        self.save_hyperparameters(
82            "distillation_reg_factor",
83            "distillation_reg_temperature",
84        )
85
86        LwF.sanity_check(self)

Initialize the LwF algorithm with the network.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadDIL): output heads.
  • distillation_reg_factor (float): hyperparameter, the distillation regularization factor. It controls the strength of preventing forgetting.
  • distillation_reg_temperature (float): hyperparameter, the temperature in the distillation regularization. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
previous_task_backbones: dict[str, torch.nn.modules.module.Module]

Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a nn.Module backbone after the corresponding previous task was trained.

Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.

distillation_reg_factor: float

The distillation regularization factor.

distillation_reg_temperature: float

The distillation regularization temperature.

distillation_reg

Initialize and store the distillation regularizer.

def sanity_check(self) -> None:
88    def sanity_check(self) -> None:
89        r"""Sanity check."""
90
91        if self.distillation_reg_factor <= 0:
92            raise ValueError(
93                f"The distillation regularization factor should be positive, but got {self.distillation_reg_factor}."
94            )
95
96        if self.distillation_reg_temperature <= 0:
97            raise ValueError(
98                f"The distillation regularization temperature should be positive, but got {self.distillation_reg_temperature}."
99            )

Sanity check.

def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
101    def training_step(self, batch: Any) -> dict[str, Tensor]:
102        r"""Training step for current task `self.task_id`.
103
104        **Args:**
105        - **batch** (`Any`): a batch of training data.
106
107        **Returns:**
108        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
109        """
110        x, y = batch
111
112        # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)
113        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
114        loss_cls = self.criterion(logits, y)
115
116        # regularization loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520)
117        distillation_reg = 0.0
118        for previous_task_id, previous_backbone in self.previous_task_backbones.items():
119            # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label."
120
121            # get the student logits for this batch, using detached head params to avoid updating old heads
122            student_feature, _ = self.backbone(
123                x, stage="train", task_id=previous_task_id
124            )
125            if isinstance(self.heads, HeadDIL):
126                head = self.heads.get_head()
127            elif isinstance(self.heads, HeadsTIL):
128                head = self.heads.get_head(previous_task_id)
129            else:
130                raise TypeError(f"Unsupported heads type {type(self.heads)} in LwF.")
131            student_logits = F.linear(
132                student_feature,
133                head.weight.detach(),
134                head.bias.detach() if head.bias is not None else None,
135            )
136
137            # get the teacher logits for this batch, which is from the previous model
138            with torch.no_grad():  # stop updating the previous backbones and heads
139                teacher_feature, _ = previous_backbone(
140                    x, stage="test", task_id=previous_task_id
141                )
142                if isinstance(self.heads, HeadDIL):
143                    previous_head = self.previous_task_heads[previous_task_id]
144                    teacher_logits = previous_head(teacher_feature)
145                elif isinstance(self.heads, HeadsTIL):
146                    teacher_logits = self.heads(
147                        teacher_feature, task_id=previous_task_id
148                    )
149                else:
150                    raise TypeError(
151                        f"Unsupported heads type {type(self.heads)} in LwF."
152                    )
153
154            distillation_reg += self.distillation_reg(
155                student_logits=student_logits,
156                teacher_logits=teacher_logits,
157            )
158
159        # do not average over tasks to avoid linear increase of the regularization loss. LwF paper doesn't mention this!
160
161        # total loss
162        loss = loss_cls + distillation_reg
163
164        # predicted labels
165        preds = logits.argmax(dim=1)
166
167        # accuracy of the batch
168        acc = (preds == y).float().mean()
169
170        return {
171            "preds": preds,
172            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
173            "loss_cls": loss_cls,
174            "distillation_reg": distillation_reg,
175            "acc": acc,
176            "activations": activations,
177        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def on_train_end(self) -> None:
179    def on_train_end(self) -> None:
180        r"""Store the backbone model after the training of a task.
181
182        The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks.
183        """
184        current_backbone = deepcopy(self.backbone)
185        current_backbone.eval()  # set the store model to evaluation mode to prevent updating
186        if isinstance(self.heads, HeadDIL):
187            current_head = deepcopy(self.heads.get_head())
188            current_head.eval()  # set the store model to evaluation mode to prevent updating
189        self.heads.get_head(
190            self.task_id
191        ).eval()  # set the store model to evaluation mode to prevent updating
192        self.previous_task_backbones[self.task_id] = current_backbone
193        if isinstance(self.heads, HeadDIL):
194            self.previous_task_heads[self.task_id] = current_head

Store the backbone model after the training of a task.

The model is stored in self.previous_task_backbones for constructing the regularisation loss in the future tasks.

class AmnesiacLwF(clarena.cl_algorithms.base.AmnesiacCLAlgorithm, LwF):
197class AmnesiacLwF(AmnesiacCLAlgorithm, LwF):
198    r"""Amnesiac LwF algorithm."""
199
200    def __init__(
201        self,
202        backbone: CLBackbone,
203        heads: HeadsTIL | HeadDIL,
204        non_algorithmic_hparams: dict[str, Any] = {},
205        disable_unlearning: bool = False,
206        **kwargs,
207    ) -> None:
208        r"""Initialize the Amnesiac LwF algorithm with the network.
209
210
211        **Args:**
212        - **backbone** (`CLBackbone`): backbone network.
213        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. Currently this LwF supports Task-Incremental Learning (TIL) and Domain-Incremental Learning (DIL) scenarios.
214        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
215        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
216        - **kwargs**: Reserved for multiple inheritance.
217        """
218        super().__init__(
219            backbone=backbone,
220            heads=heads,
221            non_algorithmic_hparams=non_algorithmic_hparams,
222            disable_unlearning=disable_unlearning,
223            **kwargs,
224        )
225
226    def on_train_start(self) -> None:
227        """Record backbone parameters before training current task."""
228        LwF.on_train_start(self)
229        AmnesiacCLAlgorithm.on_train_start(self)
230
231    def on_train_end(self) -> None:
232        """Record backbone parameters before training current task."""
233        LwF.on_train_end(self)
234        AmnesiacCLAlgorithm.on_train_end(self)

Amnesiac LwF algorithm.

AmnesiacLwF( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any] = {}, disable_unlearning: bool = False, **kwargs)
200    def __init__(
201        self,
202        backbone: CLBackbone,
203        heads: HeadsTIL | HeadDIL,
204        non_algorithmic_hparams: dict[str, Any] = {},
205        disable_unlearning: bool = False,
206        **kwargs,
207    ) -> None:
208        r"""Initialize the Amnesiac LwF algorithm with the network.
209
210
211        **Args:**
212        - **backbone** (`CLBackbone`): backbone network.
213        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. Currently this LwF supports Task-Incremental Learning (TIL) and Domain-Incremental Learning (DIL) scenarios.
214        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
215        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
216        - **kwargs**: Reserved for multiple inheritance.
217        """
218        super().__init__(
219            backbone=backbone,
220            heads=heads,
221            non_algorithmic_hparams=non_algorithmic_hparams,
222            disable_unlearning=disable_unlearning,
223            **kwargs,
224        )

Initialize the Amnesiac LwF algorithm with the network.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadDIL): output heads. Currently this LwF supports Task-Incremental Learning (TIL) and Domain-Incremental Learning (DIL) scenarios.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • disable_unlearning (bool): whether to disable the unlearning functionality. Default is False.
  • kwargs: Reserved for multiple inheritance.
def on_train_start(self) -> None:
226    def on_train_start(self) -> None:
227        """Record backbone parameters before training current task."""
228        LwF.on_train_start(self)
229        AmnesiacCLAlgorithm.on_train_start(self)

Record backbone parameters before training current task.

def on_train_end(self) -> None:
231    def on_train_end(self) -> None:
232        """Record backbone parameters before training current task."""
233        LwF.on_train_end(self)
234        AmnesiacCLAlgorithm.on_train_end(self)

Record backbone parameters before training current task.