clarena.cl_algorithms.finetuning
The submodule in cl_algorithms
for Finetuning algorithm.
1r""" 2The submodule in `cl_algorithms` for Finetuning algorithm. 3""" 4 5__all__ = ["Finetuning"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12from torch.utils.data import DataLoader 13 14from clarena.backbones import CLBackbone 15from clarena.cl_algorithms import CLAlgorithm 16from clarena.cl_heads import HeadsCIL, HeadsTIL 17 18# always get logger for built-in logging in each module 19pylogger = logging.getLogger(__name__) 20 21 22class Finetuning(CLAlgorithm): 23 r"""Finetuning algorithm. 24 25 It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task. 26 """ 27 28 def __init__( 29 self, 30 backbone: CLBackbone, 31 heads: HeadsTIL | HeadsCIL, 32 ) -> None: 33 r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 34 35 **Args:** 36 - **backbone** (`CLBackbone`): backbone network. 37 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 38 """ 39 CLAlgorithm.__init__(self, backbone=backbone, heads=heads) 40 41 def forward(self, input: Tensor, stage: str, task_id: int | None = None) -> Tensor: 42 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 43 44 **Args:** 45 - **input** (`Tensor`): The input tensor from data. 46 - **stage** (`str`): the stage of the forward pass, should be one of the following: 47 1. 'train': training stage. 48 2. 'validation': validation stage. 49 3. 'test': testing stage. 50 - **task_id** (`int`): the task ID where the data are from. If stage is 'train' or `validation`, it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. Finetuning algorithm works both for TIL and CIL. 51 52 **Returns:** 53 - **logits** (`Tensor`): the output logits tensor. 54 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although Finetuning algorithm does not need this, it is still provided for API consistence for other algorithms inherited this `forward()` method of `Finetuning` class. 55 """ 56 feature, hidden_features = self.backbone(input, stage=stage, task_id=task_id) 57 logits = self.heads(feature, task_id) 58 return logits, hidden_features 59 60 def training_step(self, batch: Any) -> dict[str, Tensor]: 61 """Training step for current task `self.task_id`. 62 63 **Args:** 64 - **batch** (`Any`): a batch of training data. 65 66 **Returns:** 67 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 68 """ 69 x, y = batch 70 71 # classification loss 72 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 73 loss_cls = self.criterion(logits, y) 74 75 # total loss 76 loss = loss_cls 77 78 # accuracy of the batch 79 acc = (logits.argmax(dim=1) == y).float().mean() 80 81 return { 82 "loss": loss, # return loss is essential for training step, or backpropagation will fail 83 "loss_cls": loss_cls, 84 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 85 "hidden_features": hidden_features, 86 } 87 88 def validation_step(self, batch: Any) -> dict[str, Tensor]: 89 r"""Validation step for current task `self.task_id`. 90 91 **Args:** 92 - **batch** (`Any`): a batch of validation data. 93 94 **Returns:** 95 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 96 """ 97 x, y = batch 98 logits, hidden_features = self.forward( 99 x, 100 stage="validation", 101 task_id=self.task_id, 102 ) 103 loss_cls = self.criterion(logits, y) 104 acc = (logits.argmax(dim=1) == y).float().mean() 105 106 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 107 return { 108 "loss_cls": loss_cls, 109 "acc": acc, 110 } 111 112 def test_step( 113 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 114 ) -> dict[str, Tensor]: 115 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 116 117 **Args:** 118 - **batch** (`Any`): a batch of test data. 119 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 120 121 **Returns:** 122 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 123 """ 124 test_task_id = dataloader_idx + 1 125 126 x, y = batch 127 logits, hidden_features = self.forward( 128 x, stage="test", task_id=test_task_id 129 ) # use the corresponding head to test (instead of the current task `self.task_id`) 130 loss_cls = self.criterion(logits, y) 131 acc = (logits.argmax(dim=1) == y).float().mean() 132 133 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 134 return { 135 "loss_cls": loss_cls, 136 "acc": acc, 137 }
class
Finetuning(clarena.cl_algorithms.base.CLAlgorithm):
23class Finetuning(CLAlgorithm): 24 r"""Finetuning algorithm. 25 26 It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task. 27 """ 28 29 def __init__( 30 self, 31 backbone: CLBackbone, 32 heads: HeadsTIL | HeadsCIL, 33 ) -> None: 34 r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 35 36 **Args:** 37 - **backbone** (`CLBackbone`): backbone network. 38 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 39 """ 40 CLAlgorithm.__init__(self, backbone=backbone, heads=heads) 41 42 def forward(self, input: Tensor, stage: str, task_id: int | None = None) -> Tensor: 43 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 44 45 **Args:** 46 - **input** (`Tensor`): The input tensor from data. 47 - **stage** (`str`): the stage of the forward pass, should be one of the following: 48 1. 'train': training stage. 49 2. 'validation': validation stage. 50 3. 'test': testing stage. 51 - **task_id** (`int`): the task ID where the data are from. If stage is 'train' or `validation`, it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. Finetuning algorithm works both for TIL and CIL. 52 53 **Returns:** 54 - **logits** (`Tensor`): the output logits tensor. 55 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although Finetuning algorithm does not need this, it is still provided for API consistence for other algorithms inherited this `forward()` method of `Finetuning` class. 56 """ 57 feature, hidden_features = self.backbone(input, stage=stage, task_id=task_id) 58 logits = self.heads(feature, task_id) 59 return logits, hidden_features 60 61 def training_step(self, batch: Any) -> dict[str, Tensor]: 62 """Training step for current task `self.task_id`. 63 64 **Args:** 65 - **batch** (`Any`): a batch of training data. 66 67 **Returns:** 68 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 69 """ 70 x, y = batch 71 72 # classification loss 73 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 74 loss_cls = self.criterion(logits, y) 75 76 # total loss 77 loss = loss_cls 78 79 # accuracy of the batch 80 acc = (logits.argmax(dim=1) == y).float().mean() 81 82 return { 83 "loss": loss, # return loss is essential for training step, or backpropagation will fail 84 "loss_cls": loss_cls, 85 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 86 "hidden_features": hidden_features, 87 } 88 89 def validation_step(self, batch: Any) -> dict[str, Tensor]: 90 r"""Validation step for current task `self.task_id`. 91 92 **Args:** 93 - **batch** (`Any`): a batch of validation data. 94 95 **Returns:** 96 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 97 """ 98 x, y = batch 99 logits, hidden_features = self.forward( 100 x, 101 stage="validation", 102 task_id=self.task_id, 103 ) 104 loss_cls = self.criterion(logits, y) 105 acc = (logits.argmax(dim=1) == y).float().mean() 106 107 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 108 return { 109 "loss_cls": loss_cls, 110 "acc": acc, 111 } 112 113 def test_step( 114 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 115 ) -> dict[str, Tensor]: 116 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 117 118 **Args:** 119 - **batch** (`Any`): a batch of test data. 120 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 121 122 **Returns:** 123 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 124 """ 125 test_task_id = dataloader_idx + 1 126 127 x, y = batch 128 logits, hidden_features = self.forward( 129 x, stage="test", task_id=test_task_id 130 ) # use the corresponding head to test (instead of the current task `self.task_id`) 131 loss_cls = self.criterion(logits, y) 132 acc = (logits.argmax(dim=1) == y).float().mean() 133 134 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 135 return { 136 "loss_cls": loss_cls, 137 "acc": acc, 138 }
Finetuning algorithm.
It is the most naive way for task-incremental learning. It simply initialises the backbone from the last task when training new task.
Finetuning( backbone: clarena.backbones.CLBackbone, heads: clarena.cl_heads.HeadsTIL | clarena.cl_heads.HeadsCIL)
29 def __init__( 30 self, 31 backbone: CLBackbone, 32 heads: HeadsTIL | HeadsCIL, 33 ) -> None: 34 r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 35 36 **Args:** 37 - **backbone** (`CLBackbone`): backbone network. 38 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 39 """ 40 CLAlgorithm.__init__(self, backbone=backbone, heads=heads)
Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
Args:
- backbone (
CLBackbone
): backbone network. - heads (
HeadsTIL
|HeadsCIL
): output heads.
def
forward( self, input: torch.Tensor, stage: str, task_id: int | None = None) -> torch.Tensor:
42 def forward(self, input: Tensor, stage: str, task_id: int | None = None) -> Tensor: 43 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 44 45 **Args:** 46 - **input** (`Tensor`): The input tensor from data. 47 - **stage** (`str`): the stage of the forward pass, should be one of the following: 48 1. 'train': training stage. 49 2. 'validation': validation stage. 50 3. 'test': testing stage. 51 - **task_id** (`int`): the task ID where the data are from. If stage is 'train' or `validation`, it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. Finetuning algorithm works both for TIL and CIL. 52 53 **Returns:** 54 - **logits** (`Tensor`): the output logits tensor. 55 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although Finetuning algorithm does not need this, it is still provided for API consistence for other algorithms inherited this `forward()` method of `Finetuning` class. 56 """ 57 feature, hidden_features = self.backbone(input, stage=stage, task_id=task_id) 58 logits = self.heads(feature, task_id) 59 return logits, hidden_features
The forward pass for data from task task_id
. Note that it is nothing to do with forward()
method in nn.Module
.
Args:
- input (
Tensor
): The input tensor from data. - stage (
str
): the stage of the forward pass, should be one of the following:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int
): the task ID where the data are from. If stage is 'train' orvalidation
, it is usually from the current taskself.task_id
. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. Finetuning algorithm works both for TIL and CIL.
Returns:
- logits (
Tensor
): the output logits tensor. - hidden_features (
dict[str, Tensor]
): the hidden features (after activation) in each weighted layer. Key (str
) is the weighted layer name, value (Tensor
) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although Finetuning algorithm does not need this, it is still provided for API consistence for other algorithms inherited thisforward()
method ofFinetuning
class.
def
training_step(self, batch: Any) -> dict[str, torch.Tensor]:
61 def training_step(self, batch: Any) -> dict[str, Tensor]: 62 """Training step for current task `self.task_id`. 63 64 **Args:** 65 - **batch** (`Any`): a batch of training data. 66 67 **Returns:** 68 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 69 """ 70 x, y = batch 71 72 # classification loss 73 logits, hidden_features = self.forward(x, stage="train", task_id=self.task_id) 74 loss_cls = self.criterion(logits, y) 75 76 # total loss 77 loss = loss_cls 78 79 # accuracy of the batch 80 acc = (logits.argmax(dim=1) == y).float().mean() 81 82 return { 83 "loss": loss, # return loss is essential for training step, or backpropagation will fail 84 "loss_cls": loss_cls, 85 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 86 "hidden_features": hidden_features, 87 }
Training step for current task self.task_id
.
Args:
- batch (
Any
): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this training step. Key (str
) is the metrics name, value (Tensor
) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def
validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
89 def validation_step(self, batch: Any) -> dict[str, Tensor]: 90 r"""Validation step for current task `self.task_id`. 91 92 **Args:** 93 - **batch** (`Any`): a batch of validation data. 94 95 **Returns:** 96 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 97 """ 98 x, y = batch 99 logits, hidden_features = self.forward( 100 x, 101 stage="validation", 102 task_id=self.task_id, 103 ) 104 loss_cls = self.criterion(logits, y) 105 acc = (logits.argmax(dim=1) == y).float().mean() 106 107 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 108 return { 109 "loss_cls": loss_cls, 110 "acc": acc, 111 }
Validation step for current task self.task_id
.
Args:
- batch (
Any
): a batch of validation data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this validation step. Key (str
) is the metrics name, value (Tensor
) is the metrics.
def
test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
113 def test_step( 114 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 115 ) -> dict[str, Tensor]: 116 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 117 118 **Args:** 119 - **batch** (`Any`): a batch of test data. 120 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 121 122 **Returns:** 123 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 124 """ 125 test_task_id = dataloader_idx + 1 126 127 x, y = batch 128 logits, hidden_features = self.forward( 129 x, stage="test", task_id=test_task_id 130 ) # use the corresponding head to test (instead of the current task `self.task_id`) 131 loss_cls = self.criterion(logits, y) 132 acc = (logits.argmax(dim=1) == y).float().mean() 133 134 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 135 return { 136 "loss_cls": loss_cls, 137 "acc": acc, 138 }
Test step for current task self.task_id
, which tests for all seen tasks indexed by dataloader_idx
.
Args:
- batch (
Any
): a batch of test data. - dataloader_idx (
int
): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this test step. Key (str
) is the metrics name, value (Tensor
) is the metrics.