clarena.cl_algorithms.fix

The submodule in cl_algorithms for Fix algorithm.

 1r"""
 2The submodule in `cl_algorithms` for Fix algorithm.
 3"""
 4
 5__all__ = ["Fix"]
 6
 7import logging
 8from typing import Any
 9
10from torch import Tensor
11
12from clarena.backbones import CLBackbone
13from clarena.cl_algorithms import Finetuning
14from clarena.cl_heads import HeadsCIL, HeadsTIL
15
16# always get logger for built-in logging in each module
17pylogger = logging.getLogger(__name__)
18
19
20class Fix(Finetuning):
21    r"""Fix algorithm.
22
23    It is another naive way for task-incremental learning aside from Finetuning. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. It simply fixes the backbone forever after training first task.
24
25    We implement Fix as a subclass of Finetuning algorithm, as Fix has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
26    """
27
28    def __init__(
29        self,
30        backbone: CLBackbone,
31        heads: HeadsTIL | HeadsCIL,
32    ) -> None:
33        r"""Initialise the Fix algorithm with the network. It has no additional hyperparamaters.
34
35        **Args:**
36        - **backbone** (`CLBackbone`): backbone network.
37        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
38        """
39        Finetuning.__init__(self, backbone=backbone, heads=heads)
40
41    def training_step(self, batch: Any) -> dict[str, Tensor]:
42        """Training step for current task `self.task_id`.
43
44        **Args:**
45        - **batch** (`Any`): a batch of training data.
46
47        **Returns:**
48        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
49        """
50        x, y = batch
51
52        if self.task_id != 1:
53            # Fix the backbone after training the first task
54            for param in self.backbone.parameters():
55                param.requires_grad = False
56
57        # classification loss
58        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
59        loss_cls = self.criterion(logits, y)
60
61        # total loss
62        loss = loss_cls
63
64        # accuracy of the batch
65        acc = (logits.argmax(dim=1) == y).float().mean()
66
67        return {
68            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
69            "loss_cls": loss_cls,
70            "acc": acc,
71            "hidden_features": hidden_features,
72        }
21class Fix(Finetuning):
22    r"""Fix algorithm.
23
24    It is another naive way for task-incremental learning aside from Finetuning. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. It simply fixes the backbone forever after training first task.
25
26    We implement Fix as a subclass of Finetuning algorithm, as Fix has the same `forward()`, `validation_step()` and `test_step()` method as `Finetuning` class.
27    """
28
29    def __init__(
30        self,
31        backbone: CLBackbone,
32        heads: HeadsTIL | HeadsCIL,
33    ) -> None:
34        r"""Initialise the Fix algorithm with the network. It has no additional hyperparamaters.
35
36        **Args:**
37        - **backbone** (`CLBackbone`): backbone network.
38        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
39        """
40        Finetuning.__init__(self, backbone=backbone, heads=heads)
41
42    def training_step(self, batch: Any) -> dict[str, Tensor]:
43        """Training step for current task `self.task_id`.
44
45        **Args:**
46        - **batch** (`Any`): a batch of training data.
47
48        **Returns:**
49        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
50        """
51        x, y = batch
52
53        if self.task_id != 1:
54            # Fix the backbone after training the first task
55            for param in self.backbone.parameters():
56                param.requires_grad = False
57
58        # classification loss
59        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
60        loss_cls = self.criterion(logits, y)
61
62        # total loss
63        loss = loss_cls
64
65        # accuracy of the batch
66        acc = (logits.argmax(dim=1) == y).float().mean()
67
68        return {
69            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
70            "loss_cls": loss_cls,
71            "acc": acc,
72            "hidden_features": hidden_features,
73        }

Fix algorithm.

It is another naive way for task-incremental learning aside from Finetuning. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning. It simply fixes the backbone forever after training first task.

We implement Fix as a subclass of Finetuning algorithm, as Fix has the same forward(), validation_step() and test_step() method as Finetuning class.

29    def __init__(
30        self,
31        backbone: CLBackbone,
32        heads: HeadsTIL | HeadsCIL,
33    ) -> None:
34        r"""Initialise the Fix algorithm with the network. It has no additional hyperparamaters.
35
36        **Args:**
37        - **backbone** (`CLBackbone`): backbone network.
38        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
39        """
40        Finetuning.__init__(self, backbone=backbone, heads=heads)

Initialise the Fix algorithm with the network. It has no additional hyperparamaters.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL): output heads.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
42    def training_step(self, batch: Any) -> dict[str, Tensor]:
43        """Training step for current task `self.task_id`.
44
45        **Args:**
46        - **batch** (`Any`): a batch of training data.
47
48        **Returns:**
49        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
50        """
51        x, y = batch
52
53        if self.task_id != 1:
54            # Fix the backbone after training the first task
55            for param in self.backbone.parameters():
56                param.requires_grad = False
57
58        # classification loss
59        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
60        loss_cls = self.criterion(logits, y)
61
62        # total loss
63        loss = loss_cls
64
65        # accuracy of the batch
66        acc = (logits.argmax(dim=1) == y).float().mean()
67
68        return {
69            "loss": loss,  # Return loss is essential for training step, or backpropagation will fail
70            "loss_cls": loss_cls,
71            "acc": acc,
72            "hidden_features": hidden_features,
73        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Key (str) is the metrics name, value (Tensor) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.