clarena.cl_algorithms.fix

The submodule in cl_algorithms for Fix algorithm.

 1r"""
 2The submodule in `cl_algorithms` for Fix algorithm.
 3"""
 4
 5__all__ = ["Fix"]
 6
 7import logging
 8from typing import Any
 9
10from torch import Tensor
11
12from clarena.backbones import CLBackbone
13from clarena.cl_algorithms import Finetuning
14from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL
15
16# always get logger for built-in logging in each module
17pylogger = logging.getLogger(__name__)
18
19
20class Fix(Finetuning):
21    r"""Fix algorithm.
22
23    Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning.
24
25    We implement `Fix` as a subclass of `Finetuning`, as it shares `forward()`, `validation_step()`, and `test_step()` with `Finetuning`.
26    """
27
28    def __init__(
29        self,
30        backbone: CLBackbone,
31        heads: HeadsTIL | HeadsCIL | HeadDIL,
32        non_algorithmic_hparams: dict[str, Any] = {},
33        **kwargs,
34    ) -> None:
35        r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters.
36
37        **Args:**
38        - **backbone** (`CLBackbone`): backbone network.
39        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
40        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
41        - **kwargs**: Reserved for multiple inheritance.
42
43        """
44        super().__init__(
45            backbone=backbone,
46            heads=heads,
47            non_algorithmic_hparams=non_algorithmic_hparams,
48            **kwargs,
49        )
50
51        # freeze only once after task 1
52        self._backbone_frozen: bool = False
53
54    def training_step(self, batch: Any) -> dict[str, Tensor]:
55        """Training step for current task `self.task_id`.
56
57        **Args:**
58        - **batch** (`Any`): a batch of training data.
59
60        **Returns:**
61        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
62        """
63        x, y = batch
64
65        if self.task_id != 1:
66            # freeze the backbone once after the first task; also stop BN/Dropout updates
67            if not self._backbone_frozen:
68                for p in self.backbone.parameters():
69                    p.requires_grad = False
70                self.backbone.eval()
71                self._backbone_frozen = True
72                pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).")
73        else:
74            # ensure backbone is trainable during the first task
75            self.backbone.train()
76
77        # classification loss
78        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
79        loss_cls = self.criterion(logits, y)
80
81        # total loss
82        loss = loss_cls
83
84        # predicted labels
85        preds = logits.argmax(dim=1)
86
87        # accuracy of the batch
88        acc = (preds == y).float().mean()
89
90        return {
91            "preds": preds,
92            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
93            "loss_cls": loss_cls,
94            "acc": acc,
95            "activations": activations,
96        }
21class Fix(Finetuning):
22    r"""Fix algorithm.
23
24    Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning.
25
26    We implement `Fix` as a subclass of `Finetuning`, as it shares `forward()`, `validation_step()`, and `test_step()` with `Finetuning`.
27    """
28
29    def __init__(
30        self,
31        backbone: CLBackbone,
32        heads: HeadsTIL | HeadsCIL | HeadDIL,
33        non_algorithmic_hparams: dict[str, Any] = {},
34        **kwargs,
35    ) -> None:
36        r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters.
37
38        **Args:**
39        - **backbone** (`CLBackbone`): backbone network.
40        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
41        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
42        - **kwargs**: Reserved for multiple inheritance.
43
44        """
45        super().__init__(
46            backbone=backbone,
47            heads=heads,
48            non_algorithmic_hparams=non_algorithmic_hparams,
49            **kwargs,
50        )
51
52        # freeze only once after task 1
53        self._backbone_frozen: bool = False
54
55    def training_step(self, batch: Any) -> dict[str, Tensor]:
56        """Training step for current task `self.task_id`.
57
58        **Args:**
59        - **batch** (`Any`): a batch of training data.
60
61        **Returns:**
62        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
63        """
64        x, y = batch
65
66        if self.task_id != 1:
67            # freeze the backbone once after the first task; also stop BN/Dropout updates
68            if not self._backbone_frozen:
69                for p in self.backbone.parameters():
70                    p.requires_grad = False
71                self.backbone.eval()
72                self._backbone_frozen = True
73                pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).")
74        else:
75            # ensure backbone is trainable during the first task
76            self.backbone.train()
77
78        # classification loss
79        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
80        loss_cls = self.criterion(logits, y)
81
82        # total loss
83        loss = loss_cls
84
85        # predicted labels
86        preds = logits.argmax(dim=1)
87
88        # accuracy of the batch
89        acc = (preds == y).float().mean()
90
91        return {
92            "preds": preds,
93            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
94            "loss_cls": loss_cls,
95            "acc": acc,
96            "activations": activations,
97        }

Fix algorithm.

Another naive way for task-incremental learning aside from Finetuning. It simply fixes the backbone forever after training first task. It serves as kind of toy algorithm when discussing stability-plasticity dilemma in continual learning.

We implement Fix as a subclass of Finetuning, as it shares forward(), validation_step(), and test_step() with Finetuning.

Fix( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
29    def __init__(
30        self,
31        backbone: CLBackbone,
32        heads: HeadsTIL | HeadsCIL | HeadDIL,
33        non_algorithmic_hparams: dict[str, Any] = {},
34        **kwargs,
35    ) -> None:
36        r"""Initialize the Fix algorithm with the network. It has no additional hyperparameters.
37
38        **Args:**
39        - **backbone** (`CLBackbone`): backbone network.
40        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
41        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
42        - **kwargs**: Reserved for multiple inheritance.
43
44        """
45        super().__init__(
46            backbone=backbone,
47            heads=heads,
48            non_algorithmic_hparams=non_algorithmic_hparams,
49            **kwargs,
50        )
51
52        # freeze only once after task 1
53        self._backbone_frozen: bool = False

Initialize the Fix algorithm with the network. It has no additional hyperparameters.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL | HeadDIL): output heads.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
55    def training_step(self, batch: Any) -> dict[str, Tensor]:
56        """Training step for current task `self.task_id`.
57
58        **Args:**
59        - **batch** (`Any`): a batch of training data.
60
61        **Returns:**
62        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
63        """
64        x, y = batch
65
66        if self.task_id != 1:
67            # freeze the backbone once after the first task; also stop BN/Dropout updates
68            if not self._backbone_frozen:
69                for p in self.backbone.parameters():
70                    p.requires_grad = False
71                self.backbone.eval()
72                self._backbone_frozen = True
73                pylogger.info("Fix: backbone frozen after task 1 (set to eval mode).")
74        else:
75            # ensure backbone is trainable during the first task
76            self.backbone.train()
77
78        # classification loss
79        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
80        loss_cls = self.criterion(logits, y)
81
82        # total loss
83        loss = loss_cls
84
85        # predicted labels
86        preds = logits.argmax(dim=1)
87
88        # accuracy of the batch
89        acc = (preds == y).float().mean()
90
91        return {
92            "preds": preds,
93            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
94            "loss_cls": loss_cls,
95            "acc": acc,
96            "activations": activations,
97        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.