clarena.cl_algorithms.finetuning
The submodule in cl_algorithms for Finetuning algorithm.
1r""" 2The submodule in `cl_algorithms` for Finetuning algorithm. 3""" 4 5__all__ = ["Finetuning", "AmnesiacFinetuning"] 6 7import logging 8from typing import Any 9 10from torch import Tensor 11from torch.utils.data import DataLoader 12 13from clarena.backbones import CLBackbone 14from clarena.cl_algorithms import AmnesiacCLAlgorithm, CLAlgorithm 15from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class Finetuning(CLAlgorithm): 22 r"""Finetuning algorithm. 23 24 The most naive way for task-incremental learning. It simply initializes the backbone from the last task when training new task. 25 """ 26 27 def __init__( 28 self, 29 backbone: CLBackbone, 30 heads: HeadsTIL | HeadsCIL | HeadDIL, 31 non_algorithmic_hparams: dict[str, Any] = {}, 32 **kwargs, 33 ) -> None: 34 r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters. 35 36 **Args:** 37 - **backbone** (`CLBackbone`): backbone network. 38 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 39 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 40 - **kwargs**: Reserved for multiple inheritance. 41 """ 42 super().__init__( 43 backbone=backbone, 44 heads=heads, 45 non_algorithmic_hparams=non_algorithmic_hparams, 46 **kwargs, 47 ) 48 49 def training_step(self, batch: Any) -> dict[str, Tensor]: 50 """Training step for current task `self.task_id`. 51 52 **Args:** 53 - **batch** (`Any`): a batch of training data. 54 55 **Returns:** 56 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 57 """ 58 x, y = batch 59 60 # classification loss 61 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 62 loss_cls = self.criterion(logits, y) 63 64 # total loss 65 loss = loss_cls 66 67 # predicted labels 68 preds = logits.argmax(dim=1) 69 70 # accuracy of the batch 71 acc = (preds == y).float().mean() 72 73 return { 74 "preds": preds, 75 "loss": loss, # return loss is essential for training step, or backpropagation will fail 76 "loss_cls": loss_cls, 77 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 78 "activations": activations, 79 } 80 81 def validation_step(self, batch: Any) -> dict[str, Tensor]: 82 r"""Validation step for current task `self.task_id`. 83 84 **Args:** 85 - **batch** (`Any`): a batch of validation data. 86 87 **Returns:** 88 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) are the metrics names, value (`Tensor`) are the metrics. 89 """ 90 x, y = batch 91 logits, _ = self.forward(x, stage="validation", task_id=self.task_id) 92 loss_cls = self.criterion(logits, y) 93 preds = logits.argmax(dim=1) 94 acc = (preds == y).float().mean() 95 96 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 97 return { 98 "preds": preds, 99 "loss_cls": loss_cls, 100 "acc": acc, 101 } 102 103 def test_step( 104 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 105 ) -> dict[str, Tensor]: 106 r"""Test step for current task `self.task_id`, which tests all seen tasks indexed by `dataloader_idx`. 107 108 **Args:** 109 - **batch** (`Any`): a batch of test data. 110 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 111 112 **Returns:** 113 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) are the metrics name, value (`Tensor`) are the metrics. 114 """ 115 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 116 117 x, y = batch 118 119 logits, _ = self.forward( 120 x, stage="test", task_id=test_task_id 121 ) # use the corresponding head to test (instead of the current task `self.task_id`) 122 loss_cls = self.criterion(logits, y) 123 preds = logits.argmax(dim=1) 124 acc = (preds == y).float().mean() 125 126 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 127 return { 128 "preds": preds, 129 "loss_cls": loss_cls, 130 "acc": acc, 131 } 132 133 134class AmnesiacFinetuning(AmnesiacCLAlgorithm, Finetuning): 135 r"""Amnesiac Finetuning algorithm.""" 136 137 def __init__( 138 self, 139 backbone: CLBackbone, 140 heads: HeadsTIL | HeadsCIL | HeadDIL, 141 non_algorithmic_hparams: dict[str, Any] = {}, 142 disable_unlearning: bool = False, 143 **kwargs, 144 ) -> None: 145 r"""Initialize the Amnesiac Finetuning algorithm with the network. 146 147 **Args:** 148 - **backbone** (`CLBackbone`): backbone network. 149 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 150 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 151 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 152 - **kwargs**: Reserved for multiple inheritance. 153 """ 154 super().__init__( 155 backbone=backbone, 156 heads=heads, 157 non_algorithmic_hparams=non_algorithmic_hparams, 158 disable_unlearning=disable_unlearning, 159 **kwargs, 160 ) 161 162 def on_train_start(self) -> None: 163 """Record backbone parameters before training current task.""" 164 Finetuning.on_train_start(self) 165 AmnesiacCLAlgorithm.on_train_start(self) 166 167 def on_train_end(self) -> None: 168 """Record backbone parameters before training current task.""" 169 Finetuning.on_train_end(self) 170 AmnesiacCLAlgorithm.on_train_end(self)
class
Finetuning(clarena.cl_algorithms.base.CLAlgorithm):
22class Finetuning(CLAlgorithm): 23 r"""Finetuning algorithm. 24 25 The most naive way for task-incremental learning. It simply initializes the backbone from the last task when training new task. 26 """ 27 28 def __init__( 29 self, 30 backbone: CLBackbone, 31 heads: HeadsTIL | HeadsCIL | HeadDIL, 32 non_algorithmic_hparams: dict[str, Any] = {}, 33 **kwargs, 34 ) -> None: 35 r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters. 36 37 **Args:** 38 - **backbone** (`CLBackbone`): backbone network. 39 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 40 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 41 - **kwargs**: Reserved for multiple inheritance. 42 """ 43 super().__init__( 44 backbone=backbone, 45 heads=heads, 46 non_algorithmic_hparams=non_algorithmic_hparams, 47 **kwargs, 48 ) 49 50 def training_step(self, batch: Any) -> dict[str, Tensor]: 51 """Training step for current task `self.task_id`. 52 53 **Args:** 54 - **batch** (`Any`): a batch of training data. 55 56 **Returns:** 57 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 58 """ 59 x, y = batch 60 61 # classification loss 62 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 63 loss_cls = self.criterion(logits, y) 64 65 # total loss 66 loss = loss_cls 67 68 # predicted labels 69 preds = logits.argmax(dim=1) 70 71 # accuracy of the batch 72 acc = (preds == y).float().mean() 73 74 return { 75 "preds": preds, 76 "loss": loss, # return loss is essential for training step, or backpropagation will fail 77 "loss_cls": loss_cls, 78 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 79 "activations": activations, 80 } 81 82 def validation_step(self, batch: Any) -> dict[str, Tensor]: 83 r"""Validation step for current task `self.task_id`. 84 85 **Args:** 86 - **batch** (`Any`): a batch of validation data. 87 88 **Returns:** 89 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) are the metrics names, value (`Tensor`) are the metrics. 90 """ 91 x, y = batch 92 logits, _ = self.forward(x, stage="validation", task_id=self.task_id) 93 loss_cls = self.criterion(logits, y) 94 preds = logits.argmax(dim=1) 95 acc = (preds == y).float().mean() 96 97 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 98 return { 99 "preds": preds, 100 "loss_cls": loss_cls, 101 "acc": acc, 102 } 103 104 def test_step( 105 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 106 ) -> dict[str, Tensor]: 107 r"""Test step for current task `self.task_id`, which tests all seen tasks indexed by `dataloader_idx`. 108 109 **Args:** 110 - **batch** (`Any`): a batch of test data. 111 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 112 113 **Returns:** 114 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) are the metrics name, value (`Tensor`) are the metrics. 115 """ 116 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 117 118 x, y = batch 119 120 logits, _ = self.forward( 121 x, stage="test", task_id=test_task_id 122 ) # use the corresponding head to test (instead of the current task `self.task_id`) 123 loss_cls = self.criterion(logits, y) 124 preds = logits.argmax(dim=1) 125 acc = (preds == y).float().mean() 126 127 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 128 return { 129 "preds": preds, 130 "loss_cls": loss_cls, 131 "acc": acc, 132 }
Finetuning algorithm.
The most naive way for task-incremental learning. It simply initializes the backbone from the last task when training new task.
Finetuning( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
28 def __init__( 29 self, 30 backbone: CLBackbone, 31 heads: HeadsTIL | HeadsCIL | HeadDIL, 32 non_algorithmic_hparams: dict[str, Any] = {}, 33 **kwargs, 34 ) -> None: 35 r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters. 36 37 **Args:** 38 - **backbone** (`CLBackbone`): backbone network. 39 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 40 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 41 - **kwargs**: Reserved for multiple inheritance. 42 """ 43 super().__init__( 44 backbone=backbone, 45 heads=heads, 46 non_algorithmic_hparams=non_algorithmic_hparams, 47 **kwargs, 48 )
Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL|HeadDIL): output heads. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
def
training_step(self, batch: Any) -> dict[str, torch.Tensor]:
50 def training_step(self, batch: Any) -> dict[str, Tensor]: 51 """Training step for current task `self.task_id`. 52 53 **Args:** 54 - **batch** (`Any`): a batch of training data. 55 56 **Returns:** 57 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 58 """ 59 x, y = batch 60 61 # classification loss 62 logits, activations = self.forward(x, stage="train", task_id=self.task_id) 63 loss_cls = self.criterion(logits, y) 64 65 # total loss 66 loss = loss_cls 67 68 # predicted labels 69 preds = logits.argmax(dim=1) 70 71 # accuracy of the batch 72 acc = (preds == y).float().mean() 73 74 return { 75 "preds": preds, 76 "loss": loss, # return loss is essential for training step, or backpropagation will fail 77 "loss_cls": loss_cls, 78 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 79 "activations": activations, 80 }
Training step for current task self.task_id.
Args:
- batch (
Any): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def
validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
82 def validation_step(self, batch: Any) -> dict[str, Tensor]: 83 r"""Validation step for current task `self.task_id`. 84 85 **Args:** 86 - **batch** (`Any`): a batch of validation data. 87 88 **Returns:** 89 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) are the metrics names, value (`Tensor`) are the metrics. 90 """ 91 x, y = batch 92 logits, _ = self.forward(x, stage="validation", task_id=self.task_id) 93 loss_cls = self.criterion(logits, y) 94 preds = logits.argmax(dim=1) 95 acc = (preds == y).float().mean() 96 97 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 98 return { 99 "preds": preds, 100 "loss_cls": loss_cls, 101 "acc": acc, 102 }
Validation step for current task self.task_id.
Args:
- batch (
Any): a batch of validation data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Key (str) are the metrics names, value (Tensor) are the metrics.
def
test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
104 def test_step( 105 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 106 ) -> dict[str, Tensor]: 107 r"""Test step for current task `self.task_id`, which tests all seen tasks indexed by `dataloader_idx`. 108 109 **Args:** 110 - **batch** (`Any`): a batch of test data. 111 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 112 113 **Returns:** 114 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) are the metrics name, value (`Tensor`) are the metrics. 115 """ 116 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 117 118 x, y = batch 119 120 logits, _ = self.forward( 121 x, stage="test", task_id=test_task_id 122 ) # use the corresponding head to test (instead of the current task `self.task_id`) 123 loss_cls = self.criterion(logits, y) 124 preds = logits.argmax(dim=1) 125 acc = (preds == y).float().mean() 126 127 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 128 return { 129 "preds": preds, 130 "loss_cls": loss_cls, 131 "acc": acc, 132 }
Test step for current task self.task_id, which tests all seen tasks indexed by dataloader_idx.
Args:
- batch (
Any): a batch of test data. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Key (str) are the metrics name, value (Tensor) are the metrics.
135class AmnesiacFinetuning(AmnesiacCLAlgorithm, Finetuning): 136 r"""Amnesiac Finetuning algorithm.""" 137 138 def __init__( 139 self, 140 backbone: CLBackbone, 141 heads: HeadsTIL | HeadsCIL | HeadDIL, 142 non_algorithmic_hparams: dict[str, Any] = {}, 143 disable_unlearning: bool = False, 144 **kwargs, 145 ) -> None: 146 r"""Initialize the Amnesiac Finetuning algorithm with the network. 147 148 **Args:** 149 - **backbone** (`CLBackbone`): backbone network. 150 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 151 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 152 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 153 - **kwargs**: Reserved for multiple inheritance. 154 """ 155 super().__init__( 156 backbone=backbone, 157 heads=heads, 158 non_algorithmic_hparams=non_algorithmic_hparams, 159 disable_unlearning=disable_unlearning, 160 **kwargs, 161 ) 162 163 def on_train_start(self) -> None: 164 """Record backbone parameters before training current task.""" 165 Finetuning.on_train_start(self) 166 AmnesiacCLAlgorithm.on_train_start(self) 167 168 def on_train_end(self) -> None: 169 """Record backbone parameters before training current task.""" 170 Finetuning.on_train_end(self) 171 AmnesiacCLAlgorithm.on_train_end(self)
Amnesiac Finetuning algorithm.
AmnesiacFinetuning( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, non_algorithmic_hparams: dict[str, typing.Any] = {}, disable_unlearning: bool = False, **kwargs)
138 def __init__( 139 self, 140 backbone: CLBackbone, 141 heads: HeadsTIL | HeadsCIL | HeadDIL, 142 non_algorithmic_hparams: dict[str, Any] = {}, 143 disable_unlearning: bool = False, 144 **kwargs, 145 ) -> None: 146 r"""Initialize the Amnesiac Finetuning algorithm with the network. 147 148 **Args:** 149 - **backbone** (`CLBackbone`): backbone network. 150 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 151 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 152 - **disable_unlearning** (`bool`): whether to disable the unlearning functionality. Default is `False`. 153 - **kwargs**: Reserved for multiple inheritance. 154 """ 155 super().__init__( 156 backbone=backbone, 157 heads=heads, 158 non_algorithmic_hparams=non_algorithmic_hparams, 159 disable_unlearning=disable_unlearning, 160 **kwargs, 161 )
Initialize the Amnesiac Finetuning algorithm with the network.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL|HeadDIL): output heads. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - disable_unlearning (
bool): whether to disable the unlearning functionality. Default isFalse. - kwargs: Reserved for multiple inheritance.
def
on_train_start(self) -> None:
163 def on_train_start(self) -> None: 164 """Record backbone parameters before training current task.""" 165 Finetuning.on_train_start(self) 166 AmnesiacCLAlgorithm.on_train_start(self)
Record backbone parameters before training current task.
def
on_train_end(self) -> None:
168 def on_train_end(self) -> None: 169 """Record backbone parameters before training current task.""" 170 Finetuning.on_train_end(self) 171 AmnesiacCLAlgorithm.on_train_end(self)
Record backbone parameters before training current task.