clarena.cl_algorithms.fix
The submodule in cl_algorithms for Fix algorithm.
1r""" 2The submodule in `cl_algorithms` for Fix algorithm. 3""" 4 5__all__ = ["Fix"] 6 7import logging 8from typing import Any 9 10from torch import Tensor 11 12from clarena.backbones import CLBackbone 13from clarena.cl_algorithms import Finetuning 14from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class Fix(Finetuning): 21 r"""Fix algorithm. 22 23 Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. 24 25 We implement `Fix` as a subclass of `Finetuning`, as it shares `forward()`, `validation_step()`, and `test_step()` with `Finetuning`. 26 """ 27 28 def __init__( 29 self, 30 backbone: CLBackbone, 31 heads: HeadsTIL | HeadsCIL | HeadDIL, 32 non_algorithmic_hparams: dict[str, Any] = {}, 33 **kwargs, 34 ) -> None: 35 r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters. 36 37 **Args:** 38 - **backbone** (`CLBackbone`): backbone network. 39 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 40 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 41 - **kwargs**: Reserved for multiple inheritance. 42 43 """ 44 super().__init__( 45 backbone=backbone, 46 heads=heads, 47 non_algorithmic_hparams=non_algorithmic_hparams, 48 **kwargs, 49 ) 50 51 # freeze only once after task 1 52 self._backbone_frozen: bool = False 53 54 def training_step(self, batch: Any) -> dict[str, Tensor]: 55 """Training step for current task `self.task_id`. 56 57 **Args:** 58 - **batch** (`Any`): a batch of training data. 59 60 **Returns:** 61 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 62 """ 63 x, y = batch 64 65 if self.task_id != 1: 66 # freeze the backbone once after the first task; also stop BN/Dropout updates 67 if not self._backbone_frozen: 68 for p in self.backbone.parameters(): 69 p.requires_grad = False 70 self.backbone.eval() 71 self._backbone_frozen = True 72 pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).") 73 else: 74 # ensure backbone is trainable during the first task 75 self.backbone.train() 76 77 # classification loss 78 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 79 loss_cls = self.criterion(logits, y) 80 81 # total loss 82 loss = loss_cls 83 84 # predicted labels 85 preds = logits.argmax(dim=1) 86 87 # accuracy of the batch 88 acc = (preds == y).float().mean() 89 90 return { 91 "preds": preds, 92 "loss": loss, # return loss is essential for training step, or backpropagation will fail 93 "loss_cls": loss_cls, 94 "acc": acc, 95 "activations": activations, 96 }
21class Fix(Finetuning): 22 r"""Fix algorithm. 23 24 Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. 25 26 We implement `Fix` as a subclass of `Finetuning`, as it shares `forward()`, `validation_step()`, and `test_step()` with `Finetuning`. 27 """ 28 29 def __init__( 30 self, 31 backbone: CLBackbone, 32 heads: HeadsTIL | HeadsCIL | HeadDIL, 33 non_algorithmic_hparams: dict[str, Any] = {}, 34 **kwargs, 35 ) -> None: 36 r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters. 37 38 **Args:** 39 - **backbone** (`CLBackbone`): backbone network. 40 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 41 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 42 - **kwargs**: Reserved for multiple inheritance. 43 44 """ 45 super().__init__( 46 backbone=backbone, 47 heads=heads, 48 non_algorithmic_hparams=non_algorithmic_hparams, 49 **kwargs, 50 ) 51 52 # freeze only once after task 1 53 self._backbone_frozen: bool = False 54 55 def training_step(self, batch: Any) -> dict[str, Tensor]: 56 """Training step for current task `self.task_id`. 57 58 **Args:** 59 - **batch** (`Any`): a batch of training data. 60 61 **Returns:** 62 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 63 """ 64 x, y = batch 65 66 if self.task_id != 1: 67 # freeze the backbone once after the first task; also stop BN/Dropout updates 68 if not self._backbone_frozen: 69 for p in self.backbone.parameters(): 70 p.requires_grad = False 71 self.backbone.eval() 72 self._backbone_frozen = True 73 pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).") 74 else: 75 # ensure backbone is trainable during the first task 76 self.backbone.train() 77 78 # classification loss 79 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 80 loss_cls = self.criterion(logits, y) 81 82 # total loss 83 loss = loss_cls 84 85 # predicted labels 86 preds = logits.argmax(dim=1) 87 88 # accuracy of the batch 89 acc = (preds == y).float().mean() 90 91 return { 92 "preds": preds, 93 "loss": loss, # return loss is essential for training step, or backpropagation will fail 94 "loss_cls": loss_cls, 95 "acc": acc, 96 "activations": activations, 97 }
Fix algorithm.
Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning.
We implement Fix as a subclass of Finetuning, as it shares forward(), validation_step(), and test_step() with Finetuning.
Fix( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
29 def __init__( 30 self, 31 backbone: CLBackbone, 32 heads: HeadsTIL | HeadsCIL | HeadDIL, 33 non_algorithmic_hparams: dict[str, Any] = {}, 34 **kwargs, 35 ) -> None: 36 r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters. 37 38 **Args:** 39 - **backbone** (`CLBackbone`): backbone network. 40 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 41 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 42 - **kwargs**: Reserved for multiple inheritance. 43 44 """ 45 super().__init__( 46 backbone=backbone, 47 heads=heads, 48 non_algorithmic_hparams=non_algorithmic_hparams, 49 **kwargs, 50 ) 51 52 # freeze only once after task 1 53 self._backbone_frozen: bool = False
Initialize the Fix algorithm with the network. It has no additional hyperparameters.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL|HeadDIL): output heads. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
def
training_step(self, batch: Any) -> dict[str, torch.Tensor]:
55 def training_step(self, batch: Any) -> dict[str, Tensor]: 56 """Training step for current task `self.task_id`. 57 58 **Args:** 59 - **batch** (`Any`): a batch of training data. 60 61 **Returns:** 62 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 63 """ 64 x, y = batch 65 66 if self.task_id != 1: 67 # freeze the backbone once after the first task; also stop BN/Dropout updates 68 if not self._backbone_frozen: 69 for p in self.backbone.parameters(): 70 p.requires_grad = False 71 self.backbone.eval() 72 self._backbone_frozen = True 73 pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).") 74 else: 75 # ensure backbone is trainable during the first task 76 self.backbone.train() 77 78 # classification loss 79 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 80 loss_cls = self.criterion(logits, y) 81 82 # total loss 83 loss = loss_cls 84 85 # predicted labels 86 preds = logits.argmax(dim=1) 87 88 # accuracy of the batch 89 acc = (preds == y).float().mean() 90 91 return { 92 "preds": preds, 93 "loss": loss, # return loss is essential for training step, or backpropagation will fail 94 "loss_cls": loss_cls, 95 "acc": acc, 96 "activations": activations, 97 }
Training step for current task self.task_id.
Args:
- batch (
Any): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.