clarena.cl_algorithms.finetuning

The submodule in cl_algorithms for Finetuning algorithm.

  1r"""
  2The submodule in `cl_algorithms` for Finetuning algorithm.
  3"""
  4
  5__all__ = ["Finetuning"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12from torch.utils.data import DataLoader
 13
 14from clarena.backbones import CLBackbone
 15from clarena.cl_algorithms import CLAlgorithm
 16from clarena.cl_heads import HeadsCIL, HeadsTIL
 17
 18# always get logger for built-in logging in each module
 19pylogger = logging.getLogger(__name__)
 20
 21
 22class Finetuning(CLAlgorithm):
 23    r"""Finetuning algorithm.
 24
 25    It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task.
 26    """
 27
 28    def __init__(
 29        self,
 30        backbone: CLBackbone,
 31        heads: HeadsTIL | HeadsCIL,
 32    ) -> None:
 33        r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
 34
 35        **Args:**
 36        - **backbone** (`CLBackbone`): backbone network.
 37        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 38        """
 39        CLAlgorithm.__init__(self, backbone=backbone, heads=heads)
 40
 41    def forward(self, input: Tensor, stage: str, task_id: int | None = None) -> Tensor:
 42        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
 43
 44        **Args:**
 45        - **input** (`Tensor`): The input tensor from data.
 46        - **stage** (`str`): the stage of the forward pass, should be one of the following:
 47            1. 'train': training stage.
 48            2. 'validation': validation stage.
 49            3. 'test': testing stage.
 50        - **task_id** (`int`): the task ID where the data are from. If stage is 'train' or `validation`, it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. Finetuning algorithm works both for TIL and CIL.
 51
 52        **Returns:**
 53        - **logits** (`Tensor`): the output logits tensor.
 54        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although Finetuning algorithm does not need this, it is still provided for API consistence for other algorithms inherited this `forward()` method of `Finetuning` class.
 55        """
 56        feature, hidden_features = self.backbone(input, stage=stage, task_id=task_id)
 57        logits = self.heads(feature, task_id)
 58        return logits, hidden_features
 59
 60    def training_step(self, batch: Any) -> dict[str, Tensor]:
 61        """Training step for current task `self.task_id`.
 62
 63        **Args:**
 64        - **batch** (`Any`): a batch of training data.
 65
 66        **Returns:**
 67        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 68        """
 69        x, y = batch
 70
 71        # classification loss
 72        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
 73        loss_cls = self.criterion(logits, y)
 74
 75        # total loss
 76        loss = loss_cls
 77
 78        # accuracy of the batch
 79        acc = (logits.argmax(dim=1) == y).float().mean()
 80
 81        return {
 82            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
 83            "loss_cls": loss_cls,
 84            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
 85            "hidden_features": hidden_features,
 86        }
 87
 88    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 89        r"""Validation step for current task `self.task_id`.
 90
 91        **Args:**
 92        - **batch** (`Any`): a batch of validation data.
 93
 94        **Returns:**
 95        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 96        """
 97        x, y = batch
 98        logits, hidden_features = self.forward(
 99            x,
100            stage="validation",
101            task_id=self.task_id,
102        )
103        loss_cls = self.criterion(logits, y)
104        acc = (logits.argmax(dim=1) == y).float().mean()
105
106        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
107        return {
108            "loss_cls": loss_cls,
109            "acc": acc,
110        }
111
112    def test_step(
113        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
114    ) -> dict[str, Tensor]:
115        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
116
117        **Args:**
118        - **batch** (`Any`): a batch of test data.
119        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
120
121        **Returns:**
122        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
123        """
124        test_task_id = dataloader_idx + 1
125
126        x, y = batch
127        logits, hidden_features = self.forward(
128            x, stage="test", task_id=test_task_id
129        )  # use the corresponding head to test (instead of the current task `self.task_id`)
130        loss_cls = self.criterion(logits, y)
131        acc = (logits.argmax(dim=1) == y).float().mean()
132
133        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
134        return {
135            "loss_cls": loss_cls,
136            "acc": acc,
137        }
class Finetuning(clarena.cl_algorithms.base.CLAlgorithm):
 23class Finetuning(CLAlgorithm):
 24    r"""Finetuning algorithm.
 25
 26    It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task.
 27    """
 28
 29    def __init__(
 30        self,
 31        backbone: CLBackbone,
 32        heads: HeadsTIL | HeadsCIL,
 33    ) -> None:
 34        r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
 35
 36        **Args:**
 37        - **backbone** (`CLBackbone`): backbone network.
 38        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 39        """
 40        CLAlgorithm.__init__(self, backbone=backbone, heads=heads)
 41
 42    def forward(self, input: Tensor, stage: str, task_id: int | None = None) -> Tensor:
 43        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
 44
 45        **Args:**
 46        - **input** (`Tensor`): The input tensor from data.
 47        - **stage** (`str`): the stage of the forward pass, should be one of the following:
 48            1. 'train': training stage.
 49            2. 'validation': validation stage.
 50            3. 'test': testing stage.
 51        - **task_id** (`int`): the task ID where the data are from. If stage is 'train' or `validation`, it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. Finetuning algorithm works both for TIL and CIL.
 52
 53        **Returns:**
 54        - **logits** (`Tensor`): the output logits tensor.
 55        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although Finetuning algorithm does not need this, it is still provided for API consistence for other algorithms inherited this `forward()` method of `Finetuning` class.
 56        """
 57        feature, hidden_features = self.backbone(input, stage=stage, task_id=task_id)
 58        logits = self.heads(feature, task_id)
 59        return logits, hidden_features
 60
 61    def training_step(self, batch: Any) -> dict[str, Tensor]:
 62        """Training step for current task `self.task_id`.
 63
 64        **Args:**
 65        - **batch** (`Any`): a batch of training data.
 66
 67        **Returns:**
 68        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 69        """
 70        x, y = batch
 71
 72        # classification loss
 73        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
 74        loss_cls = self.criterion(logits, y)
 75
 76        # total loss
 77        loss = loss_cls
 78
 79        # accuracy of the batch
 80        acc = (logits.argmax(dim=1) == y).float().mean()
 81
 82        return {
 83            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
 84            "loss_cls": loss_cls,
 85            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
 86            "hidden_features": hidden_features,
 87        }
 88
 89    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 90        r"""Validation step for current task `self.task_id`.
 91
 92        **Args:**
 93        - **batch** (`Any`): a batch of validation data.
 94
 95        **Returns:**
 96        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 97        """
 98        x, y = batch
 99        logits, hidden_features = self.forward(
100            x,
101            stage="validation",
102            task_id=self.task_id,
103        )
104        loss_cls = self.criterion(logits, y)
105        acc = (logits.argmax(dim=1) == y).float().mean()
106
107        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
108        return {
109            "loss_cls": loss_cls,
110            "acc": acc,
111        }
112
113    def test_step(
114        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
115    ) -> dict[str, Tensor]:
116        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
117
118        **Args:**
119        - **batch** (`Any`): a batch of test data.
120        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
121
122        **Returns:**
123        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
124        """
125        test_task_id = dataloader_idx + 1
126
127        x, y = batch
128        logits, hidden_features = self.forward(
129            x, stage="test", task_id=test_task_id
130        )  # use the corresponding head to test (instead of the current task `self.task_id`)
131        loss_cls = self.criterion(logits, y)
132        acc = (logits.argmax(dim=1) == y).float().mean()
133
134        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
135        return {
136            "loss_cls": loss_cls,
137            "acc": acc,
138        }

Finetuning algorithm.

It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task.

29    def __init__(
30        self,
31        backbone: CLBackbone,
32        heads: HeadsTIL | HeadsCIL,
33    ) -> None:
34        r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
35
36        **Args:**
37        - **backbone** (`CLBackbone`): backbone network.
38        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
39        """
40        CLAlgorithm.__init__(self, backbone=backbone, heads=heads)

Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL): output heads.
def forward( self, input: torch.Tensor, stage: str, task_id: int | None = None) -> torch.Tensor:
42    def forward(self, input: Tensor, stage: str, task_id: int | None = None) -> Tensor:
43        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`.
44
45        **Args:**
46        - **input** (`Tensor`): The input tensor from data.
47        - **stage** (`str`): the stage of the forward pass, should be one of the following:
48            1. 'train': training stage.
49            2. 'validation': validation stage.
50            3. 'test': testing stage.
51        - **task_id** (`int`): the task ID where the data are from. If stage is 'train' or `validation`, it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. Finetuning algorithm works both for TIL and CIL.
52
53        **Returns:**
54        - **logits** (`Tensor`): the output logits tensor.
55        - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although Finetuning algorithm does not need this, it is still provided for API consistence for other algorithms inherited this `forward()` method of `Finetuning` class.
56        """
57        feature, hidden_features = self.backbone(input, stage=stage, task_id=task_id)
58        logits = self.heads(feature, task_id)
59        return logits, hidden_features

The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): the stage of the forward pass, should be one of the following:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.
  • task_id (int): the task ID where the data are from. If stage is 'train' or validation, it is usually from the current task self.task_id. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. Finetuning algorithm works both for TIL and CIL.

