clarena.cl_algorithms.finetuning

The submodule in cl_algorithms for Finetuning algorithm.

  1r"""
  2The submodule in `cl_algorithms` for Finetuning algorithm.
  3"""
  4
  5__all__ = ["Finetuning", "AmnesiacFinetuning"]
  6
  7import logging
  8from typing import Any
  9
 10from torch import Tensor
 11from torch.utils.data import DataLoader
 12
 13from clarena.backbones import CLBackbone
 14from clarena.cl_algorithms import AmnesiacCLAlgorithm, CLAlgorithm
 15from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL
 16
 17# always get logger for built-in logging in each module
 18pylogger = logging.getLogger(__name__)
 19
 20
 21class Finetuning(CLAlgorithm):
 22    r"""Finetuning algorithm.
 23
 24    The most naive way for task-incremental learning. It simply initializes the backbone from the last task when training new task.
 25    """
 26
 27    def __init__(
 28        self,
 29        backbone: CLBackbone,
 30        heads: HeadsTIL | HeadsCIL | HeadDIL,
 31        non_algorithmic_hparams: dict[str, Any] = {},
 32        **kwargs,
 33    ) -> None:
 34        r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.
 35
 36        **Args:**
 37        - **backbone** (`CLBackbone`): backbone network.
 38        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
 39        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 40        - **kwargs**: Reserved for multiple inheritance.
 41        """
 42        super().__init__(
 43            backbone=backbone,
 44            heads=heads,
 45            non_algorithmic_hparams=non_algorithmic_hparams,
 46            **kwargs,
 47        )
 48
 49    def training_step(self, batch: Any) -> dict[str, Tensor]:
 50        """Training step for current task `self.task_id`.
 51
 52        **Args:**
 53        - **batch** (`Any`): a batch of training data.
 54
 55        **Returns:**
 56        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 57        """
 58        x, y = batch
 59
 60        # classification loss
 61        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
 62        loss_cls = self.criterion(logits, y)
 63
 64        # total loss
 65        loss = loss_cls
 66
 67        # predicted labels
 68        preds = logits.argmax(dim=1)
 69
 70        # accuracy of the batch
 71        acc = (preds == y).float().mean()
 72
 73        return {
 74            "preds": preds,
 75            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
 76            "loss_cls": loss_cls,
 77            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
 78            "activations": activations,
 79        }
 80
 81    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 82        r"""Validation step for current task `self.task_id`.
 83
 84        **Args:**
 85        - **batch** (`Any`): a batch of validation data.
 86
 87        **Returns:**
 88        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) are the metrics names, value (`Tensor`) are the metrics.
 89        """
 90        x, y = batch
 91        logits, _ = self.forward(x, stage="validation", task_id=self.task_id)
 92        loss_cls = self.criterion(logits, y)
 93        preds = logits.argmax(dim=1)
 94        acc = (preds == y).float().mean()
 95
 96        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
 97        return {
 98            "preds": preds,
 99            "loss_cls": loss_cls,
100            "acc": acc,
101        }
102
103    def test_step(
104        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
105    ) -> dict[str, Tensor]:
106        r"""Test step for current task `self.task_id`, which tests all seen tasks indexed by `dataloader_idx`.
107
108        **Args:**
109        - **batch** (`Any`): a batch of test data.
110        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
111
112        **Returns:**
113        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) are the metrics name, value (`Tensor`) are the metrics.
114        """
115        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
116
117        x, y = batch
118
119        logits, _ = self.forward(
120            x, stage="test", task_id=test_task_id
121        )  # use the corresponding head to test (instead of the current task `self.task_id`)
122        loss_cls = self.criterion(logits, y)
123        preds = logits.argmax(dim=1)
124        acc = (preds == y).float().mean()
125
126        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
127        return {
128            "preds": preds,
129            "loss_cls": loss_cls,
130            "acc": acc,
131        }
132
133
134class AmnesiacFinetuning(AmnesiacCLAlgorithm, Finetuning):
135    r"""Amnesiac Finetuning algorithm."""
136
137    def __init__(
138        self,
139        backbone: CLBackbone,
140        heads: HeadsTIL | HeadsCIL | HeadDIL,
141        non_algorithmic_hparams: dict[str, Any] = {},
142        disable_unlearning: bool = False,
143        **kwargs,
144    ) -> None:
145        r"""Initialize the Amnesiac Finetuning algorithm with the network.
146
147        **Args:**
148        - **backbone** (`CLBackbone`): backbone network.
149        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
150        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
151        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
152        - **kwargs**: Reserved for multiple inheritance.
153        """
154        super().__init__(
155            backbone=backbone,
156            heads=heads,
157            non_algorithmic_hparams=non_algorithmic_hparams,
158            disable_unlearning=disable_unlearning,
159            **kwargs,
160        )
161
162    def on_train_start(self) -> None:
163        """Record backbone parameters before training current task."""
164        Finetuning.on_train_start(self)
165        AmnesiacCLAlgorithm.on_train_start(self)
166
167    def on_train_end(self) -> None:
168        """Record backbone parameters before training current task."""
169        Finetuning.on_train_end(self)
170        AmnesiacCLAlgorithm.on_train_end(self)
class Finetuning(clarena.cl_algorithms.base.CLAlgorithm):
 22class Finetuning(CLAlgorithm):
 23    r"""Finetuning algorithm.
 24
 25    The most naive way for task-incremental learning. It simply initializes the backbone from the last task when training new task.
 26    """
 27
 28    def __init__(
 29        self,
 30        backbone: CLBackbone,
 31        heads: HeadsTIL | HeadsCIL | HeadDIL,
 32        non_algorithmic_hparams: dict[str, Any] = {},
 33        **kwargs,
 34    ) -> None:
 35        r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.
 36
 37        **Args:**
 38        - **backbone** (`CLBackbone`): backbone network.
 39        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
 40        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 41        - **kwargs**: Reserved for multiple inheritance.
 42        """
 43        super().__init__(
 44            backbone=backbone,
 45            heads=heads,
 46            non_algorithmic_hparams=non_algorithmic_hparams,
 47            **kwargs,
 48        )
 49
 50    def training_step(self, batch: Any) -> dict[str, Tensor]:
 51        """Training step for current task `self.task_id`.
 52
 53        **Args:**
 54        - **batch** (`Any`): a batch of training data.
 55
 56        **Returns:**
 57        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 58        """
 59        x, y = batch
 60
 61        # classification loss
 62        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
 63        loss_cls = self.criterion(logits, y)
 64
 65        # total loss
 66        loss = loss_cls
 67
 68        # predicted labels
 69        preds = logits.argmax(dim=1)
 70
 71        # accuracy of the batch
 72        acc = (preds == y).float().mean()
 73
 74        return {
 75            "preds": preds,
 76            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
 77            "loss_cls": loss_cls,
 78            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
 79            "activations": activations,
 80        }
 81
 82    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 83        r"""Validation step for current task `self.task_id`.
 84
 85        **Args:**
 86        - **batch** (`Any`): a batch of validation data.
 87
 88        **Returns:**
 89        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) are the metrics names, value (`Tensor`) are the metrics.
 90        """
 91        x, y = batch
 92        logits, _ = self.forward(x, stage="validation", task_id=self.task_id)
 93        loss_cls = self.criterion(logits, y)
 94        preds = logits.argmax(dim=1)
 95        acc = (preds == y).float().mean()
 96
 97        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
 98        return {
 99            "preds": preds,
100            "loss_cls": loss_cls,
101            "acc": acc,
102        }
103
104    def test_step(
105        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
106    ) -> dict[str, Tensor]:
107        r"""Test step for current task `self.task_id`, which tests all seen tasks indexed by `dataloader_idx`.
108
109        **Args:**
110        - **batch** (`Any`): a batch of test data.
111        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
112
113        **Returns:**
114        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) are the metrics name, value (`Tensor`) are the metrics.
115        """
116        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
117
118        x, y = batch
119
120        logits, _ = self.forward(
121            x, stage="test", task_id=test_task_id
122        )  # use the corresponding head to test (instead of the current task `self.task_id`)
123        loss_cls = self.criterion(logits, y)
124        preds = logits.argmax(dim=1)
125        acc = (preds == y).float().mean()
126
127        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
128        return {
129            "preds": preds,
130            "loss_cls": loss_cls,
131            "acc": acc,
132        }

