clarena.cl_algorithms.fix
The submodule in cl_algorithms for Fix algorithm.
1r""" 2The submodule in `cl_algorithms` for Fix algorithm. 3""" 4 5__all__ = ["Fix"] 6 7import logging 8from typing import Any 9 10from torch import Tensor 11 12from clarena.backbones import CLBackbone 13from clarena.cl_algorithms import Finetuning 14from clarena.heads import HeadsCIL, HeadsTIL 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class Fix(Finetuning): 21 r"""Fix algorithm. 22 23 Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. 24 25 We implement `Fix` as a subclass of `Finetuning`, as it shares `forward()`, `validation_step()`, and `test_step()` with `Finetuning`. 26 """ 27 28 def __init__( 29 self, 30 backbone: CLBackbone, 31 heads: HeadsTIL | HeadsCIL, 32 non_algorithmic_hparams: dict[str, Any] = {}, 33 ) -> None: 34 r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters. 35 36 **Args:** 37 - **backbone** (`CLBackbone`): backbone network. 38 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 39 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 40 41 """ 42 super().__init__( 43 backbone=backbone, 44 heads=heads, 45 non_algorithmic_hparams=non_algorithmic_hparams, 46 ) 47 48 # freeze only once after task 1 49 self._backbone_frozen: bool = False 50 51 def training_step(self, batch: Any) -> dict[str, Tensor]: 52 """Training step for current task `self.task_id`. 53 54 **Args:** 55 - **batch** (`Any`): a batch of training data. 56 57 **Returns:** 58 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 59 """ 60 x, y = batch 61 62 if self.task_id != 1: 63 # freeze the backbone once after the first task; also stop BN/Dropout updates 64 if not self._backbone_frozen: 65 for p in self.backbone.parameters(): 66 p.requires_grad = False 67 self.backbone.eval() 68 self._backbone_frozen = True 69 pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).") 70 else: 71 # ensure backbone is trainable during the first task 72 self.backbone.train() 73 74 # classification loss 75 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 76 loss_cls = self.criterion(logits, y) 77 78 # total loss 79 loss = loss_cls 80 81 # accuracy of the batch 82 acc = (logits.argmax(dim=1) == y).float().mean() 83 84 return { 85 "loss": loss, # return loss is essential for training step, or backpropagation will fail 86 "loss_cls": loss_cls, 87 "acc": acc, 88 "activations": activations, 89 }
21class Fix(Finetuning): 22 r"""Fix algorithm. 23 24 Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. 25 26 We implement `Fix` as a subclass of `Finetuning`, as it shares `forward()`, `validation_step()`, and `test_step()` with `Finetuning`. 27 """ 28 29 def __init__( 30 self, 31 backbone: CLBackbone, 32 heads: HeadsTIL | HeadsCIL, 33 non_algorithmic_hparams: dict[str, Any] = {}, 34 ) -> None: 35 r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters. 36 37 **Args:** 38 - **backbone** (`CLBackbone`): backbone network. 39 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 40 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 41 42 """ 43 super().__init__( 44 backbone=backbone, 45 heads=heads, 46 non_algorithmic_hparams=non_algorithmic_hparams, 47 ) 48 49 # freeze only once after task 1 50 self._backbone_frozen: bool = False 51 52 def training_step(self, batch: Any) -> dict[str, Tensor]: 53 """Training step for current task `self.task_id`. 54 55 **Args:** 56 - **batch** (`Any`): a batch of training data. 57 58 **Returns:** 59 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 60 """ 61 x, y = batch 62 63 if self.task_id != 1: 64 # freeze the backbone once after the first task; also stop BN/Dropout updates 65 if not self._backbone_frozen: 66 for p in self.backbone.parameters(): 67 p.requires_grad = False 68 self.backbone.eval() 69 self._backbone_frozen = True 70 pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).") 71 else: 72 # ensure backbone is trainable during the first task 73 self.backbone.train() 74 75 # classification loss 76 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 77 loss_cls = self.criterion(logits, y) 78 79 # total loss 80 loss = loss_cls 81 82 # accuracy of the batch 83 acc = (logits.argmax(dim=1) == y).float().mean() 84 85 return { 86 "loss": loss, # return loss is essential for training step, or backpropagation will fail 87 "loss_cls": loss_cls, 88 "acc": acc, 89 "activations": activations, 90 }
Fix algorithm.
Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning.
We implement Fix as a subclass of Finetuning, as it shares forward(), validation_step(), and test_step() with Finetuning.
Fix( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL, non_algorithmic_hparams: dict[str, typing.Any] = {})
29 def __init__( 30 self, 31 backbone: CLBackbone, 32 heads: HeadsTIL | HeadsCIL, 33 non_algorithmic_hparams: dict[str, Any] = {}, 34 ) -> None: 35 r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters. 36 37 **Args:** 38 - **backbone** (`CLBackbone`): backbone network. 39 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 40 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 41 42 """ 43 super().__init__( 44 backbone=backbone, 45 heads=heads, 46 non_algorithmic_hparams=non_algorithmic_hparams, 47 ) 48 49 # freeze only once after task 1 50 self._backbone_frozen: bool = False
Initialize the Fix algorithm with the network. It has no additional hyperparameters.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL): output heads. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility.
def
training_step(self, batch: Any) -> dict[str, torch.Tensor]:
52 def training_step(self, batch: Any) -> dict[str, Tensor]: 53 """Training step for current task `self.task_id`. 54 55 **Args:** 56 - **batch** (`Any`): a batch of training data. 57 58 **Returns:** 59 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 60 """ 61 x, y = batch 62 63 if self.task_id != 1: 64 # freeze the backbone once after the first task; also stop BN/Dropout updates 65 if not self._backbone_frozen: 66 for p in self.backbone.parameters(): 67 p.requires_grad = False 68 self.backbone.eval() 69 self._backbone_frozen = True 70 pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).") 71 else: 72 # ensure backbone is trainable during the first task 73 self.backbone.train() 74 75 # classification loss 76 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 77 loss_cls = self.criterion(logits, y) 78 79 # total loss 80 loss = loss_cls 81 82 # accuracy of the batch 83 acc = (logits.argmax(dim=1) == y).float().mean() 84 85 return { 86 "loss": loss, # return loss is essential for training step, or backpropagation will fail 87 "loss_cls": loss_cls, 88 "acc": acc, 89 "activations": activations, 90 }
Training step for current task self.task_id.
Args:
- batch (
Any): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.