clarena.cl_algorithms.lwf

The submodule in cl_algorithms for LwF (Learning without Forgetting) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [LwF (Learning without Forgetting) algorithm](https://ieeexplore.ieee.org/abstract/document/8107520).
  3"""
  4
  5__all__ = ["LwF"]
  6
  7import logging
  8from copy import deepcopy
  9from typing import Any
 10
 11import torch
 12from torch import Tensor, nn
 13
 14from clarena.backbones import CLBackbone
 15from clarena.cl_algorithms import Finetuning
 16from clarena.cl_algorithms.regularisers import DistillationReg
 17from clarena.cl_heads import HeadsCIL, HeadsTIL
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23class LwF(Finetuning):
 24    r"""LwF (Learning without Forgetting) algorithm.
 25
 26    [LwF (Learning without Forgetting, 2017)](https://ieeexplore.ieee.org/abstract/document/8107520) is a regularisation-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularisation term. It is a simple yet effective method for continual learning.
 27
 28    We implement LwF as a subclass of Finetuning algorithm, as LwF has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: CLBackbone,
 34        heads: HeadsTIL | HeadsCIL,
 35        distillation_reg_factor: float,
 36        distillation_reg_temparture: float,
 37    ) -> None:
 38        r"""Initialise the LwF algorithm with the network.
 39
 40        **Args:**
 41        - **backbone** (`CLBackbone`): backbone network.
 42        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 43        - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularisation factor. It controls the strength of preventing forgetting.
 44        - **distillation_reg_temparture** (`float`): hyperparameter, the temperature in the distillation regularisation. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
 45        """
 46        Finetuning.__init__(self, backbone=backbone, heads=heads)
 47
 48        self.previous_task_backbones: dict[str, nn.Module] = {}
 49        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
 50        
 51        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.
 52        """
 53
 54        self.distillation_reg_factor = distillation_reg_factor
 55        r"""Store distillation regularisation factor."""
 56        self.distillation_reg_temperature = distillation_reg_temparture
 57        r"""Store distillation regularisation temperature."""
 58        self.distillation_reg = DistillationReg(
 59            factor=distillation_reg_factor,
 60            temperature=distillation_reg_temparture,
 61            distance="cross_entropy",
 62        )
 63        r"""Initialise and store the distillation regulariser."""
 64
 65        LwF.sanity_check(self)
 66
 67    def sanity_check(self) -> None:
 68        r"""Check the sanity of the arguments.
 69
 70        **Raises:**
 71        - **ValueError**: If the regularisation factor and distillation temperature is not positive.
 72        """
 73
 74        if self.distillation_reg_factor <= 0:
 75            raise ValueError(
 76                f"The distillation regularisation factor should be positive, but got {self.distillation_reg_factor}."
 77            )
 78
 79        if self.distillation_reg_temperature <= 0:
 80            raise ValueError(
 81                f"The distillation regularisation temperature should be positive, but got {self.distillation_reg_temperature}."
 82            )
 83
 84    def training_step(self, batch: Any) -> dict[str, Tensor]:
 85        r"""Training step for current task `self.task_id`.
 86
 87        **Args:**
 88        - **batch** (`Any`): a batch of training data.
 89
 90        **Returns:**
 91        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 92        """
 93        x, y = batch
 94
 95        # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
 96        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
 97        loss_cls = self.criterion(logits, y)
 98
 99        # regularisation loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
100        loss_reg = 0.0
101        for previous_task_id in range(1, self.task_id):
102            # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label."
103
104            # get the teacher logits for this batch, which is from the current model (to previous output head)
105            student_feature, _ = self.backbone(
106                x, stage="train", task_id=previous_task_id
107            )
108            with torch.no_grad():  # stop updating the previous heads
109                student_logits = self.heads(student_feature, task_id=previous_task_id)
110
111            # get the teacher logits for this batch, which is from the previous model
112            previous_backbone = self.previous_task_backbones[previous_task_id]
113            with torch.no_grad():  # stop updating the previous backbones and heads
114                teacher_feature, _ = previous_backbone(
115                    x, stage="test", task_id=previous_task_id
116                )
117
118                teacher_logits = self.heads(teacher_feature, task_id=previous_task_id)
119
120            loss_reg += self.distillation_reg(
121                student_logits=student_logits,
122                teacher_logits=teacher_logits,
123            )
124
125        if self.task_id != 1:
126            loss_reg /= (
127                self.task_id
128            )  # average over tasks to avoid linear increase of the regularisation loss
129
130        # total loss
131        loss = loss_cls + loss_reg
132
133        # accuracy of the batch
134        acc = (logits.argmax(dim=1) == y).float().mean()
135
136        return {
137            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
138            "loss_cls": loss_cls,
139            "loss_reg": loss_reg,
140            "acc": acc,
141            "hidden_features": hidden_features,
142        }
143
144    def on_train_end(self) -> None:
145        r"""Store the backbone model after the training of a task.
146
147        The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks.
148        """
149        previous_backbone = deepcopy(self.backbone)
150        previous_backbone.eval()  # set the store model to evaluation mode to prevent updating
151        self.heads.heads[
152            f"{self.task_id}"
153        ].eval()  # set the store model to evaluation mode to prevent updating
154        self.previous_task_backbones[self.task_id] = previous_backbone
 24class LwF(Finetuning):
 25    r"""LwF (Learning without Forgetting) algorithm.
 26
 27    [LwF (Learning without Forgetting, 2017)](https://ieeexplore.ieee.org/abstract/document/8107520) is a regularisation-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularisation term. It is a simple yet effective method for continual learning.
 28
 29    We implement LwF as a subclass of Finetuning algorithm, as LwF has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 30    """
 31
 32    def __init__(
 33        self,
 34        backbone: CLBackbone,
 35        heads: HeadsTIL | HeadsCIL,
 36        distillation_reg_factor: float,
 37        distillation_reg_temparture: float,
 38    ) -> None:
 39        r"""Initialise the LwF algorithm with the network.
 40
 41        **Args:**
 42        - **backbone** (`CLBackbone`): backbone network.
 43        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 44        - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularisation factor. It controls the strength of preventing forgetting.
 45        - **distillation_reg_temparture** (`float`): hyperparameter, the temperature in the distillation regularisation. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
 46        """
 47        Finetuning.__init__(self, backbone=backbone, heads=heads)
 48
 49        self.previous_task_backbones: dict[str, nn.Module] = {}
 50        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
 51        
 52        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.
 53        """
 54
 55        self.distillation_reg_factor = distillation_reg_factor
 56        r"""Store distillation regularisation factor."""
 57        self.distillation_reg_temperature = distillation_reg_temparture
 58        r"""Store distillation regularisation temperature."""
 59        self.distillation_reg = DistillationReg(
 60            factor=distillation_reg_factor,
 61            temperature=distillation_reg_temparture,
 62            distance="cross_entropy",
 63        )
 64        r"""Initialise and store the distillation regulariser."""
 65
 66        LwF.sanity_check(self)
 67
 68    def sanity_check(self) -> None:
 69        r"""Check the sanity of the arguments.
 70
 71        **Raises:**
 72        - **ValueError**: If the regularisation factor and distillation temperature is not positive.
 73        """
 74
 75        if self.distillation_reg_factor <= 0:
 76            raise ValueError(
 77                f"The distillation regularisation factor should be positive, but got {self.distillation_reg_factor}."
 78            )
 79
 80        if self.distillation_reg_temperature <= 0:
 81            raise ValueError(
 82                f"The distillation regularisation temperature should be positive, but got {self.distillation_reg_temperature}."
 83            )
 84
 85    def training_step(self, batch: Any) -> dict[str, Tensor]:
 86        r"""Training step for current task `self.task_id`.
 87
 88        **Args:**
 89        - **batch** (`Any`): a batch of training data.
 90
 91        **Returns:**
 92        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 93        """
 94        x, y = batch
 95
 96        # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
 97        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
 98        loss_cls = self.criterion(logits, y)
 99
100        # regularisation loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
101        loss_reg = 0.0
102        for previous_task_id in range(1, self.task_id):
103            # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label."
104
105            # get the teacher logits for this batch, which is from the current model (to previous output head)
106            student_feature, _ = self.backbone(
107                x, stage="train", task_id=previous_task_id
108            )
109            with torch.no_grad():  # stop updating the previous heads
110                student_logits = self.heads(student_feature, task_id=previous_task_id)
111
112            # get the teacher logits for this batch, which is from the previous model
113            previous_backbone = self.previous_task_backbones[previous_task_id]
114            with torch.no_grad():  # stop updating the previous backbones and heads
115                teacher_feature, _ = previous_backbone(
116                    x, stage="test", task_id=previous_task_id
117                )
118
119                teacher_logits = self.heads(teacher_feature, task_id=previous_task_id)
120
121            loss_reg += self.distillation_reg(
122                student_logits=student_logits,
123                teacher_logits=teacher_logits,
124            )
125
126        if self.task_id != 1:
127            loss_reg /= (
128                self.task_id
129            )  # average over tasks to avoid linear increase of the regularisation loss
130
131        # total loss
132        loss = loss_cls + loss_reg
133
134        # accuracy of the batch
135        acc = (logits.argmax(dim=1) == y).float().mean()
136
137        return {
138            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
139            "loss_cls": loss_cls,
140            "loss_reg": loss_reg,
141            "acc": acc,
142            "hidden_features": hidden_features,
143        }
144
145    def on_train_end(self) -> None:
146        r"""Store the backbone model after the training of a task.
147
148        The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks.
149        """
150        previous_backbone = deepcopy(self.backbone)
151        previous_backbone.eval()  # set the store model to evaluation mode to prevent updating
152        self.heads.heads[
153            f"{self.task_id}"
154        ].eval()  # set the store model to evaluation mode to prevent updating
155        self.previous_task_backbones[self.task_id] = previous_backbone

LwF (Learning without Forgetting) algorithm.

LwF (Learning without Forgetting, 2017) is a regularisation-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularisation term. It is a simple yet effective method for continual learning.

We implement LwF as a subclass of Finetuning algorithm, as LwF has the same forward(), validation_step() and test_step() method as Finetuning class.

LwF( backbone: clarena.backbones.CLBackbone, heads: clarena.cl_heads.HeadsTIL | clarena.cl_heads.HeadsCIL, distillation_reg_factor: float, distillation_reg_temparture: float)
32    def __init__(
33        self,
34        backbone: CLBackbone,
35        heads: HeadsTIL | HeadsCIL,
36        distillation_reg_factor: float,
37        distillation_reg_temparture: float,
38    ) -> None:
39        r"""Initialise the LwF algorithm with the network.
40
41        **Args:**
42        - **backbone** (`CLBackbone`): backbone network.
43        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
44        - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularisation factor. It controls the strength of preventing forgetting.
45        - **distillation_reg_temparture** (`float`): hyperparameter, the temperature in the distillation regularisation. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
46        """
47        Finetuning.__init__(self, backbone=backbone, heads=heads)
48
49        self.previous_task_backbones: dict[str, nn.Module] = {}
50        r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained.
51        
52        Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.
53        """
54
55        self.distillation_reg_factor = distillation_reg_factor
56        r"""Store distillation regularisation factor."""
57        self.distillation_reg_temperature = distillation_reg_temparture
58        r"""Store distillation regularisation temperature."""
59        self.distillation_reg = DistillationReg(
60            factor=distillation_reg_factor,
61            temperature=distillation_reg_temparture,
62            distance="cross_entropy",
63        )
64        r"""Initialise and store the distillation regulariser."""
65
66        LwF.sanity_check(self)

Initialise the LwF algorithm with the network.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL): output heads.
  • distillation_reg_factor (float): hyperparameter, the distillation regularisation factor. It controls the strength of preventing forgetting.
  • distillation_reg_temparture (float): hyperparameter, the temperature in the distillation regularisation. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
previous_task_backbones: dict[str, torch.nn.modules.module.Module]

Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a nn.Module backbone after the corresponding previous task was trained.

Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.

distillation_reg_factor

Store distillation regularisation factor.

distillation_reg_temperature

Store distillation regularisation temperature.

distillation_reg

Initialise and store the distillation regulariser.

def sanity_check(self) -> None:
68    def sanity_check(self) -> None:
69        r"""Check the sanity of the arguments.
70
71        **Raises:**
72        - **ValueError**: If the regularisation factor and distillation temperature is not positive.
73        """
74
75        if self.distillation_reg_factor <= 0:
76            raise ValueError(
77                f"The distillation regularisation factor should be positive, but got {self.distillation_reg_factor}."
78            )
79
80        if self.distillation_reg_temperature <= 0:
81            raise ValueError(
82                f"The distillation regularisation temperature should be positive, but got {self.distillation_reg_temperature}."
83            )

Check the sanity of the arguments.

Raises:

  • ValueError: If the regularisation factor and distillation temperature is not positive.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
 85    def training_step(self, batch: Any) -> dict[str, Tensor]:
 86        r"""Training step for current task `self.task_id`.
 87
 88        **Args:**
 89        - **batch** (`Any`): a batch of training data.
 90
 91        **Returns:**
 92        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 93        """
 94        x, y = batch
 95
 96        # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
 97        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
 98        loss_cls = self.criterion(logits, y)
 99
100        # regularisation loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520).
101        loss_reg = 0.0
102        for previous_task_id in range(1, self.task_id):
103            # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label."
104
105            # get the teacher logits for this batch, which is from the current model (to previous output head)
106            student_feature, _ = self.backbone(
107                x, stage="train", task_id=previous_task_id
108            )
109            with torch.no_grad():  # stop updating the previous heads
110                student_logits = self.heads(student_feature, task_id=previous_task_id)
111
112            # get the teacher logits for this batch, which is from the previous model
113            previous_backbone = self.previous_task_backbones[previous_task_id]
114            with torch.no_grad():  # stop updating the previous backbones and heads
115                teacher_feature, _ = previous_backbone(
116                    x, stage="test", task_id=previous_task_id
117                )
118
119                teacher_logits = self.heads(teacher_feature, task_id=previous_task_id)
120
121            loss_reg += self.distillation_reg(
122                student_logits=student_logits,
123                teacher_logits=teacher_logits,
124            )
125
126        if self.task_id != 1:
127            loss_reg /= (
128                self.task_id
129            )  # average over tasks to avoid linear increase of the regularisation loss
130
131        # total loss
132        loss = loss_cls + loss_reg
133
134        # accuracy of the batch
135        acc = (logits.argmax(dim=1) == y).float().mean()
136
137        return {
138            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
139            "loss_cls": loss_cls,
140            "loss_reg": loss_reg,
141            "acc": acc,
142            "hidden_features": hidden_features,
143        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Key (str) is the metrics name, value (Tensor) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def on_train_end(self) -> None:
145    def on_train_end(self) -> None:
146        r"""Store the backbone model after the training of a task.
147
148        The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks.
149        """
150        previous_backbone = deepcopy(self.backbone)
151        previous_backbone.eval()  # set the store model to evaluation mode to prevent updating
152        self.heads.heads[
153            f"{self.task_id}"
154        ].eval()  # set the store model to evaluation mode to prevent updating
155        self.previous_task_backbones[self.task_id] = previous_backbone

Store the backbone model after the training of a task.

The model is stored in self.previous_task_backbones for constructing the regularisation loss in the future tasks.