clarena.cl_algorithms.finetuning

The submodule in cl_algorithms for Finetuning algorithm.

 1"""
 2The submodule in `cl_algorithms` for Finetuning algorithm.
 3"""
 4
 5__all__ = ["Finetuning"]
 6
 7import logging
 8from typing import Any
 9
10import torch
11from torch.utils.data import DataLoader
12
13from clarena.cl_algorithms import CLAlgorithm
14
15# always get logger for built-in logging in each module
16pylogger = logging.getLogger(__name__)
17
18
19class Finetuning(CLAlgorithm):
20    """Finetuning algorithm.
21
22    It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task.
23    """
24
25    def __init__(
26        self,
27        backbone: torch.nn.Module,
28        heads: torch.nn.Module,
29    ) -> None:
30        """Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
31
32        **Args:**
33        - **backbone** (`CLBackbone`): backbone network.
34        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
35        """
36        super().__init__(backbone=backbone, heads=heads)
37
38    def training_step(self, batch: Any):
39        """Training step for current task `self.task_id`.
40
41        **Args:**
42        - **batch** (`Any`): a batch of training data.
43        """
44        x, y = batch
45        logits = self.forward(x, self.task_id)
46        loss_cls = self.criterion(logits, y)
47        loss = loss_cls
48        acc = (logits.argmax(dim=1) == y).float().mean()
49
50        # Return loss is essential for training step, or backpropagation will fail
51        # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
52        return {
53            "loss": loss,
54            "loss_cls": loss_cls,
55            "acc": acc,
56        }
57
58    def validation_step(self, batch: Any):
59        """Validation step for current task `self.task_id`.
60
61        **Args:**
62        - **batch** (`Any`): a batch of validation data.
63        """
64        x, y = batch
65        logits = self.forward(x, self.task_id)
66        loss_cls = self.criterion(logits, y)
67        acc = (logits.argmax(dim=1) == y).float().mean()
68
69        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
70        return {
71            "loss_cls": loss_cls,
72            "acc": acc,
73        }
74
75    def test_step(self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0):
76        """Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
77
78        **Args:**
79        - **batch** (`Any`): a batch of test data.
80        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
81        """
82        test_task_id = dataloader_idx + 1
83
84        x, y = batch
85        logits = self.forward(
86            x, test_task_id
87        )  # use the corresponding head to test (instead of the current task `self.task_id`)
88        loss_cls = self.criterion(logits, y)
89        acc = (logits.argmax(dim=1) == y).float().mean()
90
91        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
92        return {
93            "loss_cls": loss_cls,
94            "acc": acc,
95        }
class Finetuning(clarena.cl_algorithms.base.CLAlgorithm):
20class Finetuning(CLAlgorithm):
21    """Finetuning algorithm.
22
23    It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task.
24    """
25
26    def __init__(
27        self,
28        backbone: torch.nn.Module,
29        heads: torch.nn.Module,
30    ) -> None:
31        """Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
32
33        **Args:**
34        - **backbone** (`CLBackbone`): backbone network.
35        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
36        """
37        super().__init__(backbone=backbone, heads=heads)
38
39    def training_step(self, batch: Any):
40        """Training step for current task `self.task_id`.
41
42        **Args:**
43        - **batch** (`Any`): a batch of training data.
44        """
45        x, y = batch
46        logits = self.forward(x, self.task_id)
47        loss_cls = self.criterion(logits, y)
48        loss = loss_cls
49        acc = (logits.argmax(dim=1) == y).float().mean()
50
51        # Return loss is essential for training step, or backpropagation will fail
52        # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
53        return {
54            "loss": loss,
55            "loss_cls": loss_cls,
56            "acc": acc,
57        }
58
59    def validation_step(self, batch: Any):
60        """Validation step for current task `self.task_id`.
61
62        **Args:**
63        - **batch** (`Any`): a batch of validation data.
64        """
65        x, y = batch
66        logits = self.forward(x, self.task_id)
67        loss_cls = self.criterion(logits, y)
68        acc = (logits.argmax(dim=1) == y).float().mean()
69
70        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
71        return {
72            "loss_cls": loss_cls,
73            "acc": acc,
74        }
75
76    def test_step(self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0):
77        """Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
78
79        **Args:**
80        - **batch** (`Any`): a batch of test data.
81        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
82        """
83        test_task_id = dataloader_idx + 1
84
85        x, y = batch
86        logits = self.forward(
87            x, test_task_id
88        )  # use the corresponding head to test (instead of the current task `self.task_id`)
89        loss_cls = self.criterion(logits, y)
90        acc = (logits.argmax(dim=1) == y).float().mean()
91
92        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
93        return {
94            "loss_cls": loss_cls,
95            "acc": acc,
96        }

Finetuning algorithm.

It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task.

Finetuning( backbone: torch.nn.modules.module.Module, heads: torch.nn.modules.module.Module)
26    def __init__(
27        self,
28        backbone: torch.nn.Module,
29        heads: torch.nn.Module,
30    ) -> None:
31        """Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
32
33        **Args:**
34        - **backbone** (`CLBackbone`): backbone network.
35        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
36        """
37        super().__init__(backbone=backbone, heads=heads)

Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL): output heads.
def training_step(self, batch: Any):
39    def training_step(self, batch: Any):
40        """Training step for current task `self.task_id`.
41
42        **Args:**
43        - **batch** (`Any`): a batch of training data.
44        """
45        x, y = batch
46        logits = self.forward(x, self.task_id)
47        loss_cls = self.criterion(logits, y)
48        loss = loss_cls
49        acc = (logits.argmax(dim=1) == y).float().mean()
50
51        # Return loss is essential for training step, or backpropagation will fail
52        # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
53        return {
54            "loss": loss,
55            "loss_cls": loss_cls,
56            "acc": acc,
57        }

Training step for current task self.task_id.

Args:

  • batch (Any): a batch of training data.
def validation_step(self, batch: Any):
59    def validation_step(self, batch: Any):
60        """Validation step for current task `self.task_id`.
61
62        **Args:**
63        - **batch** (`Any`): a batch of validation data.
64        """
65        x, y = batch
66        logits = self.forward(x, self.task_id)
67        loss_cls = self.criterion(logits, y)
68        acc = (logits.argmax(dim=1) == y).float().mean()
69
70        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
71        return {
72            "loss_cls": loss_cls,
73            "acc": acc,
74        }

Validation step for current task self.task_id.

Args:

  • batch (Any): a batch of validation data.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0):
76    def test_step(self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0):
77        """Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
78
79        **Args:**
80        - **batch** (`Any`): a batch of test data.
81        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
82        """
83        test_task_id = dataloader_idx + 1
84
85        x, y = batch
86        logits = self.forward(
87            x, test_task_id
88        )  # use the corresponding head to test (instead of the current task `self.task_id`)
89        loss_cls = self.criterion(logits, y)
90        acc = (logits.argmax(dim=1) == y).float().mean()
91
92        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
93        return {
94            "loss_cls": loss_cls,
95            "acc": acc,
96        }

Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.