clarena.cl_algorithms.cbp
The submodule in cl_algorithms for CBP (Continual Backpropagation) algorithm.
1r""" 2The submodule in `cl_algorithms` for [CBP (Continual Backpropagation)](https://www.nature.com/articles/s41586-024-07711-7) algorithm. 3""" 4 5__all__ = ["CBP"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12 13from clarena.backbones import CLBackbone 14from clarena.cl_algorithms import Finetuning 15from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL 16from clarena.utils.transforms import min_max_normalize 17 18# always get logger for built-in logging in each module 19pylogger = logging.getLogger(__name__) 20 21 22class CBP(Finetuning): 23 r"""[CBP (Continual Backpropagation)](https://www.nature.com/articles/s41586-024-07711-7) algorithm. 24 25 A continual learning approach that reinitializes a small number of units during training, using an utility measures to determine which units to reinitialize. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning. 26 27 We implement CBP as a subclass of Finetuning algorithm, as CBP has the same `forward()`, `training_step()`, `validation_step()` and `test_step()` method as `Finetuning` class. 28 """ 29 30 def __init__( 31 self, 32 backbone: CLBackbone, 33 heads: HeadsTIL | HeadsCIL | HeadDIL, 34 replacement_rate: float, 35 maturity_threshold: int, 36 utility_decay_rate: float, 37 non_algorithmic_hparams: dict[str, Any] = {}, 38 **kwargs, 39 ) -> None: 40 r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters. 41 42 **Args:** 43 - **backbone** (`CLBackbone`): backbone network. 44 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 45 - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialized during training. 46 - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialized. 47 - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time. 48 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 49 - **kwargs**: Reserved for multiple inheritance. 50 51 """ 52 super().__init__( 53 backbone=backbone, 54 heads=heads, 55 non_algorithmic_hparams=non_algorithmic_hparams, 56 **kwargs, 57 ) 58 59 self.replacement_rate: float = replacement_rate 60 r"""The replacement rate of units. """ 61 self.maturity_threshold: int = maturity_threshold 62 r"""The maturity threshold of units. """ 63 self.utility_decay_rate: float = utility_decay_rate 64 r"""The utility decay rate of units. """ 65 66 # save additional algorithmic hyperparameters 67 self.save_hyperparameters( 68 "replacement_rate", 69 "maturity_threshold", 70 "utility_decay_rate", 71 ) 72 73 self.contribution_utility: dict[str, Tensor] = {} 74 r"""The contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 75 self.num_replacements: dict[str, int] = {} 76 r"""The number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """ 77 self.age: dict[str, Tensor] = {} 78 r"""The age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """ 79 80 def on_train_start(self) -> None: 81 r"""Initialize the utility, number of replacements and age for each layer as zeros.""" 82 83 # initialize the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 84 if self.task_id == 1: 85 for layer_name in self.backbone.weighted_layer_names: 86 layer = self.backbone.get_layer_by_name( 87 layer_name 88 ) # get the layer by its name 89 num_units = layer.weight.shape[0] 90 91 self.contribution_utility[layer_name] = torch.zeros(num_units).to( 92 self.device 93 ) 94 self.num_replacements[layer_name] = 0 95 self.age[layer_name] = torch.zeros(num_units).to(self.device) 96 97 def on_train_batch_end( 98 self, outputs: dict[str, Any], batch: Any, batch_idx: int 99 ) -> None: 100 r"""Update the contribution utility and age of units after each training step, and conduct reinitialization of units based on utility measures. This is the core of the CBP algorithm. 101 102 **Args:** 103 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 104 - **batch** (`Any`): the training data batch. 105 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 106 """ 107 108 activations = outputs["activations"] 109 110 for layer_name in self.backbone.weighted_layer_names: 111 # layer-wise operation 112 113 layer = self.backbone.get_layer_by_name( 114 layer_name 115 ) # get the layer by its name 116 117 # update age 118 self.age[layer_name] += 1 119 120 # calculate current contribution utility 121 current_contribution_utility = ( 122 torch.mean( 123 torch.abs(activations[layer_name]), 124 dim=0, # average the features over batch samples 125 ) 126 * torch.sum( 127 torch.abs(layer.weight), 128 dim=1, # sum over the output dimension 129 ) 130 ).detach() 131 current_contribution_utility = min_max_normalize( 132 current_contribution_utility 133 ) # normalize the utility to [0,1] to avoid linearly increasing utility 134 135 # update utility 136 self.contribution_utility[layer_name] = ( 137 self.utility_decay_rate * self.contribution_utility[layer_name] 138 + (1 - self.utility_decay_rate) * current_contribution_utility 139 ) 140 141 # find eligible units 142 eligible_mask = self.age[layer_name] > self.maturity_threshold 143 eligible_indices = torch.where(eligible_mask)[0] 144 145 # update the number of replacements 146 num_eligible_units = eligible_indices.numel() 147 self.num_replacements[layer_name] += int( 148 self.replacement_rate * num_eligible_units 149 ) 150 151 # if the number of replacements is greater than 1, execute the replacement 152 if self.num_replacements[layer_name] > 1: 153 154 # find the unit with smallest utility among eligible units 155 replaced_unit_idx = eligible_indices[ 156 torch.argmin( 157 self.contribution_utility[layer_name][eligible_indices] 158 / self.age[layer_name][eligible_indices] 159 ).item() 160 ] 161 162 # reinitialize the input weights of the unit 163 preceding_layer = self.backbone.preceding_layer(layer_name) 164 if preceding_layer is not None: 165 166 with torch.no_grad(): 167 168 preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like( 169 preceding_layer.weight[:, replaced_unit_idx] 170 ) 171 172 # reinitalize the output weights of the unit 173 with torch.no_grad(): 174 layer.weight[replaced_unit_idx] = torch.rand_like( 175 layer.weight[replaced_unit_idx] 176 ) 177 178 # reinitialize utility 179 self.contribution_utility[layer_name][replaced_unit_idx] = 0.0 180 181 # reintialize age 182 self.age[layer_name][replaced_unit_idx] = 0 183 184 # update the number of replacements 185 self.num_replacements[layer_name] -= 1
23class CBP(Finetuning): 24 r"""[CBP (Continual Backpropagation)](https://www.nature.com/articles/s41586-024-07711-7) algorithm. 25 26 A continual learning approach that reinitializes a small number of units during training, using an utility measures to determine which units to reinitialize. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning. 27 28 We implement CBP as a subclass of Finetuning algorithm, as CBP has the same `forward()`, `training_step()`, `validation_step()` and `test_step()` method as `Finetuning` class. 29 """ 30 31 def __init__( 32 self, 33 backbone: CLBackbone, 34 heads: HeadsTIL | HeadsCIL | HeadDIL, 35 replacement_rate: float, 36 maturity_threshold: int, 37 utility_decay_rate: float, 38 non_algorithmic_hparams: dict[str, Any] = {}, 39 **kwargs, 40 ) -> None: 41 r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters. 42 43 **Args:** 44 - **backbone** (`CLBackbone`): backbone network. 45 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 46 - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialized during training. 47 - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialized. 48 - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time. 49 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 50 - **kwargs**: Reserved for multiple inheritance. 51 52 """ 53 super().__init__( 54 backbone=backbone, 55 heads=heads, 56 non_algorithmic_hparams=non_algorithmic_hparams, 57 **kwargs, 58 ) 59 60 self.replacement_rate: float = replacement_rate 61 r"""The replacement rate of units. """ 62 self.maturity_threshold: int = maturity_threshold 63 r"""The maturity threshold of units. """ 64 self.utility_decay_rate: float = utility_decay_rate 65 r"""The utility decay rate of units. """ 66 67 # save additional algorithmic hyperparameters 68 self.save_hyperparameters( 69 "replacement_rate", 70 "maturity_threshold", 71 "utility_decay_rate", 72 ) 73 74 self.contribution_utility: dict[str, Tensor] = {} 75 r"""The contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 76 self.num_replacements: dict[str, int] = {} 77 r"""The number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """ 78 self.age: dict[str, Tensor] = {} 79 r"""The age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """ 80 81 def on_train_start(self) -> None: 82 r"""Initialize the utility, number of replacements and age for each layer as zeros.""" 83 84 # initialize the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 85 if self.task_id == 1: 86 for layer_name in self.backbone.weighted_layer_names: 87 layer = self.backbone.get_layer_by_name( 88 layer_name 89 ) # get the layer by its name 90 num_units = layer.weight.shape[0] 91 92 self.contribution_utility[layer_name] = torch.zeros(num_units).to( 93 self.device 94 ) 95 self.num_replacements[layer_name] = 0 96 self.age[layer_name] = torch.zeros(num_units).to(self.device) 97 98 def on_train_batch_end( 99 self, outputs: dict[str, Any], batch: Any, batch_idx: int 100 ) -> None: 101 r"""Update the contribution utility and age of units after each training step, and conduct reinitialization of units based on utility measures. This is the core of the CBP algorithm. 102 103 **Args:** 104 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 105 - **batch** (`Any`): the training data batch. 106 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 107 """ 108 109 activations = outputs["activations"] 110 111 for layer_name in self.backbone.weighted_layer_names: 112 # layer-wise operation 113 114 layer = self.backbone.get_layer_by_name( 115 layer_name 116 ) # get the layer by its name 117 118 # update age 119 self.age[layer_name] += 1 120 121 # calculate current contribution utility 122 current_contribution_utility = ( 123 torch.mean( 124 torch.abs(activations[layer_name]), 125 dim=0, # average the features over batch samples 126 ) 127 * torch.sum( 128 torch.abs(layer.weight), 129 dim=1, # sum over the output dimension 130 ) 131 ).detach() 132 current_contribution_utility = min_max_normalize( 133 current_contribution_utility 134 ) # normalize the utility to [0,1] to avoid linearly increasing utility 135 136 # update utility 137 self.contribution_utility[layer_name] = ( 138 self.utility_decay_rate * self.contribution_utility[layer_name] 139 + (1 - self.utility_decay_rate) * current_contribution_utility 140 ) 141 142 # find eligible units 143 eligible_mask = self.age[layer_name] > self.maturity_threshold 144 eligible_indices = torch.where(eligible_mask)[0] 145 146 # update the number of replacements 147 num_eligible_units = eligible_indices.numel() 148 self.num_replacements[layer_name] += int( 149 self.replacement_rate * num_eligible_units 150 ) 151 152 # if the number of replacements is greater than 1, execute the replacement 153 if self.num_replacements[layer_name] > 1: 154 155 # find the unit with smallest utility among eligible units 156 replaced_unit_idx = eligible_indices[ 157 torch.argmin( 158 self.contribution_utility[layer_name][eligible_indices] 159 / self.age[layer_name][eligible_indices] 160 ).item() 161 ] 162 163 # reinitialize the input weights of the unit 164 preceding_layer = self.backbone.preceding_layer(layer_name) 165 if preceding_layer is not None: 166 167 with torch.no_grad(): 168 169 preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like( 170 preceding_layer.weight[:, replaced_unit_idx] 171 ) 172 173 # reinitalize the output weights of the unit 174 with torch.no_grad(): 175 layer.weight[replaced_unit_idx] = torch.rand_like( 176 layer.weight[replaced_unit_idx] 177 ) 178 179 # reinitialize utility 180 self.contribution_utility[layer_name][replaced_unit_idx] = 0.0 181 182 # reintialize age 183 self.age[layer_name][replaced_unit_idx] = 0 184 185 # update the number of replacements 186 self.num_replacements[layer_name] -= 1
CBP (Continual Backpropagation) algorithm.
A continual learning approach that reinitializes a small number of units during training, using an utility measures to determine which units to reinitialize. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning.
We implement CBP as a subclass of Finetuning algorithm, as CBP has the same forward(), training_step(), validation_step() and test_step() method as Finetuning class.
31 def __init__( 32 self, 33 backbone: CLBackbone, 34 heads: HeadsTIL | HeadsCIL | HeadDIL, 35 replacement_rate: float, 36 maturity_threshold: int, 37 utility_decay_rate: float, 38 non_algorithmic_hparams: dict[str, Any] = {}, 39 **kwargs, 40 ) -> None: 41 r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters. 42 43 **Args:** 44 - **backbone** (`CLBackbone`): backbone network. 45 - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads. 46 - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialized during training. 47 - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialized. 48 - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time. 49 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 50 - **kwargs**: Reserved for multiple inheritance. 51 52 """ 53 super().__init__( 54 backbone=backbone, 55 heads=heads, 56 non_algorithmic_hparams=non_algorithmic_hparams, 57 **kwargs, 58 ) 59 60 self.replacement_rate: float = replacement_rate 61 r"""The replacement rate of units. """ 62 self.maturity_threshold: int = maturity_threshold 63 r"""The maturity threshold of units. """ 64 self.utility_decay_rate: float = utility_decay_rate 65 r"""The utility decay rate of units. """ 66 67 # save additional algorithmic hyperparameters 68 self.save_hyperparameters( 69 "replacement_rate", 70 "maturity_threshold", 71 "utility_decay_rate", 72 ) 73 74 self.contribution_utility: dict[str, Tensor] = {} 75 r"""The contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 76 self.num_replacements: dict[str, int] = {} 77 r"""The number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """ 78 self.age: dict[str, Tensor] = {} 79 r"""The age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """
Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.
Args:
- backbone (
CLBackbone): backbone network. - heads (
HeadsTIL|HeadsCIL|HeadDIL): output heads. - replacement_rate (
float): the replacement rate of units. It is the precentage of units to be reinitialized during training. - maturity_threshold (
int): the maturity threshold of units. It is the number of training steps before a unit can be reinitialized. - utility_decay_rate (
float): the utility decay rate of units. It is the rate at which the utility of a unit decays over time. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
The contribution utility of units. See equation (1) in the continual backpropagation paper. Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ).
The number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer.
The age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units).
81 def on_train_start(self) -> None: 82 r"""Initialize the utility, number of replacements and age for each layer as zeros.""" 83 84 # initialize the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 85 if self.task_id == 1: 86 for layer_name in self.backbone.weighted_layer_names: 87 layer = self.backbone.get_layer_by_name( 88 layer_name 89 ) # get the layer by its name 90 num_units = layer.weight.shape[0] 91 92 self.contribution_utility[layer_name] = torch.zeros(num_units).to( 93 self.device 94 ) 95 self.num_replacements[layer_name] = 0 96 self.age[layer_name] = torch.zeros(num_units).to(self.device)
Initialize the utility, number of replacements and age for each layer as zeros.
98 def on_train_batch_end( 99 self, outputs: dict[str, Any], batch: Any, batch_idx: int 100 ) -> None: 101 r"""Update the contribution utility and age of units after each training step, and conduct reinitialization of units based on utility measures. This is the core of the CBP algorithm. 102 103 **Args:** 104 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 105 - **batch** (`Any`): the training data batch. 106 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 107 """ 108 109 activations = outputs["activations"] 110 111 for layer_name in self.backbone.weighted_layer_names: 112 # layer-wise operation 113 114 layer = self.backbone.get_layer_by_name( 115 layer_name 116 ) # get the layer by its name 117 118 # update age 119 self.age[layer_name] += 1 120 121 # calculate current contribution utility 122 current_contribution_utility = ( 123 torch.mean( 124 torch.abs(activations[layer_name]), 125 dim=0, # average the features over batch samples 126 ) 127 * torch.sum( 128 torch.abs(layer.weight), 129 dim=1, # sum over the output dimension 130 ) 131 ).detach() 132 current_contribution_utility = min_max_normalize( 133 current_contribution_utility 134 ) # normalize the utility to [0,1] to avoid linearly increasing utility 135 136 # update utility 137 self.contribution_utility[layer_name] = ( 138 self.utility_decay_rate * self.contribution_utility[layer_name] 139 + (1 - self.utility_decay_rate) * current_contribution_utility 140 ) 141 142 # find eligible units 143 eligible_mask = self.age[layer_name] > self.maturity_threshold 144 eligible_indices = torch.where(eligible_mask)[0] 145 146 # update the number of replacements 147 num_eligible_units = eligible_indices.numel() 148 self.num_replacements[layer_name] += int( 149 self.replacement_rate * num_eligible_units 150 ) 151 152 # if the number of replacements is greater than 1, execute the replacement 153 if self.num_replacements[layer_name] > 1: 154 155 # find the unit with smallest utility among eligible units 156 replaced_unit_idx = eligible_indices[ 157 torch.argmin( 158 self.contribution_utility[layer_name][eligible_indices] 159 / self.age[layer_name][eligible_indices] 160 ).item() 161 ] 162 163 # reinitialize the input weights of the unit 164 preceding_layer = self.backbone.preceding_layer(layer_name) 165 if preceding_layer is not None: 166 167 with torch.no_grad(): 168 169 preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like( 170 preceding_layer.weight[:, replaced_unit_idx] 171 ) 172 173 # reinitalize the output weights of the unit 174 with torch.no_grad(): 175 layer.weight[replaced_unit_idx] = torch.rand_like( 176 layer.weight[replaced_unit_idx] 177 ) 178 179 # reinitialize utility 180 self.contribution_utility[layer_name][replaced_unit_idx] = 0.0 181 182 # reintialize age 183 self.age[layer_name][replaced_unit_idx] = 0 184 185 # update the number of replacements 186 self.num_replacements[layer_name] -= 1
Update the contribution utility and age of units after each training step, and conduct reinitialization of units based on utility measures. This is the core of the CBP algorithm.
Args:
- outputs (
dict[str, Any]): the outputs of the training step, which is the returns of thetraining_step()method in theCLAlgorithm. - batch (
Any): the training data batch. - batch_idx (
int): the index of the current batch. This is for the file name of mask figures.