clarena.cl_algorithms.finetuning
The submodule in cl_algorithms
for Finetuning algorithm.
1""" 2The submodule in `cl_algorithms` for Finetuning algorithm. 3""" 4 5__all__ = ["Finetuning"] 6 7import logging 8from typing import Any 9 10import torch 11from torch.utils.data import DataLoader 12 13from clarena.cl_algorithms import CLAlgorithm 14 15# always get logger for built-in logging in each module 16pylogger = logging.getLogger(__name__) 17 18 19class Finetuning(CLAlgorithm): 20 """Finetuning algorithm. 21 22 It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task. 23 """ 24 25 def __init__( 26 self, 27 backbone: torch.nn.Module, 28 heads: torch.nn.Module, 29 ) -> None: 30 """Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 31 32 **Args:** 33 - **backbone** (`CLBackbone`): backbone network. 34 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 35 """ 36 super().__init__(backbone=backbone, heads=heads) 37 38 def training_step(self, batch: Any): 39 """Training step for current task `self.task_id`. 40 41 **Args:** 42 - **batch** (`Any`): a batch of training data. 43 """ 44 x, y = batch 45 logits = self.forward(x, self.task_id) 46 loss_cls = self.criterion(logits, y) 47 loss = loss_cls 48 acc = (logits.argmax(dim=1) == y).float().mean() 49 50 # Return loss is essential for training step, or backpropagation will fail 51 # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 52 return { 53 "loss": loss, 54 "loss_cls": loss_cls, 55 "acc": acc, 56 } 57 58 def validation_step(self, batch: Any): 59 """Validation step for current task `self.task_id`. 60 61 **Args:** 62 - **batch** (`Any`): a batch of validation data. 63 """ 64 x, y = batch 65 logits = self.forward(x, self.task_id) 66 loss_cls = self.criterion(logits, y) 67 acc = (logits.argmax(dim=1) == y).float().mean() 68 69 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 70 return { 71 "loss_cls": loss_cls, 72 "acc": acc, 73 } 74 75 def test_step(self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0): 76 """Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 77 78 **Args:** 79 - **batch** (`Any`): a batch of test data. 80 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 81 """ 82 test_task_id = dataloader_idx + 1 83 84 x, y = batch 85 logits = self.forward( 86 x, test_task_id 87 ) # use the corresponding head to test (instead of the current task `self.task_id`) 88 loss_cls = self.criterion(logits, y) 89 acc = (logits.argmax(dim=1) == y).float().mean() 90 91 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 92 return { 93 "loss_cls": loss_cls, 94 "acc": acc, 95 }
class
Finetuning(clarena.cl_algorithms.base.CLAlgorithm):
20class Finetuning(CLAlgorithm): 21 """Finetuning algorithm. 22 23 It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task. 24 """ 25 26 def __init__( 27 self, 28 backbone: torch.nn.Module, 29 heads: torch.nn.Module, 30 ) -> None: 31 """Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 32 33 **Args:** 34 - **backbone** (`CLBackbone`): backbone network. 35 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 36 """ 37 super().__init__(backbone=backbone, heads=heads) 38 39 def training_step(self, batch: Any): 40 """Training step for current task `self.task_id`. 41 42 **Args:** 43 - **batch** (`Any`): a batch of training data. 44 """ 45 x, y = batch 46 logits = self.forward(x, self.task_id) 47 loss_cls = self.criterion(logits, y) 48 loss = loss_cls 49 acc = (logits.argmax(dim=1) == y).float().mean() 50 51 # Return loss is essential for training step, or backpropagation will fail 52 # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 53 return { 54 "loss": loss, 55 "loss_cls": loss_cls, 56 "acc": acc, 57 } 58 59 def validation_step(self, batch: Any): 60 """Validation step for current task `self.task_id`. 61 62 **Args:** 63 - **batch** (`Any`): a batch of validation data. 64 """ 65 x, y = batch 66 logits = self.forward(x, self.task_id) 67 loss_cls = self.criterion(logits, y) 68 acc = (logits.argmax(dim=1) == y).float().mean() 69 70 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 71 return { 72 "loss_cls": loss_cls, 73 "acc": acc, 74 } 75 76 def test_step(self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0): 77 """Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 78 79 **Args:** 80 - **batch** (`Any`): a batch of test data. 81 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 82 """ 83 test_task_id = dataloader_idx + 1 84 85 x, y = batch 86 logits = self.forward( 87 x, test_task_id 88 ) # use the corresponding head to test (instead of the current task `self.task_id`) 89 loss_cls = self.criterion(logits, y) 90 acc = (logits.argmax(dim=1) == y).float().mean() 91 92 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 93 return { 94 "loss_cls": loss_cls, 95 "acc": acc, 96 }
Finetuning algorithm.
It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task.
Finetuning( backbone: torch.nn.modules.module.Module, heads: torch.nn.modules.module.Module)
26 def __init__( 27 self, 28 backbone: torch.nn.Module, 29 heads: torch.nn.Module, 30 ) -> None: 31 """Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 32 33 **Args:** 34 - **backbone** (`CLBackbone`): backbone network. 35 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 36 """ 37 super().__init__(backbone=backbone, heads=heads)
Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
Args:
- backbone (
CLBackbone
): backbone network. - heads (
HeadsTIL
|HeadsCIL
): output heads.
def
training_step(self, batch: Any):
39 def training_step(self, batch: Any): 40 """Training step for current task `self.task_id`. 41 42 **Args:** 43 - **batch** (`Any`): a batch of training data. 44 """ 45 x, y = batch 46 logits = self.forward(x, self.task_id) 47 loss_cls = self.criterion(logits, y) 48 loss = loss_cls 49 acc = (logits.argmax(dim=1) == y).float().mean() 50 51 # Return loss is essential for training step, or backpropagation will fail 52 # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 53 return { 54 "loss": loss, 55 "loss_cls": loss_cls, 56 "acc": acc, 57 }
Training step for current task self.task_id
.
Args:
- batch (
Any
): a batch of training data.
def
validation_step(self, batch: Any):
59 def validation_step(self, batch: Any): 60 """Validation step for current task `self.task_id`. 61 62 **Args:** 63 - **batch** (`Any`): a batch of validation data. 64 """ 65 x, y = batch 66 logits = self.forward(x, self.task_id) 67 loss_cls = self.criterion(logits, y) 68 acc = (logits.argmax(dim=1) == y).float().mean() 69 70 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 71 return { 72 "loss_cls": loss_cls, 73 "acc": acc, 74 }
Validation step for current task self.task_id
.
Args:
- batch (
Any
): a batch of validation data.
def
test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0):
76 def test_step(self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0): 77 """Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 78 79 **Args:** 80 - **batch** (`Any`): a batch of test data. 81 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 82 """ 83 test_task_id = dataloader_idx + 1 84 85 x, y = batch 86 logits = self.forward( 87 x, test_task_id 88 ) # use the corresponding head to test (instead of the current task `self.task_id`) 89 loss_cls = self.criterion(logits, y) 90 acc = (logits.argmax(dim=1) == y).float().mean() 91 92 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 93 return { 94 "loss_cls": loss_cls, 95 "acc": acc, 96 }
Test step for current task self.task_id
, which tests for all seen tasks indexed by dataloader_idx
.
Args:
- batch (
Any
): a batch of test data. - dataloader_idx (
int
): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.