clarena.cl_algorithms.lwf
The submodule in cl_algorithms for LwF (Learning without Forgetting) algorithm.
1r""" 2The submodule in `cl_algorithms` for [LwF (Learning without Forgetting) algorithm](https://ieeexplore.ieee.org/abstract/document/8107520). 3""" 4 5__all__ = ["LwF", "AmnesiacLwF"] 6 7import logging 8from copy import deepcopy 9from typing import Any 10 11import torch 12import torch.nn.functional as F 13from torch import Tensor, nn 14 15from clarena.backbones import CLBackbone 16from clarena.cl_algorithms import AmnesiacCLAlgorithm, Finetuning 17from clarena.cl_algorithms.regularizers import DistillationReg 18from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL 19 20# always get logger for built-in logging in each module 21pylogger = logging.getLogger(__name__) 22 23 24class LwF(Finetuning): 25 r"""[LwF (Learning without Forgetting)](https://ieeexplore.ieee.org/abstract/document/8107520) algorithm. 26 27 A regularization-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularization term. It is a simple yet effective method for continual learning. 28 29 We implement LwF as a subclass of Finetuning algorithm, as LwF has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 30 """ 31 32 def __init__( 33 self, 34 backbone: CLBackbone, 35 heads: HeadsTIL | HeadsCIL | HeadDIL, 36 distillation_reg_factor: float, 37 distillation_reg_temperature: float, 38 non_algorithmic_hparams: dict[str, Any] = {}, 39 **kwargs, 40 ) -> None: 41 r"""Initialize the LwF algorithm with the network. 42 43 **Args:** 44 - **backbone** (`CLBackbone`): backbone network. 45 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. 46 - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularization factor. It controls the strength of preventing forgetting. 47 - **distillation_reg_temperature** (`float`): hyperparameter, the temperature in the distillation regularization. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting. 48 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 49 - **kwargs**: Reserved for multiple inheritance. 50 51 """ 52 super().__init__( 53 backbone=backbone, 54 heads=heads, 55 non_algorithmic_hparams=non_algorithmic_hparams, 56 **kwargs, 57 ) 58 59 self.previous_task_backbones: dict[str, nn.Module] = {} 60 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 61 62 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF. 63 """ 64 if isinstance(self.heads, HeadDIL): 65 self.previous_task_heads: dict[str, nn.Module] = {} 66 r"""The heads snapshot of the previous task (teacher). This is only used when the heads is `HeadDIL`, because in DIL scenario, all tasks share the same head, so we need to store the previous head for distillation; where in TIL scenario, each task has its own head, so we can directly use the head of the previous task without storing it separately.""" 67 68 self.distillation_reg_factor: float = distillation_reg_factor 69 r"""The distillation regularization factor.""" 70 self.distillation_reg_temperature: float = distillation_reg_temperature 71 r"""The distillation regularization temperature.""" 72 self.distillation_reg = DistillationReg( 73 factor=distillation_reg_factor, 74 temperature=distillation_reg_temperature, 75 distance="cross_entropy", 76 ) 77 r"""Initialize and store the distillation regularizer.""" 78 79 # save additional algorithmic hyperparameters 80 self.save_hyperparameters( 81 "distillation_reg_factor", 82 "distillation_reg_temperature", 83 ) 84 85 LwF.sanity_check(self) 86 87 def sanity_check(self) -> None: 88 r"""Sanity check.""" 89 90 if self.distillation_reg_factor <= 0: 91 raise ValueError( 92 f"The distillation regularization factor should be positive, but got {self.distillation_reg_factor}." 93 ) 94 95 if self.distillation_reg_temperature <= 0: 96 raise ValueError( 97 f"The distillation regularization temperature should be positive, but got {self.distillation_reg_temperature}." 98 ) 99 100 def training_step(self, batch: Any) -> dict[str, Tensor]: 101 r"""Training step for current task `self.task_id`. 102 103 **Args:** 104 - **batch** (`Any`): a batch of training data. 105 106 **Returns:** 107 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 108 """ 109 x, y = batch 110 111 # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) 112 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 113 loss_cls = self.criterion(logits, y) 114 115 # regularization loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) 116 distillation_reg = 0.0 117 for previous_task_id, previous_backbone in self.previous_task_backbones.items(): 118 # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label." 119 120 # get the student logits for this batch, using detached head params to avoid updating old heads 121 student_feature, _ = self.backbone( 122 x, stage="train", task_id=previous_task_id 123 ) 124 if isinstance(self.heads, HeadDIL): 125 head = self.heads.get_head() 126 elif isinstance(self.heads, HeadsTIL): 127 head = self.heads.get_head(previous_task_id) 128 else: 129 raise TypeError(f"Unsupported heads type {type(self.heads)} in LwF.") 130 student_logits = F.linear( 131 student_feature, 132 head.weight.detach(), 133 head.bias.detach() if head.bias is not None else None, 134 ) 135 136 # get the teacher logits for this batch, which is from the previous model 137 with torch.no_grad(): # stop updating the previous backbones and heads 138 teacher_feature, _ = previous_backbone( 139 x, stage="test", task_id=previous_task_id 140 ) 141 if isinstance(self.heads, HeadDIL): 142 previous_head = self.previous_task_heads[previous_task_id] 143 teacher_logits = previous_head(teacher_feature) 144 elif isinstance(self.heads, HeadsTIL): 145 teacher_logits = self.heads( 146 teacher_feature, task_id=previous_task_id 147 ) 148 else: 149 raise TypeError( 150 f"Unsupported heads type {type(self.heads)} in LwF." 151 ) 152 153 distillation_reg += self.distillation_reg( 154 student_logits=student_logits, 155 teacher_logits=teacher_logits, 156 ) 157 158 # do not average over tasks to avoid linear increase of the regularization loss. LwF paper doesn't mention this! 159 160 # total loss 161 loss = loss_cls + distillation_reg 162 163 # predicted labels 164 preds = logits.argmax(dim=1) 165 166 # accuracy of the batch 167 acc = (preds == y).float().mean() 168 169 return { 170 "preds": preds, 171 "loss": loss, # return loss is essential for training step, or backpropagation will fail 172 "loss_cls": loss_cls, 173 "distillation_reg": distillation_reg, 174 "acc": acc, 175 "activations": activations, 176 } 177 178 def on_train_end(self) -> None: 179 r"""Store the backbone model after the training of a task. 180 181 The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks. 182 """ 183 current_backbone = deepcopy(self.backbone) 184 current_backbone.eval() # set the store model to evaluation mode to prevent updating 185 if isinstance(self.heads, HeadDIL): 186 current_head = deepcopy(self.heads.get_head()) 187 current_head.eval() # set the store model to evaluation mode to prevent updating 188 self.heads.get_head( 189 self.task_id 190 ).eval() # set the store model to evaluation mode to prevent updating 191 self.previous_task_backbones[self.task_id] = current_backbone 192 if isinstance(self.heads, HeadDIL): 193 self.previous_task_heads[self.task_id] = current_head 194 195 196class AmnesiacLwF(AmnesiacCLAlgorithm, LwF): 197 r"""Amnesiac LwF algorithm.""" 198 199 def __init__( 200 self, 201 backbone: CLBackbone, 202 heads: HeadsTIL | HeadDIL, 203 non_algorithmic_hparams: dict[str, Any] = {}, 204 disable_unlearning: bool = False, 205 **kwargs, 206 ) -> None: 207 r"""Initialize the Amnesiac LwF algorithm with the network. 208 209 210 **Args:** 211 - **backbone** (`CLBackbone`): backbone network. 212 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. Currently this LwF supports Task-Incremental Learning (TIL) and Domain-Incremental Learning (DIL) scenarios. 213 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 214 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 215 - **kwargs**: Reserved for multiple inheritance. 216 """ 217 super().__init__( 218 backbone=backbone, 219 heads=heads, 220 non_algorithmic_hparams=non_algorithmic_hparams, 221 disable_unlearning=disable_unlearning, 222 **kwargs, 223 ) 224 225 def on_train_start(self) -> None: 226 """Record backbone parameters before training current task.""" 227 LwF.on_train_start(self) 228 AmnesiacCLAlgorithm.on_train_start(self) 229 230 def on_train_end(self) -> None: 231 """Record backbone parameters before training current task.""" 232 LwF.on_train_end(self) 233 AmnesiacCLAlgorithm.on_train_end(self)
25class LwF(Finetuning): 26 r"""[LwF (Learning without Forgetting)](https://ieeexplore.ieee.org/abstract/document/8107520) algorithm. 27 28 A regularization-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularization term. It is a simple yet effective method for continual learning. 29 30 We implement LwF as a subclass of Finetuning algorithm, as LwF has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 31 """ 32 33 def __init__( 34 self, 35 backbone: CLBackbone, 36 heads: HeadsTIL | HeadsCIL | HeadDIL, 37 distillation_reg_factor: float, 38 distillation_reg_temperature: float, 39 non_algorithmic_hparams: dict[str, Any] = {}, 40 **kwargs, 41 ) -> None: 42 r"""Initialize the LwF algorithm with the network. 43 44 **Args:** 45 - **backbone** (`CLBackbone`): backbone network. 46 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. 47 - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularization factor. It controls the strength of preventing forgetting. 48 - **distillation_reg_temperature** (`float`): hyperparameter, the temperature in the distillation regularization. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting. 49 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 50 - **kwargs**: Reserved for multiple inheritance. 51 52 """ 53 super().__init__( 54 backbone=backbone, 55 heads=heads, 56 non_algorithmic_hparams=non_algorithmic_hparams, 57 **kwargs, 58 ) 59 60 self.previous_task_backbones: dict[str, nn.Module] = {} 61 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 62 63 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF. 64 """ 65 if isinstance(self.heads, HeadDIL): 66 self.previous_task_heads: dict[str, nn.Module] = {} 67 r"""The heads snapshot of the previous task (teacher). This is only used when the heads is `HeadDIL`, because in DIL scenario, all tasks share the same head, so we need to store the previous head for distillation; where in TIL scenario, each task has its own head, so we can directly use the head of the previous task without storing it separately.""" 68 69 self.distillation_reg_factor: float = distillation_reg_factor 70 r"""The distillation regularization factor.""" 71 self.distillation_reg_temperature: float = distillation_reg_temperature 72 r"""The distillation regularization temperature.""" 73 self.distillation_reg = DistillationReg( 74 factor=distillation_reg_factor, 75 temperature=distillation_reg_temperature, 76 distance="cross_entropy", 77 ) 78 r"""Initialize and store the distillation regularizer.""" 79 80 # save additional algorithmic hyperparameters 81 self.save_hyperparameters( 82 "distillation_reg_factor", 83 "distillation_reg_temperature", 84 ) 85 86 LwF.sanity_check(self) 87 88 def sanity_check(self) -> None: 89 r"""Sanity check.""" 90 91 if self.distillation_reg_factor <= 0: 92 raise ValueError( 93 f"The distillation regularization factor should be positive, but got {self.distillation_reg_factor}." 94 ) 95 96 if self.distillation_reg_temperature <= 0: 97 raise ValueError( 98 f"The distillation regularization temperature should be positive, but got {self.distillation_reg_temperature}." 99 ) 100 101 def training_step(self, batch: Any) -> dict[str, Tensor]: 102 r"""Training step for current task `self.task_id`. 103 104 **Args:** 105 - **batch** (`Any`): a batch of training data. 106 107 **Returns:** 108 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 109 """ 110 x, y = batch 111 112 # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) 113 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 114 loss_cls = self.criterion(logits, y) 115 116 # regularization loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) 117 distillation_reg = 0.0 118 for previous_task_id, previous_backbone in self.previous_task_backbones.items(): 119 # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label." 120 121 # get the student logits for this batch, using detached head params to avoid updating old heads 122 student_feature, _ = self.backbone( 123 x, stage="train", task_id=previous_task_id 124 ) 125 if isinstance(self.heads, HeadDIL): 126 head = self.heads.get_head() 127 elif isinstance(self.heads, HeadsTIL): 128 head = self.heads.get_head(previous_task_id) 129 else: 130 raise TypeError(f"Unsupported heads type {type(self.heads)} in LwF.") 131 student_logits = F.linear( 132 student_feature, 133 head.weight.detach(), 134 head.bias.detach() if head.bias is not None else None, 135 ) 136 137 # get the teacher logits for this batch, which is from the previous model 138 with torch.no_grad(): # stop updating the previous backbones and heads 139 teacher_feature, _ = previous_backbone( 140 x, stage="test", task_id=previous_task_id 141 ) 142 if isinstance(self.heads, HeadDIL): 143 previous_head = self.previous_task_heads[previous_task_id] 144 teacher_logits = previous_head(teacher_feature) 145 elif isinstance(self.heads, HeadsTIL): 146 teacher_logits = self.heads( 147 teacher_feature, task_id=previous_task_id 148 ) 149 else: 150 raise TypeError( 151 f"Unsupported heads type {type(self.heads)} in LwF." 152 ) 153 154 distillation_reg += self.distillation_reg( 155 student_logits=student_logits, 156 teacher_logits=teacher_logits, 157 ) 158 159 # do not average over tasks to avoid linear increase of the regularization loss. LwF paper doesn't mention this! 160 161 # total loss 162 loss = loss_cls + distillation_reg 163 164 # predicted labels 165 preds = logits.argmax(dim=1) 166 167 # accuracy of the batch 168 acc = (preds == y).float().mean() 169 170 return { 171 "preds": preds, 172 "loss": loss, # return loss is essential for training step, or backpropagation will fail 173 "loss_cls": loss_cls, 174 "distillation_reg": distillation_reg, 175 "acc": acc, 176 "activations": activations, 177 } 178 179 def on_train_end(self) -> None: 180 r"""Store the backbone model after the training of a task. 181 182 The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks. 183 """ 184 current_backbone = deepcopy(self.backbone) 185 current_backbone.eval() # set the store model to evaluation mode to prevent updating 186 if isinstance(self.heads, HeadDIL): 187 current_head = deepcopy(self.heads.get_head()) 188 current_head.eval() # set the store model to evaluation mode to prevent updating 189 self.heads.get_head( 190 self.task_id 191 ).eval() # set the store model to evaluation mode to prevent updating 192 self.previous_task_backbones[self.task_id] = current_backbone 193 if isinstance(self.heads, HeadDIL): 194 self.previous_task_heads[self.task_id] = current_head
LwF (Learning without Forgetting) algorithm.
A regularization-based continual learning approach that constrains the feature output of the model to be similar to that of the previous tasks. From the perspective of knowledge distillation, it distills previous tasks models into the training process for new task in the regularization term. It is a simple yet effective method for continual learning.
We implement LwF as a subclass of Finetuning algorithm, as LwF has the same forward(), validation_step() and test_step() method as Finetuning class.
33 def __init__( 34 self, 35 backbone: CLBackbone, 36 heads: HeadsTIL | HeadsCIL | HeadDIL, 37 distillation_reg_factor: float, 38 distillation_reg_temperature: float, 39 non_algorithmic_hparams: dict[str, Any] = {}, 40 **kwargs, 41 ) -> None: 42 r"""Initialize the LwF algorithm with the network. 43 44 **Args:** 45 - **backbone** (`CLBackbone`): backbone network. 46 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. 47 - **distillation_reg_factor** (`float`): hyperparameter, the distillation regularization factor. It controls the strength of preventing forgetting. 48 - **distillation_reg_temperature** (`float`): hyperparameter, the temperature in the distillation regularization. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting. 49 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 50 - **kwargs**: Reserved for multiple inheritance. 51 52 """ 53 super().__init__( 54 backbone=backbone, 55 heads=heads, 56 non_algorithmic_hparams=non_algorithmic_hparams, 57 **kwargs, 58 ) 59 60 self.previous_task_backbones: dict[str, nn.Module] = {} 61 r"""Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a `nn.Module` backbone after the corresponding previous task was trained. 62 63 Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF. 64 """ 65 if isinstance(self.heads, HeadDIL): 66 self.previous_task_heads: dict[str, nn.Module] = {} 67 r"""The heads snapshot of the previous task (teacher). This is only used when the heads is `HeadDIL`, because in DIL scenario, all tasks share the same head, so we need to store the previous head for distillation; where in TIL scenario, each task has its own head, so we can directly use the head of the previous task without storing it separately.""" 68 69 self.distillation_reg_factor: float = distillation_reg_factor 70 r"""The distillation regularization factor.""" 71 self.distillation_reg_temperature: float = distillation_reg_temperature 72 r"""The distillation regularization temperature.""" 73 self.distillation_reg = DistillationReg( 74 factor=distillation_reg_factor, 75 temperature=distillation_reg_temperature, 76 distance="cross_entropy", 77 ) 78 r"""Initialize and store the distillation regularizer.""" 79 80 # save additional algorithmic hyperparameters 81 self.save_hyperparameters( 82 "distillation_reg_factor", 83 "distillation_reg_temperature", 84 ) 85 86 LwF.sanity_check(self)
Initialize the LwF algorithm with the network.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadDIL): output heads. - distillation_reg_factor (
float): hyperparameter, the distillation regularization factor. It controls the strength of preventing forgetting. - distillation_reg_temperature (
float): hyperparameter, the temperature in the distillation regularization. It controls the softness of the labels that the student model (here is the current model) learns from the teacher models (here are the previous models), thereby controlling the strength of the distillation. It controls the strength of preventing forgetting. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
Store the backbone models of the previous tasks. Keys are task IDs (string type) and values are the corresponding models. Each model is a nn.Module backbone after the corresponding previous task was trained.
Some would argue that since we could store the model of the previous tasks, why don't we test the task directly with the stored model, instead of doing the less easier LwF thing? The thing is, LwF only uses the model of the previous tasks to train current and future tasks, which aggregate them into a single model. Once the training of the task is done, the storage for those parameters can be released. However, this make the future tasks not able to use LwF anymore, which is a disadvantage for LwF.
88 def sanity_check(self) -> None: 89 r"""Sanity check.""" 90 91 if self.distillation_reg_factor <= 0: 92 raise ValueError( 93 f"The distillation regularization factor should be positive, but got {self.distillation_reg_factor}." 94 ) 95 96 if self.distillation_reg_temperature <= 0: 97 raise ValueError( 98 f"The distillation regularization temperature should be positive, but got {self.distillation_reg_temperature}." 99 )
Sanity check.
101 def training_step(self, batch: Any) -> dict[str, Tensor]: 102 r"""Training step for current task `self.task_id`. 103 104 **Args:** 105 - **batch** (`Any`): a batch of training data. 106 107 **Returns:** 108 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 109 """ 110 x, y = batch 111 112 # classification loss. See equation (1) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) 113 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 114 loss_cls = self.criterion(logits, y) 115 116 # regularization loss. See equation (2) (3) in the [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) 117 distillation_reg = 0.0 118 for previous_task_id, previous_backbone in self.previous_task_backbones.items(): 119 # sum over all previous models, because [LwF paper](https://ieeexplore.ieee.org/abstract/document/8107520) says: "If there are multiple old tasks, or if an old task is multi-label classification, we take the sum of the loss for each old task and label." 120 121 # get the student logits for this batch, using detached head params to avoid updating old heads 122 student_feature, _ = self.backbone( 123 x, stage="train", task_id=previous_task_id 124 ) 125 if isinstance(self.heads, HeadDIL): 126 head = self.heads.get_head() 127 elif isinstance(self.heads, HeadsTIL): 128 head = self.heads.get_head(previous_task_id) 129 else: 130 raise TypeError(f"Unsupported heads type {type(self.heads)} in LwF.") 131 student_logits = F.linear( 132 student_feature, 133 head.weight.detach(), 134 head.bias.detach() if head.bias is not None else None, 135 ) 136 137 # get the teacher logits for this batch, which is from the previous model 138 with torch.no_grad(): # stop updating the previous backbones and heads 139 teacher_feature, _ = previous_backbone( 140 x, stage="test", task_id=previous_task_id 141 ) 142 if isinstance(self.heads, HeadDIL): 143 previous_head = self.previous_task_heads[previous_task_id] 144 teacher_logits = previous_head(teacher_feature) 145 elif isinstance(self.heads, HeadsTIL): 146 teacher_logits = self.heads( 147 teacher_feature, task_id=previous_task_id 148 ) 149 else: 150 raise TypeError( 151 f"Unsupported heads type {type(self.heads)} in LwF." 152 ) 153 154 distillation_reg += self.distillation_reg( 155 student_logits=student_logits, 156 teacher_logits=teacher_logits, 157 ) 158 159 # do not average over tasks to avoid linear increase of the regularization loss. LwF paper doesn't mention this! 160 161 # total loss 162 loss = loss_cls + distillation_reg 163 164 # predicted labels 165 preds = logits.argmax(dim=1) 166 167 # accuracy of the batch 168 acc = (preds == y).float().mean() 169 170 return { 171 "preds": preds, 172 "loss": loss, # return loss is essential for training step, or backpropagation will fail 173 "loss_cls": loss_cls, 174 "distillation_reg": distillation_reg, 175 "acc": acc, 176 "activations": activations, 177 }
Training step for current task self.task_id.
Args:
- batch (
Any): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
179 def on_train_end(self) -> None: 180 r"""Store the backbone model after the training of a task. 181 182 The model is stored in `self.previous_task_backbones` for constructing the regularisation loss in the future tasks. 183 """ 184 current_backbone = deepcopy(self.backbone) 185 current_backbone.eval() # set the store model to evaluation mode to prevent updating 186 if isinstance(self.heads, HeadDIL): 187 current_head = deepcopy(self.heads.get_head()) 188 current_head.eval() # set the store model to evaluation mode to prevent updating 189 self.heads.get_head( 190 self.task_id 191 ).eval() # set the store model to evaluation mode to prevent updating 192 self.previous_task_backbones[self.task_id] = current_backbone 193 if isinstance(self.heads, HeadDIL): 194 self.previous_task_heads[self.task_id] = current_head
Store the backbone model after the training of a task.
The model is stored in self.previous_task_backbones for constructing the regularisation loss in the future tasks.
Inherited Members
197class AmnesiacLwF(AmnesiacCLAlgorithm, LwF): 198 r"""Amnesiac LwF algorithm.""" 199 200 def __init__( 201 self, 202 backbone: CLBackbone, 203 heads: HeadsTIL | HeadDIL, 204 non_algorithmic_hparams: dict[str, Any] = {}, 205 disable_unlearning: bool = False, 206 **kwargs, 207 ) -> None: 208 r"""Initialize the Amnesiac LwF algorithm with the network. 209 210 211 **Args:** 212 - **backbone** (`CLBackbone`): backbone network. 213 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. Currently this LwF supports Task-Incremental Learning (TIL) and Domain-Incremental Learning (DIL) scenarios. 214 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 215 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 216 - **kwargs**: Reserved for multiple inheritance. 217 """ 218 super().__init__( 219 backbone=backbone, 220 heads=heads, 221 non_algorithmic_hparams=non_algorithmic_hparams, 222 disable_unlearning=disable_unlearning, 223 **kwargs, 224 ) 225 226 def on_train_start(self) -> None: 227 """Record backbone parameters before training current task.""" 228 LwF.on_train_start(self) 229 AmnesiacCLAlgorithm.on_train_start(self) 230 231 def on_train_end(self) -> None: 232 """Record backbone parameters before training current task.""" 233 LwF.on_train_end(self) 234 AmnesiacCLAlgorithm.on_train_end(self)
Amnesiac LwF algorithm.
200 def __init__( 201 self, 202 backbone: CLBackbone, 203 heads: HeadsTIL | HeadDIL, 204 non_algorithmic_hparams: dict[str, Any] = {}, 205 disable_unlearning: bool = False, 206 **kwargs, 207 ) -> None: 208 r"""Initialize the Amnesiac LwF algorithm with the network. 209 210 211 **Args:** 212 - **backbone** (`CLBackbone`): backbone network. 213 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. Currently this LwF supports Task-Incremental Learning (TIL) and Domain-Incremental Learning (DIL) scenarios. 214 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 215 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 216 - **kwargs**: Reserved for multiple inheritance. 217 """ 218 super().__init__( 219 backbone=backbone, 220 heads=heads, 221 non_algorithmic_hparams=non_algorithmic_hparams, 222 disable_unlearning=disable_unlearning, 223 **kwargs, 224 )
Initialize the Amnesiac LwF algorithm with the network.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadDIL): output heads. Currently this LwF supports Task-Incremental Learning (TIL) and Domain-Incremental Learning (DIL) scenarios. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - disable_unlearning (
bool): whether to disable the unlearning functionality. Default isFalse. - kwargs: Reserved for multiple inheritance.
226 def on_train_start(self) -> None: 227 """Record backbone parameters before training current task.""" 228 LwF.on_train_start(self) 229 AmnesiacCLAlgorithm.on_train_start(self)
Record backbone parameters before training current task.