clarena.cl_algorithms
Continual Learning Algorithms
This submodule provides the continual learning algorithms in CLArena.
Here are the base classes for CL algorithms, which inherit from PyTorch Lightning LightningModule:
CLAlgorithm: the base class for all continual learning algorithms.UnlearnableCLAlgorithm: the base class for unlearnable continual learning algorithms.
Please note that this is an API documentation. Please refer to the main documentation pages for more information about and how to configure and implement CL algorithms:
1r""" 2 3# Continual Learning Algorithms 4 5This submodule provides the **continual learning algorithms** in CLArena. 6 7Here are the base classes for CL algorithms, which inherit from PyTorch Lightning `LightningModule`: 8 9- `CLAlgorithm`: the base class for all continual learning algorithms. 10 - `UnlearnableCLAlgorithm`: the base class for unlearnable continual learning algorithms. 11 12Please note that this is an API documentation. Please refer to the main documentation pages for more information about and how to configure and implement CL algorithms: 13 14- [**Configure CL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/cl-algorithm) 15- [**Implement Custom CL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/cl-algorithm) 16- [**A Beginners' Guide to Continual Learning (Methodology Overview)**](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-methodology) 17 18 19""" 20 21from .base import CLAlgorithm, UnlearnableCLAlgorithm 22 23# finetuning first 24from .finetuning import Finetuning 25from .independent import Independent, UnlearnableIndependent 26from .fix import Fix 27from .random import Random 28 29from .lwf import LwF 30from .ewc import EWC 31from .cbp import CBP 32 33from .hat import HAT 34from .adahat import AdaHAT 35from .fgadahat import FGAdaHAT 36from .wsn import WSN 37 38# from .nispa import NISPA 39 40 41__all__ = [ 42 "CLAlgorithm", 43 "UnlearnableCLAlgorithm", 44 "regularizers", 45 "finetuning", 46 "independent", 47 "fix", 48 "random", 49 "lwf", 50 "ewc", 51 "cbp", 52 "hat", 53 "adahat", 54 "fgadahat", 55 "wsn", 56 # "nispa", 57]
23class CLAlgorithm(LightningModule): 24 r"""The base class of continual learning algorithms.""" 25 26 def __init__( 27 self, 28 backbone: CLBackbone, 29 heads: HeadsTIL | HeadsCIL, 30 non_algorithmic_hparams: dict[str, Any] = {}, 31 ) -> None: 32 r""" 33 **Args:** 34 - **backbone** (`CLBackbone`): backbone network. 35 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 36 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 37 """ 38 super().__init__() 39 self.save_hyperparameters(non_algorithmic_hparams) 40 41 # components 42 self.backbone: CLBackbone = backbone 43 r"""The backbone network.""" 44 self.heads: HeadsTIL | HeadsCIL = heads 45 r"""The output heads.""" 46 self.optimizer_t: Optimizer 47 r"""Optimizer (partially initialized) for the current task `self.task_id`. Will be equipped with parameters in `configure_optimizers()`.""" 48 self.lr_scheduler_t: LRScheduler | None 49 r"""Learning rate scheduler for the optimizer of the current task `self.task_id`. If `None`, no scheduler is used.""" 50 self.criterion = nn.CrossEntropyLoss() 51 r"""Loss function between the output logits and the target labels. Default is cross-entropy loss.""" 52 53 self.if_forward_func_return_logits_only: bool = False 54 r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`.""" 55 56 # task ID control 57 self.task_id: int 58 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 59 self.processed_task_ids: list[int] = [] 60 r"""Task IDs that have been processed.""" 61 62 CLAlgorithm.sanity_check(self) 63 64 def sanity_check(self) -> None: 65 r"""Sanity check.""" 66 67 # check backbone and heads compatibility 68 if self.backbone.output_dim != self.heads.input_dim: 69 raise ValueError( 70 "The output_dim of the backbone must equal the input_dim of the CL heads." 71 ) 72 73 def setup_task_id( 74 self, 75 task_id: int, 76 num_classes: int, 77 optimizer: Optimizer, 78 lr_scheduler: LRScheduler | None, 79 ) -> None: 80 r"""Set up which task the CL experiment is on. This must be done before `forward()` method is called. 81 82 **Args:** 83 - **task_id** (`int`): the target task ID. 84 - **num_classes** (`int`): the number of classes in the task. 85 - **optimizer** (`Optimizer`): the optimizer object (partially initialized) for the task. 86 - **lr_scheduler** (`LRScheduler` | `None`): the learning rate scheduler for the optimizer. If `None`, no scheduler is used. 87 """ 88 self.task_id = task_id 89 self.processed_task_ids.append(task_id) 90 self.backbone.setup_task_id(task_id=task_id) 91 self.heads.setup_task_id(task_id, num_classes) 92 self.optimizer_t = optimizer 93 self.lr_scheduler_t = lr_scheduler 94 95 def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 96 r"""Get the test task ID from the dataloader index. 97 98 **Args:** 99 - **dataloader_idx** (`int`): the dataloader index. 100 101 **Returns:** 102 - **test_task_id** (`int`): the test task ID. 103 """ 104 dataset_test = self.trainer.datamodule.dataset_test 105 test_task_id = list(dataset_test.keys())[dataloader_idx] 106 return test_task_id 107 108 def set_forward_func_return_logits_only( 109 self, forward_func_return_logits_only: bool 110 ) -> None: 111 r"""Set whether the `forward()` method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT. 112 113 **Args:** 114 - **forward_func_return_logits_only** (`bool`): whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. 115 """ 116 self.if_forward_func_return_logits_only = forward_func_return_logits_only 117 118 def preceding_layer(self, layer_name: str) -> nn.Module | None: 119 r"""Get the preceding layer of the given layer (including backbone and output heads). If the given layer is the first layer, return `None`. 120 121 **Args:** 122 - **layer_name** (`str`): the name of the layer. 123 124 **Returns:** 125 - **preceding_layer** (`nn.Module` | `None`): the preceding layer. 126 """ 127 128 if layer_name == "heads": 129 backbone_last_layer_name = self.backbone.weighted_layer_names[-1] 130 backbone_last_layer = self.backbone.get_layer_by_name( 131 backbone_last_layer_name 132 ) 133 return backbone_last_layer 134 else: 135 preceding_layer_name = self.backbone.preceding_layer_name(layer_name) 136 preceding_layer = self.backbone.get_layer_by_name(preceding_layer_name) 137 138 return preceding_layer 139 140 def next_layer(self, layer_name: str) -> nn.Module | None: 141 r"""Get the next layer of the given layer (including backbone and output heads). If the given layer is the last layer, return `None`. 142 143 **Args:** 144 - **layer_name** (`str`): the name of the layer. 145 146 **Returns:** 147 - **preceding_layer** (`nn.Module` | `None`): the next layer. 148 """ 149 150 if layer_name == "heads": 151 return None 152 else: 153 next_layer_name = self.backbone.next_layer_name(layer_name) 154 if next_layer_name is not None: 155 next_layer = self.backbone.get_layer_by_name(next_layer_name) 156 else: 157 next_layer = self.heads.get_head(self.task_id) 158 159 return next_layer 160 161 def forward(self, input: Tensor, stage: str, task_id: int | None = None) -> Tensor: 162 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many CL algorithm including the vanilla Finetuning algorithm use. It works both for TIL and CIL. 163 164 **Args:** 165 - **input** (`Tensor`): the input tensor from data. 166 - **stage** (`str`): the stage of the forward pass; one of: 167 1. 'train': training stage. 168 2. 'validation': validation stage. 169 3. 'test': testing stage. 170 - **task_id** (`int`): the task ID where the data are from. If stage is 'train' or `validation`, it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. 171 172 **Returns:** 173 - **logits** (`Tensor`): the output logits tensor. 174 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 175 """ 176 feature, activations = self.backbone(input, stage=stage, task_id=task_id) 177 logits = self.heads(feature, task_id) 178 return ( 179 logits if self.if_forward_func_return_logits_only else (logits, activations) 180 ) 181 182 def configure_optimizers(self) -> Optimizer: 183 r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details.""" 184 # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `CLAlgorithm` (inherited from `LightningModule`) returns both backbone and heads parameters 185 fully_initialized_optimizer = self.optimizer_t(params=self.parameters()) 186 187 if self.lr_scheduler_t: 188 fully_initialized_lr_scheduler = self.lr_scheduler_t( 189 optimizer=fully_initialized_optimizer 190 ) 191 192 return { 193 "optimizer": fully_initialized_optimizer, 194 "lr_scheduler": { 195 "scheduler": fully_initialized_lr_scheduler, 196 "monitor": f"task_{self.task_id}/learning_curve/val/loss_cls", 197 "interval": "epoch", 198 "frequency": 1, 199 }, 200 } 201 202 return {"optimizer": fully_initialized_optimizer}
The base class of continual learning algorithms.
26 def __init__( 27 self, 28 backbone: CLBackbone, 29 heads: HeadsTIL | HeadsCIL, 30 non_algorithmic_hparams: dict[str, Any] = {}, 31 ) -> None: 32 r""" 33 **Args:** 34 - **backbone** (`CLBackbone`): backbone network. 35 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 36 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 37 """ 38 super().__init__() 39 self.save_hyperparameters(non_algorithmic_hparams) 40 41 # components 42 self.backbone: CLBackbone = backbone 43 r"""The backbone network.""" 44 self.heads: HeadsTIL | HeadsCIL = heads 45 r"""The output heads.""" 46 self.optimizer_t: Optimizer 47 r"""Optimizer (partially initialized) for the current task `self.task_id`. Will be equipped with parameters in `configure_optimizers()`.""" 48 self.lr_scheduler_t: LRScheduler | None 49 r"""Learning rate scheduler for the optimizer of the current task `self.task_id`. If `None`, no scheduler is used.""" 50 self.criterion = nn.CrossEntropyLoss() 51 r"""Loss function between the output logits and the target labels. Default is cross-entropy loss.""" 52 53 self.if_forward_func_return_logits_only: bool = False 54 r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`.""" 55 56 # task ID control 57 self.task_id: int 58 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 59 self.processed_task_ids: list[int] = [] 60 r"""Task IDs that have been processed.""" 61 62 CLAlgorithm.sanity_check(self)
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL): output heads. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility.
Optimizer (partially initialized) for the current task self.task_id. Will be equipped with parameters in configure_optimizers().
Learning rate scheduler for the optimizer of the current task self.task_id. If None, no scheduler is used.
Loss function between the output logits and the target labels. Default is cross-entropy loss.
Whether the forward() method returns logits only. If False, it returns a dictionary containing logits and other information. Default is False.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.
64 def sanity_check(self) -> None: 65 r"""Sanity check.""" 66 67 # check backbone and heads compatibility 68 if self.backbone.output_dim != self.heads.input_dim: 69 raise ValueError( 70 "The output_dim of the backbone must equal the input_dim of the CL heads." 71 )
Sanity check.
73 def setup_task_id( 74 self, 75 task_id: int, 76 num_classes: int, 77 optimizer: Optimizer, 78 lr_scheduler: LRScheduler | None, 79 ) -> None: 80 r"""Set up which task the CL experiment is on. This must be done before `forward()` method is called. 81 82 **Args:** 83 - **task_id** (`int`): the target task ID. 84 - **num_classes** (`int`): the number of classes in the task. 85 - **optimizer** (`Optimizer`): the optimizer object (partially initialized) for the task. 86 - **lr_scheduler** (`LRScheduler` | `None`): the learning rate scheduler for the optimizer. If `None`, no scheduler is used. 87 """ 88 self.task_id = task_id 89 self.processed_task_ids.append(task_id) 90 self.backbone.setup_task_id(task_id=task_id) 91 self.heads.setup_task_id(task_id, num_classes) 92 self.optimizer_t = optimizer 93 self.lr_scheduler_t = lr_scheduler
Set up which task the CL experiment is on. This must be done before forward() method is called.
Args:
- task_id (
int): the target task ID. - num_classes (
int): the number of classes in the task. - optimizer (
Optimizer): the optimizer object (partially initialized) for the task. - lr_scheduler (
LRScheduler|None): the learning rate scheduler for the optimizer. IfNone, no scheduler is used.
95 def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 96 r"""Get the test task ID from the dataloader index. 97 98 **Args:** 99 - **dataloader_idx** (`int`): the dataloader index. 100 101 **Returns:** 102 - **test_task_id** (`int`): the test task ID. 103 """ 104 dataset_test = self.trainer.datamodule.dataset_test 105 test_task_id = list(dataset_test.keys())[dataloader_idx] 106 return test_task_id
Get the test task ID from the dataloader index.
Args:
- dataloader_idx (
int): the dataloader index.
Returns:
- test_task_id (
int): the test task ID.
108 def set_forward_func_return_logits_only( 109 self, forward_func_return_logits_only: bool 110 ) -> None: 111 r"""Set whether the `forward()` method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT. 112 113 **Args:** 114 - **forward_func_return_logits_only** (`bool`): whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. 115 """ 116 self.if_forward_func_return_logits_only = forward_func_return_logits_only
Set whether the forward() method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT.
Args:
- forward_func_return_logits_only (
bool): whether theforward()method returns logits only. IfFalse, it returns a dictionary containing logits and other information.
118 def preceding_layer(self, layer_name: str) -> nn.Module | None: 119 r"""Get the preceding layer of the given layer (including backbone and output heads). If the given layer is the first layer, return `None`. 120 121 **Args:** 122 - **layer_name** (`str`): the name of the layer. 123 124 **Returns:** 125 - **preceding_layer** (`nn.Module` | `None`): the preceding layer. 126 """ 127 128 if layer_name == "heads": 129 backbone_last_layer_name = self.backbone.weighted_layer_names[-1] 130 backbone_last_layer = self.backbone.get_layer_by_name( 131 backbone_last_layer_name 132 ) 133 return backbone_last_layer 134 else: 135 preceding_layer_name = self.backbone.preceding_layer_name(layer_name) 136 preceding_layer = self.backbone.get_layer_by_name(preceding_layer_name) 137 138 return preceding_layer
Get the preceding layer of the given layer (including backbone and output heads). If the given layer is the first layer, return None.
Args:
- layer_name (
str): the name of the layer.
Returns:
- preceding_layer (
nn.Module|None): the preceding layer.
140 def next_layer(self, layer_name: str) -> nn.Module | None: 141 r"""Get the next layer of the given layer (including backbone and output heads). If the given layer is the last layer, return `None`. 142 143 **Args:** 144 - **layer_name** (`str`): the name of the layer. 145 146 **Returns:** 147 - **preceding_layer** (`nn.Module` | `None`): the next layer. 148 """ 149 150 if layer_name == "heads": 151 return None 152 else: 153 next_layer_name = self.backbone.next_layer_name(layer_name) 154 if next_layer_name is not None: 155 next_layer = self.backbone.get_layer_by_name(next_layer_name) 156 else: 157 next_layer = self.heads.get_head(self.task_id) 158 159 return next_layer
Get the next layer of the given layer (including backbone and output heads). If the given layer is the last layer, return None.
Args:
- layer_name (
str): the name of the layer.
Returns:
- preceding_layer (
nn.Module|None): the next layer.
161 def forward(self, input: Tensor, stage: str, task_id: int | None = None) -> Tensor: 162 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many CL algorithm including the vanilla Finetuning algorithm use. It works both for TIL and CIL. 163 164 **Args:** 165 - **input** (`Tensor`): the input tensor from data. 166 - **stage** (`str`): the stage of the forward pass; one of: 167 1. 'train': training stage. 168 2. 'validation': validation stage. 169 3. 'test': testing stage. 170 - **task_id** (`int`): the task ID where the data are from. If stage is 'train' or `validation`, it is usually from the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value. 171 172 **Returns:** 173 - **logits** (`Tensor`): the output logits tensor. 174 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 175 """ 176 feature, activations = self.backbone(input, stage=stage, task_id=task_id) 177 logits = self.heads(feature, task_id) 178 return ( 179 logits if self.if_forward_func_return_logits_only else (logits, activations) 180 )
The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module. This definition provides a template that many CL algorithm including the vanilla Finetuning algorithm use. It works both for TIL and CIL.
Args:
- input (
Tensor): the input tensor from data. - stage (
str): the stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int): the task ID where the data are from. If stage is 'train' orvalidation, it is usually from the current taskself.task_id. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. In CIL, they are not provided, so it is just a placeholder for API consistence but never used, and best practices are not to provide this argument and leave it as the default value.
Returns:
- logits (
Tensor): the output logits tensor. - activations (
dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
182 def configure_optimizers(self) -> Optimizer: 183 r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details.""" 184 # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `CLAlgorithm` (inherited from `LightningModule`) returns both backbone and heads parameters 185 fully_initialized_optimizer = self.optimizer_t(params=self.parameters()) 186 187 if self.lr_scheduler_t: 188 fully_initialized_lr_scheduler = self.lr_scheduler_t( 189 optimizer=fully_initialized_optimizer 190 ) 191 192 return { 193 "optimizer": fully_initialized_optimizer, 194 "lr_scheduler": { 195 "scheduler": fully_initialized_lr_scheduler, 196 "monitor": f"task_{self.task_id}/learning_curve/val/loss_cls", 197 "interval": "epoch", 198 "frequency": 1, 199 }, 200 } 201 202 return {"optimizer": fully_initialized_optimizer}
Configure optimizer hooks by Lightning. See Lightning docs for more details.
205class UnlearnableCLAlgorithm(CLAlgorithm): 206 r"""The base class of unlearnable continual learning algorithms.""" 207 208 def __init__( 209 self, 210 backbone: CLBackbone, 211 heads: HeadsTIL | HeadsCIL, 212 non_algorithmic_hparams: dict[str, Any] = {}, 213 ) -> None: 214 r""" 215 **Args:** 216 - **backbone** (`CLBackbone`): backbone network. 217 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 218 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 219 """ 220 super().__init__( 221 backbone=backbone, 222 heads=heads, 223 non_algorithmic_hparams=non_algorithmic_hparams, 224 ) 225 226 self.unlearning_task_ids: list[int] 227 r"""The list of task IDs that are requested to be unlearned after training `self.task_id`.""" 228 229 self.unlearned_task_ids: set[int] = set() 230 r"""The list of task IDs that have been unlearned in the experiment.""" 231 232 UnlearnableCLAlgorithm.sanity_check(self) 233 234 def sanity_check(self) -> None: 235 r"""Sanity check.""" 236 237 def aggregated_backbone_output(self, input: Tensor) -> Tensor: 238 r"""Get the aggregated backbone output for the input data. All parts of backbones should be aggregated together. 239 240 This output feature is used for measuring unlearning metrics, such as Distribution Distance (DD). An aggregated output involving every part of the backbone is needed to ensure the fairness of the metric. 241 242 **Args:** 243 - **input** (`Tensor`): the input tensor from data. 244 245 **Returns:** 246 - **output** (`Tensor`): the aggregated backbone output tensor. 247 """ 248 feature = 0 249 250 for i in self.processed_task_ids: 251 feature_i = self.backbone(input, stage="train", task_id=i)[0] 252 feature += feature_i 253 feature = feature / len(self.processed_task_ids) 254 255 return feature
The base class of unlearnable continual learning algorithms.
208 def __init__( 209 self, 210 backbone: CLBackbone, 211 heads: HeadsTIL | HeadsCIL, 212 non_algorithmic_hparams: dict[str, Any] = {}, 213 ) -> None: 214 r""" 215 **Args:** 216 - **backbone** (`CLBackbone`): backbone network. 217 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 218 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 219 """ 220 super().__init__( 221 backbone=backbone, 222 heads=heads, 223 non_algorithmic_hparams=non_algorithmic_hparams, 224 ) 225 226 self.unlearning_task_ids: list[int] 227 r"""The list of task IDs that are requested to be unlearned after training `self.task_id`.""" 228 229 self.unlearned_task_ids: set[int] = set() 230 r"""The list of task IDs that have been unlearned in the experiment.""" 231 232 UnlearnableCLAlgorithm.sanity_check(self)
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL): output heads. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility.
The list of task IDs that are requested to be unlearned after training self.task_id.
237 def aggregated_backbone_output(self, input: Tensor) -> Tensor: 238 r"""Get the aggregated backbone output for the input data. All parts of backbones should be aggregated together. 239 240 This output feature is used for measuring unlearning metrics, such as Distribution Distance (DD). An aggregated output involving every part of the backbone is needed to ensure the fairness of the metric. 241 242 **Args:** 243 - **input** (`Tensor`): the input tensor from data. 244 245 **Returns:** 246 - **output** (`Tensor`): the aggregated backbone output tensor. 247 """ 248 feature = 0 249 250 for i in self.processed_task_ids: 251 feature_i = self.backbone(input, stage="train", task_id=i)[0] 252 feature += feature_i 253 feature = feature / len(self.processed_task_ids) 254 255 return feature
Get the aggregated backbone output for the input data. All parts of backbones should be aggregated together.
This output feature is used for measuring unlearning metrics, such as Distribution Distance (DD). An aggregated output involving every part of the backbone is needed to ensure the fairness of the metric.
Args:
- input (
Tensor): the input tensor from data.
Returns:
- output (
Tensor): the aggregated backbone output tensor.