clarena.cl_algorithms.lwf
The submodule in cl_algorithms
for LwF (Learning without Forgetting) algorithm.
1r""" 2The submodule in `cl_algorithms` for [LwF (Learning without Forgetting) algorithm](https://ieeexplore.ieee.org/abstract/document/8107520). 3""" 4 5__all__ = ["LwF"] 6 7import logging 8from copy import deepcopy 9from typing import Any 10 11import torch 12from torch import Tensor, nn 13 14from clarena.backbones import CLBackbone 15from clarena.cl_algorithms import Finetuning 16from clarena.cl_algorithms.regularisers import DistillationReg 17from clarena.cl_heads import HeadsCIL, HeadsTIL 18 19# always get logger for built-in logging in each module 20pylogger = logging.getLogger(__name__) 21 22 23class LwF(Finetuning): 24 r"""LwF (Learning without Forgetting) algorithm. 25 26 [LwF (Learning without Forgetting, 2017)](https://ieeexplore.ieee.org/abstract/document/8107520) is a regularisation-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularisation term. It is a simple yet effective method for continual learning. 27 28 We implement LwF as a subclass of Finetuning algorithm, as LwF has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 29 """ 30 31 def __init__( 32 self, 33 backbone: CLBackbone, 34 heads: HeadsTIL | HeadsCIL, 35 distillation_reg_factor: float, 36 distillation_reg_temparture: float, 37 ) -> None: 38 r"""Initialise the LwF algorithm with the network. 39 40 **Args:** 41 - **backbone** (`CLBackbone`): backbone network. 42 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 43 - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularisation factor. It controls the strength of preventing forgetting. 44 - **distillation_reg_temparture** (`float`): hyperparameter, the temperature in the distillation regularisation. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting. 45 """ 46 Finetuning.__init__(self, backbone=backbone, heads=heads) 47 48 self.previous_task_backbones: dict[str, nn.Module] = {} 49 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 50 51 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF. 52 """ 53 54 self.distillation_reg_factor = distillation_reg_factor 55 r"""Store distillation regularisation factor.""" 56 self.distillation_reg_temperature = distillation_reg_temparture 57 r"""Store distillation regularisation temperature.""" 58 self.distillation_reg = DistillationReg( 59 factor=distillation_reg_factor, 60 temperature=distillation_reg_temparture, 61 distance="cross_entropy", 62 ) 63 r"""Initialise and store the distillation regulariser.""" 64 65 LwF.sanity_check(self) 66 67 def sanity_check(self) -> None: 68 r"""Check the sanity of the arguments. 69 70 **Raises:** 71 - **ValueError**: If the regularisation factor and distillation temperature is not positive. 72 """ 73 74 if self.distillation_reg_factor <= 0: 75 raise ValueError( 76 f"The distillation regularisation factor should be positive, but got {self.distillation_reg_factor}." 77 ) 78 79 if self.distillation_reg_temperature <= 0: 80 raise ValueError( 81 f"The distillation regularisation temperature should be positive, but got {self.distillation_reg_temperature}." 82 ) 83 84 def training_step(self, batch: Any) -> dict[str, Tensor]: 85 r"""Training step for current task `self.task_id`. 86 87 **Args:** 88 - **batch** (`Any`): a batch of training data. 89 90 **Returns:** 91 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 92 """ 93 x, y = batch 94 95 # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 96 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 97 loss_cls = self.criterion(logits, y) 98 99 # regularisation loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 100 loss_reg = 0.0 101 for previous_task_id in range(1, self.task_id): 102 # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label." 103 104 # get the teacher logits for this batch, which is from the current model (to previous output head) 105 student_feature, _ = self.backbone( 106 x, stage="train", task_id=previous_task_id 107 ) 108 with torch.no_grad(): # stop updating the previous heads 109 student_logits = self.heads(student_feature, task_id=previous_task_id) 110 111 # get the teacher logits for this batch, which is from the previous model 112 previous_backbone = self.previous_task_backbones[previous_task_id] 113 with torch.no_grad(): # stop updating the previous backbones and heads 114 teacher_feature, _ = previous_backbone( 115 x, stage="test", task_id=previous_task_id 116 ) 117 118 teacher_logits = self.heads(teacher_feature, task_id=previous_task_id) 119 120 loss_reg += self.distillation_reg( 121 student_logits=student_logits, 122 teacher_logits=teacher_logits, 123 ) 124 125 if self.task_id != 1: 126 loss_reg /= ( 127 self.task_id 128 ) # average over tasks to avoid linear increase of the regularisation loss 129 130 # total loss 131 loss = loss_cls + loss_reg 132 133 # accuracy of the batch 134 acc = (logits.argmax(dim=1) == y).float().mean() 135 136 return { 137 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 138 "loss_cls": loss_cls, 139 "loss_reg": loss_reg, 140 "acc": acc, 141 "hidden_features": hidden_features, 142 } 143 144 def on_train_end(self) -> None: 145 r"""Store the backbone model after the training of a task. 146 147 The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks. 148 """ 149 previous_backbone = deepcopy(self.backbone) 150 previous_backbone.eval() # set the store model to evaluation mode to prevent updating 151 self.heads.heads[ 152 f"{self.task_id}" 153 ].eval() # set the store model to evaluation mode to prevent updating 154 self.previous_task_backbones[self.task_id] = previous_backbone
24class LwF(Finetuning): 25 r"""LwF (Learning without Forgetting) algorithm. 26 27 [LwF (Learning without Forgetting, 2017)](https://ieeexplore.ieee.org/abstract/document/8107520) is a regularisation-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularisation term. It is a simple yet effective method for continual learning. 28 29 We implement LwF as a subclass of Finetuning algorithm, as LwF has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 30 """ 31 32 def __init__( 33 self, 34 backbone: CLBackbone, 35 heads: HeadsTIL | HeadsCIL, 36 distillation_reg_factor: float, 37 distillation_reg_temparture: float, 38 ) -> None: 39 r"""Initialise the LwF algorithm with the network. 40 41 **Args:** 42 - **backbone** (`CLBackbone`): backbone network. 43 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 44 - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularisation factor. It controls the strength of preventing forgetting. 45 - **distillation_reg_temparture** (`float`): hyperparameter, the temperature in the distillation regularisation. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting. 46 """ 47 Finetuning.__init__(self, backbone=backbone, heads=heads) 48 49 self.previous_task_backbones: dict[str, nn.Module] = {} 50 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 51 52 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF. 53 """ 54 55 self.distillation_reg_factor = distillation_reg_factor 56 r"""Store distillation regularisation factor.""" 57 self.distillation_reg_temperature = distillation_reg_temparture 58 r"""Store distillation regularisation temperature.""" 59 self.distillation_reg = DistillationReg( 60 factor=distillation_reg_factor, 61 temperature=distillation_reg_temparture, 62 distance="cross_entropy", 63 ) 64 r"""Initialise and store the distillation regulariser.""" 65 66 LwF.sanity_check(self) 67 68 def sanity_check(self) -> None: 69 r"""Check the sanity of the arguments. 70 71 **Raises:** 72 - **ValueError**: If the regularisation factor and distillation temperature is not positive. 73 """ 74 75 if self.distillation_reg_factor <= 0: 76 raise ValueError( 77 f"The distillation regularisation factor should be positive, but got {self.distillation_reg_factor}." 78 ) 79 80 if self.distillation_reg_temperature <= 0: 81 raise ValueError( 82 f"The distillation regularisation temperature should be positive, but got {self.distillation_reg_temperature}." 83 ) 84 85 def training_step(self, batch: Any) -> dict[str, Tensor]: 86 r"""Training step for current task `self.task_id`. 87 88 **Args:** 89 - **batch** (`Any`): a batch of training data. 90 91 **Returns:** 92 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 93 """ 94 x, y = batch 95 96 # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 97 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 98 loss_cls = self.criterion(logits, y) 99 100 # regularisation loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 101 loss_reg = 0.0 102 for previous_task_id in range(1, self.task_id): 103 # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label." 104 105 # get the teacher logits for this batch, which is from the current model (to previous output head) 106 student_feature, _ = self.backbone( 107 x, stage="train", task_id=previous_task_id 108 ) 109 with torch.no_grad(): # stop updating the previous heads 110 student_logits = self.heads(student_feature, task_id=previous_task_id) 111 112 # get the teacher logits for this batch, which is from the previous model 113 previous_backbone = self.previous_task_backbones[previous_task_id] 114 with torch.no_grad(): # stop updating the previous backbones and heads 115 teacher_feature, _ = previous_backbone( 116 x, stage="test", task_id=previous_task_id 117 ) 118 119 teacher_logits = self.heads(teacher_feature, task_id=previous_task_id) 120 121 loss_reg += self.distillation_reg( 122 student_logits=student_logits, 123 teacher_logits=teacher_logits, 124 ) 125 126 if self.task_id != 1: 127 loss_reg /= ( 128 self.task_id 129 ) # average over tasks to avoid linear increase of the regularisation loss 130 131 # total loss 132 loss = loss_cls + loss_reg 133 134 # accuracy of the batch 135 acc = (logits.argmax(dim=1) == y).float().mean() 136 137 return { 138 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 139 "loss_cls": loss_cls, 140 "loss_reg": loss_reg, 141 "acc": acc, 142 "hidden_features": hidden_features, 143 } 144 145 def on_train_end(self) -> None: 146 r"""Store the backbone model after the training of a task. 147 148 The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks. 149 """ 150 previous_backbone = deepcopy(self.backbone) 151 previous_backbone.eval() # set the store model to evaluation mode to prevent updating 152 self.heads.heads[ 153 f"{self.task_id}" 154 ].eval() # set the store model to evaluation mode to prevent updating 155 self.previous_task_backbones[self.task_id] = previous_backbone
LwF (Learning without Forgetting) algorithm.
LwF (Learning without Forgetting, 2017) is a regularisation-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularisation term. It is a simple yet effective method for continual learning.
We implement LwF as a subclass of Finetuning algorithm, as LwF has the same forward()
, validation_step()
and test_step()
method as Finetuning
class.
32 def __init__( 33 self, 34 backbone: CLBackbone, 35 heads: HeadsTIL | HeadsCIL, 36 distillation_reg_factor: float, 37 distillation_reg_temparture: float, 38 ) -> None: 39 r"""Initialise the LwF algorithm with the network. 40 41 **Args:** 42 - **backbone** (`CLBackbone`): backbone network. 43 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 44 - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularisation factor. It controls the strength of preventing forgetting. 45 - **distillation_reg_temparture** (`float`): hyperparameter, the temperature in the distillation regularisation. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting. 46 """ 47 Finetuning.__init__(self, backbone=backbone, heads=heads) 48 49 self.previous_task_backbones: dict[str, nn.Module] = {} 50 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 51 52 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF. 53 """ 54 55 self.distillation_reg_factor = distillation_reg_factor 56 r"""Store distillation regularisation factor.""" 57 self.distillation_reg_temperature = distillation_reg_temparture 58 r"""Store distillation regularisation temperature.""" 59 self.distillation_reg = DistillationReg( 60 factor=distillation_reg_factor, 61 temperature=distillation_reg_temparture, 62 distance="cross_entropy", 63 ) 64 r"""Initialise and store the distillation regulariser.""" 65 66 LwF.sanity_check(self)
Initialise the LwF algorithm with the network.
Args:
- backbone (
CLBackbone
): backbone network. - heads (
HeadsTIL
|HeadsCIL
): output heads. - distillation_reg_factor (
float
): hyperparameter, the distillation regularisation factor. It controls the strength of preventing forgetting. - distillation_reg_temparture (
float
): hyperparameter, the temperature in the distillation regularisation. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting.
Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a nn.Module
backbone after the corresponding previous task was trained.
Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.
68 def sanity_check(self) -> None: 69 r"""Check the sanity of the arguments. 70 71 **Raises:** 72 - **ValueError**: If the regularisation factor and distillation temperature is not positive. 73 """ 74 75 if self.distillation_reg_factor <= 0: 76 raise ValueError( 77 f"The distillation regularisation factor should be positive, but got {self.distillation_reg_factor}." 78 ) 79 80 if self.distillation_reg_temperature <= 0: 81 raise ValueError( 82 f"The distillation regularisation temperature should be positive, but got {self.distillation_reg_temperature}." 83 )
Check the sanity of the arguments.
Raises:
- ValueError: If the regularisation factor and distillation temperature is not positive.
85 def training_step(self, batch: Any) -> dict[str, Tensor]: 86 r"""Training step for current task `self.task_id`. 87 88 **Args:** 89 - **batch** (`Any`): a batch of training data. 90 91 **Returns:** 92 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 93 """ 94 x, y = batch 95 96 # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 97 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 98 loss_cls = self.criterion(logits, y) 99 100 # regularisation loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520). 101 loss_reg = 0.0 102 for previous_task_id in range(1, self.task_id): 103 # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label." 104 105 # get the teacher logits for this batch, which is from the current model (to previous output head) 106 student_feature, _ = self.backbone( 107 x, stage="train", task_id=previous_task_id 108 ) 109 with torch.no_grad(): # stop updating the previous heads 110 student_logits = self.heads(student_feature, task_id=previous_task_id) 111 112 # get the teacher logits for this batch, which is from the previous model 113 previous_backbone = self.previous_task_backbones[previous_task_id] 114 with torch.no_grad(): # stop updating the previous backbones and heads 115 teacher_feature, _ = previous_backbone( 116 x, stage="test", task_id=previous_task_id 117 ) 118 119 teacher_logits = self.heads(teacher_feature, task_id=previous_task_id) 120 121 loss_reg += self.distillation_reg( 122 student_logits=student_logits, 123 teacher_logits=teacher_logits, 124 ) 125 126 if self.task_id != 1: 127 loss_reg /= ( 128 self.task_id 129 ) # average over tasks to avoid linear increase of the regularisation loss 130 131 # total loss 132 loss = loss_cls + loss_reg 133 134 # accuracy of the batch 135 acc = (logits.argmax(dim=1) == y).float().mean() 136 137 return { 138 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 139 "loss_cls": loss_cls, 140 "loss_reg": loss_reg, 141 "acc": acc, 142 "hidden_features": hidden_features, 143 }
Training step for current task self.task_id
.
Args:
- batch (
Any
): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this training step. Key (str
) is the metrics name, value (Tensor
) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
145 def on_train_end(self) -> None: 146 r"""Store the backbone model after the training of a task. 147 148 The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks. 149 """ 150 previous_backbone = deepcopy(self.backbone) 151 previous_backbone.eval() # set the store model to evaluation mode to prevent updating 152 self.heads.heads[ 153 f"{self.task_id}" 154 ].eval() # set the store model to evaluation mode to prevent updating 155 self.previous_task_backbones[self.task_id] = previous_backbone
Store the backbone model after the training of a task.
The model is stored in self.previous_task_backbones
for constructing the regularisation loss in the future tasks.