Returns:

  • logits (Tensor): the output logits tensor.
  • hidden_features (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although Finetuning algorithm does not need this, it is still provided for API consistence for other algorithms inherited this forward() method of Finetuning class.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
61    def training_step(self, batch: Any) -> dict[str, Tensor]:
62        """Training step for current task `self.task_id`.
63
64        **Args:**
65        - **batch** (`Any`): a batch of training data.
66
67        **Returns:**
68        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
69        """
70        x, y = batch
71
72        # classification loss
73        logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id)
74        loss_cls = self.criterion(logits, y)
75
76        # total loss
77        loss = loss_cls
78
79        # accuracy of the batch
80        acc = (logits.argmax(dim=1) == y).float().mean()
81
82        return {
83            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
84            "loss_cls": loss_cls,
85            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
86            "hidden_features": hidden_features,
87        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Key (str) is the metrics name, value (Tensor) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
 89    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 90        r"""Validation step for current task `self.task_id`.
 91
 92        **Args:**
 93        - **batch** (`Any`): a batch of validation data.
 94
 95        **Returns:**
 96        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 97        """
 98        x, y = batch
 99        logits, hidden_features = self.forward(
100            x,
101            stage="validation",
102            task_id=self.task_id,
103        )
104        loss_cls = self.criterion(logits, y)
105        acc = (logits.argmax(dim=1) == y).float().mean()
106
107        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
108        return {
109            "loss_cls": loss_cls,
110            "acc": acc,
111        }

Validation step for current task self.task_id.

Args:

  • batch (Any): a batch of validation data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Key (str) is the metrics name, value (Tensor) is the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
113    def test_step(
114        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
115    ) -> dict[str, Tensor]:
116        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
117
118        **Args:**
119        - **batch** (`Any`): a batch of test data.
120        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
121
122        **Returns:**
123        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
124        """
125        test_task_id = dataloader_idx + 1
126
127        x, y = batch
128        logits, hidden_features = self.forward(
129            x, stage="test", task_id=test_task_id
130        )  # use the corresponding head to test (instead of the current task `self.task_id`)
131        loss_cls = self.criterion(logits, y)
132        acc = (logits.argmax(dim=1) == y).float().mean()
133
134        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
135        return {
136            "loss_cls": loss_cls,
137            "acc": acc,
138        }

Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Key (str) is the metrics name, value (Tensor) is the metrics.