clarena.cl_algorithms.fix
The submodule in cl_algorithms
for Fix algorithm.
1r""" 2The submodule in `cl_algorithms` for Fix algorithm. 3""" 4 5__all__ = ["Fix"] 6 7import logging 8from typing import Any 9 10from torch import Tensor 11 12from clarena.backbones import CLBackbone 13from clarena.cl_algorithms import Finetuning 14from clarena.cl_heads import HeadsCIL, HeadsTIL 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class Fix(Finetuning): 21 r"""Fix algorithm. 22 23 It is another naive way for task-incremental learning aside from Finetuning. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. It simply fixes the backbone forever after training first task. 24 25 We implement Fix as a subclass of Finetuning algorithm, as Fix has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 26 """ 27 28 def __init__( 29 self, 30 backbone: CLBackbone, 31 heads: HeadsTIL | HeadsCIL, 32 ) -> None: 33 r"""Initialise the Fix algorithm with the network. It has no additional hyperparamaters. 34 35 **Args:** 36 - **backbone** (`CLBackbone`): backbone network. 37 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 38 """ 39 Finetuning.__init__(self, backbone=backbone, heads=heads) 40 41 def training_step(self, batch: Any) -> dict[str, Tensor]: 42 """Training step for current task `self.task_id`. 43 44 **Args:** 45 - **batch** (`Any`): a batch of training data. 46 47 **Returns:** 48 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 49 """ 50 x, y = batch 51 52 if self.task_id != 1: 53 # Fix the backbone after training the first task 54 for param in self.backbone.parameters(): 55 param.requires_grad = False 56 57 # classification loss 58 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 59 loss_cls = self.criterion(logits, y) 60 61 # total loss 62 loss = loss_cls 63 64 # accuracy of the batch 65 acc = (logits.argmax(dim=1) == y).float().mean() 66 67 return { 68 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 69 "loss_cls": loss_cls, 70 "acc": acc, 71 "hidden_features": hidden_features, 72 }
21class Fix(Finetuning): 22 r"""Fix algorithm. 23 24 It is another naive way for task-incremental learning aside from Finetuning. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. It simply fixes the backbone forever after training first task. 25 26 We implement Fix as a subclass of Finetuning algorithm, as Fix has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class. 27 """ 28 29 def __init__( 30 self, 31 backbone: CLBackbone, 32 heads: HeadsTIL | HeadsCIL, 33 ) -> None: 34 r"""Initialise the Fix algorithm with the network. It has no additional hyperparamaters. 35 36 **Args:** 37 - **backbone** (`CLBackbone`): backbone network. 38 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 39 """ 40 Finetuning.__init__(self, backbone=backbone, heads=heads) 41 42 def training_step(self, batch: Any) -> dict[str, Tensor]: 43 """Training step for current task `self.task_id`. 44 45 **Args:** 46 - **batch** (`Any`): a batch of training data. 47 48 **Returns:** 49 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 50 """ 51 x, y = batch 52 53 if self.task_id != 1: 54 # Fix the backbone after training the first task 55 for param in self.backbone.parameters(): 56 param.requires_grad = False 57 58 # classification loss 59 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 60 loss_cls = self.criterion(logits, y) 61 62 # total loss 63 loss = loss_cls 64 65 # accuracy of the batch 66 acc = (logits.argmax(dim=1) == y).float().mean() 67 68 return { 69 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 70 "loss_cls": loss_cls, 71 "acc": acc, 72 "hidden_features": hidden_features, 73 }
Fix algorithm.
It is another naive way for task-incremental learning aside from Finetuning. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. It simply fixes the backbone forever after training first task.
We implement Fix as a subclass of Finetuning algorithm, as Fix has the same forward()
, validation_step()
and test_step()
method as Finetuning
class.
Fix( backbone: clarena.backbones.CLBackbone, heads: clarena.cl_heads.HeadsTIL | clarena.cl_heads.HeadsCIL)
29 def __init__( 30 self, 31 backbone: CLBackbone, 32 heads: HeadsTIL | HeadsCIL, 33 ) -> None: 34 r"""Initialise the Fix algorithm with the network. It has no additional hyperparamaters. 35 36 **Args:** 37 - **backbone** (`CLBackbone`): backbone network. 38 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 39 """ 40 Finetuning.__init__(self, backbone=backbone, heads=heads)
Initialise the Fix algorithm with the network. It has no additional hyperparamaters.
Args:
- backbone (
CLBackbone
): backbone network. - heads (
HeadsTIL
|HeadsCIL
): output heads.
def
training_step(self, batch: Any) -> dict[str, torch.Tensor]:
42 def training_step(self, batch: Any) -> dict[str, Tensor]: 43 """Training step for current task `self.task_id`. 44 45 **Args:** 46 - **batch** (`Any`): a batch of training data. 47 48 **Returns:** 49 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 50 """ 51 x, y = batch 52 53 if self.task_id != 1: 54 # Fix the backbone after training the first task 55 for param in self.backbone.parameters(): 56 param.requires_grad = False 57 58 # classification loss 59 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 60 loss_cls = self.criterion(logits, y) 61 62 # total loss 63 loss = loss_cls 64 65 # accuracy of the batch 66 acc = (logits.argmax(dim=1) == y).float().mean() 67 68 return { 69 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 70 "loss_cls": loss_cls, 71 "acc": acc, 72 "hidden_features": hidden_features, 73 }
Training step for current task self.task_id
.
Args:
- batch (
Any
): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this training step. Key (str
) is the metrics name, value (Tensor
) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.