Finetuning algorithm.

The most naive way for task-incremental learning. It simply initializes the backbone from the last task when training new task.

Finetuning( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
28    def __init__(
29        self,
30        backbone: CLBackbone,
31        heads: HeadsTIL | HeadsCIL | HeadDIL,
32        non_algorithmic_hparams: dict[str, Any] = {},
33        **kwargs,
34    ) -> None:
35        r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.
36
37        **Args:**
38        - **backbone** (`CLBackbone`): backbone network.
39        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
40        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
41        - **kwargs**: Reserved for multiple inheritance.
42        """
43        super().__init__(
44            backbone=backbone,
45            heads=heads,
46            non_algorithmic_hparams=non_algorithmic_hparams,
47            **kwargs,
48        )

Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL | HeadDIL): output heads.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
50    def training_step(self, batch: Any) -> dict[str, Tensor]:
51        """Training step for current task `self.task_id`.
52
53        **Args:**
54        - **batch** (`Any`): a batch of training data.
55
56        **Returns:**
57        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
58        """
59        x, y = batch
60
61        # classification loss
62        logits, activations = self.forward(x, stage="train", task_id=self.task_id)
63        loss_cls = self.criterion(logits, y)
64
65        # total loss
66        loss = loss_cls
67
68        # predicted labels
69        preds = logits.argmax(dim=1)
70
71        # accuracy of the batch
72        acc = (preds == y).float().mean()
73
74        return {
75            "preds": preds,
76            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
77            "loss_cls": loss_cls,
78            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
79            "activations": activations,
80        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
 82    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 83        r"""Validation step for current task `self.task_id`.
 84
 85        **Args:**
 86        - **batch** (`Any`): a batch of validation data.
 87
 88        **Returns:**
 89        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) are the metrics names, value (`Tensor`) are the metrics.
 90        """
 91        x, y = batch
 92        logits, _ = self.forward(x, stage="validation", task_id=self.task_id)
 93        loss_cls = self.criterion(logits, y)
 94        preds = logits.argmax(dim=1)
 95        acc = (preds == y).float().mean()
 96
 97        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
 98        return {
 99            "preds": preds,
100            "loss_cls": loss_cls,
101            "acc": acc,
102        }

Validation step for current task self.task_id.

Args:

  • batch (Any): a batch of validation data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Key (str) are the metrics names, value (Tensor) are the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
104    def test_step(
105        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
106    ) -> dict[str, Tensor]:
107        r"""Test step for current task `self.task_id`, which tests all seen tasks indexed by `dataloader_idx`.
108
109        **Args:**
110        - **batch** (`Any`): a batch of test data.
111        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
112
113        **Returns:**
114        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) are the metrics name, value (`Tensor`) are the metrics.
115        """
116        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
117
118        x, y = batch
119
120        logits, _ = self.forward(
121            x, stage="test", task_id=test_task_id
122        )  # use the corresponding head to test (instead of the current task `self.task_id`)
123        loss_cls = self.criterion(logits, y)
124        preds = logits.argmax(dim=1)
125        acc = (preds == y).float().mean()
126
127        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
128        return {
129            "preds": preds,
130            "loss_cls": loss_cls,
131            "acc": acc,
132        }

Test step for current task self.task_id, which tests all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Key (str) are the metrics name, value (Tensor) are the metrics.
class AmnesiacFinetuning(clarena.cl_algorithms.base.AmnesiacCLAlgorithm, Finetuning):
135class AmnesiacFinetuning(AmnesiacCLAlgorithm, Finetuning):
136    r"""Amnesiac Finetuning algorithm."""
137
138    def __init__(
139        self,
140        backbone: CLBackbone,
141        heads: HeadsTIL | HeadsCIL | HeadDIL,
142        non_algorithmic_hparams: dict[str, Any] = {},
143        disable_unlearning: bool = False,
144        **kwargs,
145    ) -> None:
146        r"""Initialize the Amnesiac Finetuning algorithm with the network.
147
148        **Args:**
149        - **backbone** (`CLBackbone`): backbone network.
150        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
151        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
152        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
153        - **kwargs**: Reserved for multiple inheritance.
154        """
155        super().__init__(
156            backbone=backbone,
157            heads=heads,
158            non_algorithmic_hparams=non_algorithmic_hparams,
159            disable_unlearning=disable_unlearning,
160            **kwargs,
161        )
162
163    def on_train_start(self) -> None:
164        """Record backbone parameters before training current task."""
165        Finetuning.on_train_start(self)
166        AmnesiacCLAlgorithm.on_train_start(self)
167
168    def on_train_end(self) -> None:
169        """Record backbone parameters before training current task."""
170        Finetuning.on_train_end(self)
171        AmnesiacCLAlgorithm.on_train_end(self)

Amnesiac Finetuning algorithm.

AmnesiacFinetuning( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any] = {}, disable_unlearning: bool = False, **kwargs)
138    def __init__(
139        self,
140        backbone: CLBackbone,
141        heads: HeadsTIL | HeadsCIL | HeadDIL,
142        non_algorithmic_hparams: dict[str, Any] = {},
143        disable_unlearning: bool = False,
144        **kwargs,
145    ) -> None:
146        r"""Initialize the Amnesiac Finetuning algorithm with the network.
147
148        **Args:**
149        - **backbone** (`CLBackbone`): backbone network.
150        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
151        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
152        - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`.
153        - **kwargs**: Reserved for multiple inheritance.
154        """
155        super().__init__(
156            backbone=backbone,
157            heads=heads,
158            non_algorithmic_hparams=non_algorithmic_hparams,
159            disable_unlearning=disable_unlearning,
160            **kwargs,
161        )

Initialize the Amnesiac Finetuning algorithm with the network.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL | HeadDIL): output heads.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • disable_unlearning (bool): whether to disable the unlearning functionality. Default is False.
  • kwargs: Reserved for multiple inheritance.
def on_train_start(self) -> None:
163    def on_train_start(self) -> None:
164        """Record backbone parameters before training current task."""
165        Finetuning.on_train_start(self)
166        AmnesiacCLAlgorithm.on_train_start(self)

Record backbone parameters before training current task.

def on_train_end(self) -> None:
168    def on_train_end(self) -> None:
169        """Record backbone parameters before training current task."""
170        Finetuning.on_train_end(self)
171        AmnesiacCLAlgorithm.on_train_end(self)

Record backbone parameters before training current task.