clarena.cl_algorithms.cbp
The submodule in cl_algorithms
for CBP (Continual Backpropagation) algorithm.
1r""" 2The submodule in `cl_algorithms` for [CBP (Continual Backpropagation)](https://www.nature.com/articles/s41586-024-07711-7) algorithm. 3""" 4 5__all__ = ["CBP"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12 13from clarena.backbones import CLBackbone 14from clarena.cl_algorithms import Finetuning 15from clarena.cl_heads import HeadsCIL, HeadsTIL 16from clarena.utils.transforms import min_max_normalise 17 18# always get logger for built-in logging in each module 19pylogger = logging.getLogger(__name__) 20 21 22class CBP(Finetuning): 23 r"""CBP (Continual Backpropagation) algorithm. 24 25 [CBP (Continual Backpropagation, 2024)](https://www.nature.com/articles/s41586-024-07711-7) is a continual learning approach that reinitialises a small number of units during training, using an utility measures to determine which units to reinitialise. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning. 26 27 We implement CBP as a subclass of Finetuning algorithm, as CBP has the same `forward()`, `training_step()`, `validation_step()` and `test_step()` method as `Finetuning` class. 28 """ 29 30 def __init__( 31 self, 32 backbone: CLBackbone, 33 heads: HeadsTIL | HeadsCIL, 34 replacement_rate: float, 35 maturity_threshold: int, 36 utility_decay_rate: float, 37 ) -> None: 38 r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 39 40 **Args:** 41 - **backbone** (`CLBackbone`): backbone network. 42 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 43 - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialised during training. 44 - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialised. 45 - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time. 46 """ 47 Finetuning.__init__(self, backbone=backbone, heads=heads) 48 49 self.replacement_rate: float = replacement_rate 50 r"""Store the replacement rate of units. """ 51 self.maturity_threshold: int = maturity_threshold 52 r"""Store the maturity threshold of units. """ 53 self.utility_decay_rate: float = utility_decay_rate 54 r"""Store the utility decay rate of units. """ 55 56 self.contribution_utility: dict[str, Tensor] = {} 57 r"""Store the contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units). """ 58 self.num_replacements: dict[str, int] = {} 59 r"""Store the number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """ 60 self.age: dict[str, Tensor] = {} 61 r"""Store the age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """ 62 63 def on_train_start(self) -> None: 64 r"""Initialise the utility, number of replacements and age for each layer as zeros.""" 65 66 # initialise the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 67 if self.task_id == 1: 68 for layer_name in self.backbone.weighted_layer_names: 69 layer = self.backbone.get_layer_by_name( 70 layer_name 71 ) # get the layer by its name 72 num_units = layer.weight.shape[0] 73 74 self.contribution_utility[layer_name] = torch.zeros(num_units).to( 75 self.device 76 ) 77 self.num_replacements[layer_name] = 0 78 self.age[layer_name] = torch.zeros(num_units).to(self.device) 79 80 def on_train_batch_end( 81 self, outputs: dict[str, Any], batch: Any, batch_idx: int 82 ) -> None: 83 r"""Update the contribution utility and age of units after each training step, and conduct reinitialisation of units based on utility measures. This is the core of the CBP algorithm. 84 85 **Args:** 86 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 87 - **batch** (`Any`): the training data batch. 88 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 89 """ 90 91 hidden_features = outputs["hidden_features"] 92 93 for layer_name in self.backbone.weighted_layer_names: 94 # layer-wise operation 95 96 layer = self.backbone.get_layer_by_name( 97 layer_name 98 ) # get the layer by its name 99 100 # update age 101 self.age[layer_name] += 1 102 103 # calculate current contribution utility 104 current_contribution_utility = ( 105 torch.mean( 106 torch.abs(hidden_features[layer_name]), 107 dim=0, # average the features over batch samples 108 ) 109 * torch.sum( 110 torch.abs(layer.weight), 111 dim=1, # sum over the output dimension 112 ) 113 ).detach() 114 current_contribution_utility = min_max_normalise( 115 current_contribution_utility 116 ) # normalise the utility to [0,1] to avoid linearly increasing utility 117 118 # update utility 119 self.contribution_utility[layer_name] = ( 120 self.utility_decay_rate * self.contribution_utility[layer_name] 121 + (1 - self.utility_decay_rate) * current_contribution_utility 122 ) 123 124 # find eligible units 125 eligible_mask = self.age[layer_name] > self.maturity_threshold 126 eligible_indices = torch.where(eligible_mask)[0] 127 128 # update the number of replacements 129 num_eligible_units = eligible_indices.numel() 130 self.num_replacements[layer_name] += int( 131 self.replacement_rate * num_eligible_units 132 ) 133 134 # if the number of replacements is greater than 1, execute the replacement 135 if self.num_replacements[layer_name] > 1: 136 137 # find the unit with smallest utility among eligible units 138 replaced_unit_idx = eligible_indices[ 139 torch.argmin( 140 self.contribution_utility[layer_name][eligible_indices] 141 / self.age[layer_name][eligible_indices] 142 ).item() 143 ] 144 145 # reinitialise the input weights of the unit 146 preceding_layer_name = self.backbone.preceding_layer_name(layer_name) 147 if preceding_layer_name is not None: 148 preceding_layer = self.backbone.get_layer_by_name( 149 preceding_layer_name 150 ) 151 with torch.no_grad(): 152 153 preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like( 154 preceding_layer.weight[:, replaced_unit_idx] 155 ) 156 157 # reinitalise the output weights of the unit 158 with torch.no_grad(): 159 layer.weight[replaced_unit_idx] = torch.rand_like( 160 layer.weight[replaced_unit_idx] 161 ) 162 163 # reinitialise utility 164 self.contribution_utility[layer_name][replaced_unit_idx] = 0.0 165 166 # reintialise age 167 self.age[layer_name][replaced_unit_idx] = 0 168 169 # update the number of replacements 170 self.num_replacements[layer_name] -= 1
23class CBP(Finetuning): 24 r"""CBP (Continual Backpropagation) algorithm. 25 26 [CBP (Continual Backpropagation, 2024)](https://www.nature.com/articles/s41586-024-07711-7) is a continual learning approach that reinitialises a small number of units during training, using an utility measures to determine which units to reinitialise. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning. 27 28 We implement CBP as a subclass of Finetuning algorithm, as CBP has the same `forward()`, `training_step()`, `validation_step()` and `test_step()` method as `Finetuning` class. 29 """ 30 31 def __init__( 32 self, 33 backbone: CLBackbone, 34 heads: HeadsTIL | HeadsCIL, 35 replacement_rate: float, 36 maturity_threshold: int, 37 utility_decay_rate: float, 38 ) -> None: 39 r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 40 41 **Args:** 42 - **backbone** (`CLBackbone`): backbone network. 43 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 44 - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialised during training. 45 - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialised. 46 - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time. 47 """ 48 Finetuning.__init__(self, backbone=backbone, heads=heads) 49 50 self.replacement_rate: float = replacement_rate 51 r"""Store the replacement rate of units. """ 52 self.maturity_threshold: int = maturity_threshold 53 r"""Store the maturity threshold of units. """ 54 self.utility_decay_rate: float = utility_decay_rate 55 r"""Store the utility decay rate of units. """ 56 57 self.contribution_utility: dict[str, Tensor] = {} 58 r"""Store the contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units). """ 59 self.num_replacements: dict[str, int] = {} 60 r"""Store the number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """ 61 self.age: dict[str, Tensor] = {} 62 r"""Store the age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """ 63 64 def on_train_start(self) -> None: 65 r"""Initialise the utility, number of replacements and age for each layer as zeros.""" 66 67 # initialise the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 68 if self.task_id == 1: 69 for layer_name in self.backbone.weighted_layer_names: 70 layer = self.backbone.get_layer_by_name( 71 layer_name 72 ) # get the layer by its name 73 num_units = layer.weight.shape[0] 74 75 self.contribution_utility[layer_name] = torch.zeros(num_units).to( 76 self.device 77 ) 78 self.num_replacements[layer_name] = 0 79 self.age[layer_name] = torch.zeros(num_units).to(self.device) 80 81 def on_train_batch_end( 82 self, outputs: dict[str, Any], batch: Any, batch_idx: int 83 ) -> None: 84 r"""Update the contribution utility and age of units after each training step, and conduct reinitialisation of units based on utility measures. This is the core of the CBP algorithm. 85 86 **Args:** 87 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 88 - **batch** (`Any`): the training data batch. 89 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 90 """ 91 92 hidden_features = outputs["hidden_features"] 93 94 for layer_name in self.backbone.weighted_layer_names: 95 # layer-wise operation 96 97 layer = self.backbone.get_layer_by_name( 98 layer_name 99 ) # get the layer by its name 100 101 # update age 102 self.age[layer_name] += 1 103 104 # calculate current contribution utility 105 current_contribution_utility = ( 106 torch.mean( 107 torch.abs(hidden_features[layer_name]), 108 dim=0, # average the features over batch samples 109 ) 110 * torch.sum( 111 torch.abs(layer.weight), 112 dim=1, # sum over the output dimension 113 ) 114 ).detach() 115 current_contribution_utility = min_max_normalise( 116 current_contribution_utility 117 ) # normalise the utility to [0,1] to avoid linearly increasing utility 118 119 # update utility 120 self.contribution_utility[layer_name] = ( 121 self.utility_decay_rate * self.contribution_utility[layer_name] 122 + (1 - self.utility_decay_rate) * current_contribution_utility 123 ) 124 125 # find eligible units 126 eligible_mask = self.age[layer_name] > self.maturity_threshold 127 eligible_indices = torch.where(eligible_mask)[0] 128 129 # update the number of replacements 130 num_eligible_units = eligible_indices.numel() 131 self.num_replacements[layer_name] += int( 132 self.replacement_rate * num_eligible_units 133 ) 134 135 # if the number of replacements is greater than 1, execute the replacement 136 if self.num_replacements[layer_name] > 1: 137 138 # find the unit with smallest utility among eligible units 139 replaced_unit_idx = eligible_indices[ 140 torch.argmin( 141 self.contribution_utility[layer_name][eligible_indices] 142 / self.age[layer_name][eligible_indices] 143 ).item() 144 ] 145 146 # reinitialise the input weights of the unit 147 preceding_layer_name = self.backbone.preceding_layer_name(layer_name) 148 if preceding_layer_name is not None: 149 preceding_layer = self.backbone.get_layer_by_name( 150 preceding_layer_name 151 ) 152 with torch.no_grad(): 153 154 preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like( 155 preceding_layer.weight[:, replaced_unit_idx] 156 ) 157 158 # reinitalise the output weights of the unit 159 with torch.no_grad(): 160 layer.weight[replaced_unit_idx] = torch.rand_like( 161 layer.weight[replaced_unit_idx] 162 ) 163 164 # reinitialise utility 165 self.contribution_utility[layer_name][replaced_unit_idx] = 0.0 166 167 # reintialise age 168 self.age[layer_name][replaced_unit_idx] = 0 169 170 # update the number of replacements 171 self.num_replacements[layer_name] -= 1
CBP (Continual Backpropagation) algorithm.
CBP (Continual Backpropagation, 2024) is a continual learning approach that reinitialises a small number of units during training, using an utility measures to determine which units to reinitialise. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning.
We implement CBP as a subclass of Finetuning algorithm, as CBP has the same forward()
, training_step()
, validation_step()
and test_step()
method as Finetuning
class.
31 def __init__( 32 self, 33 backbone: CLBackbone, 34 heads: HeadsTIL | HeadsCIL, 35 replacement_rate: float, 36 maturity_threshold: int, 37 utility_decay_rate: float, 38 ) -> None: 39 r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters. 40 41 **Args:** 42 - **backbone** (`CLBackbone`): backbone network. 43 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 44 - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialised during training. 45 - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialised. 46 - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time. 47 """ 48 Finetuning.__init__(self, backbone=backbone, heads=heads) 49 50 self.replacement_rate: float = replacement_rate 51 r"""Store the replacement rate of units. """ 52 self.maturity_threshold: int = maturity_threshold 53 r"""Store the maturity threshold of units. """ 54 self.utility_decay_rate: float = utility_decay_rate 55 r"""Store the utility decay rate of units. """ 56 57 self.contribution_utility: dict[str, Tensor] = {} 58 r"""Store the contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units). """ 59 self.num_replacements: dict[str, int] = {} 60 r"""Store the number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """ 61 self.age: dict[str, Tensor] = {} 62 r"""Store the age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """
Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
Args:
- backbone (
CLBackbone
): backbone network. - heads (
HeadsTIL
|HeadsCIL
): output heads. - replacement_rate (
float
): the replacement rate of units. It is the precentage of units to be reinitialised during training. - maturity_threshold (
int
): the maturity threshold of units. It is the number of training steps before a unit can be reinitialised. - utility_decay_rate (
float
): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
Store the contribution utility of units. See equation (1) in the continual backpropagation paper. Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units).
Store the number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer.
Store the age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units).
64 def on_train_start(self) -> None: 65 r"""Initialise the utility, number of replacements and age for each layer as zeros.""" 66 67 # initialise the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 68 if self.task_id == 1: 69 for layer_name in self.backbone.weighted_layer_names: 70 layer = self.backbone.get_layer_by_name( 71 layer_name 72 ) # get the layer by its name 73 num_units = layer.weight.shape[0] 74 75 self.contribution_utility[layer_name] = torch.zeros(num_units).to( 76 self.device 77 ) 78 self.num_replacements[layer_name] = 0 79 self.age[layer_name] = torch.zeros(num_units).to(self.device)
Initialise the utility, number of replacements and age for each layer as zeros.
81 def on_train_batch_end( 82 self, outputs: dict[str, Any], batch: Any, batch_idx: int 83 ) -> None: 84 r"""Update the contribution utility and age of units after each training step, and conduct reinitialisation of units based on utility measures. This is the core of the CBP algorithm. 85 86 **Args:** 87 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 88 - **batch** (`Any`): the training data batch. 89 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 90 """ 91 92 hidden_features = outputs["hidden_features"] 93 94 for layer_name in self.backbone.weighted_layer_names: 95 # layer-wise operation 96 97 layer = self.backbone.get_layer_by_name( 98 layer_name 99 ) # get the layer by its name 100 101 # update age 102 self.age[layer_name] += 1 103 104 # calculate current contribution utility 105 current_contribution_utility = ( 106 torch.mean( 107 torch.abs(hidden_features[layer_name]), 108 dim=0, # average the features over batch samples 109 ) 110 * torch.sum( 111 torch.abs(layer.weight), 112 dim=1, # sum over the output dimension 113 ) 114 ).detach() 115 current_contribution_utility = min_max_normalise( 116 current_contribution_utility 117 ) # normalise the utility to [0,1] to avoid linearly increasing utility 118 119 # update utility 120 self.contribution_utility[layer_name] = ( 121 self.utility_decay_rate * self.contribution_utility[layer_name] 122 + (1 - self.utility_decay_rate) * current_contribution_utility 123 ) 124 125 # find eligible units 126 eligible_mask = self.age[layer_name] > self.maturity_threshold 127 eligible_indices = torch.where(eligible_mask)[0] 128 129 # update the number of replacements 130 num_eligible_units = eligible_indices.numel() 131 self.num_replacements[layer_name] += int( 132 self.replacement_rate * num_eligible_units 133 ) 134 135 # if the number of replacements is greater than 1, execute the replacement 136 if self.num_replacements[layer_name] > 1: 137 138 # find the unit with smallest utility among eligible units 139 replaced_unit_idx = eligible_indices[ 140 torch.argmin( 141 self.contribution_utility[layer_name][eligible_indices] 142 / self.age[layer_name][eligible_indices] 143 ).item() 144 ] 145 146 # reinitialise the input weights of the unit 147 preceding_layer_name = self.backbone.preceding_layer_name(layer_name) 148 if preceding_layer_name is not None: 149 preceding_layer = self.backbone.get_layer_by_name( 150 preceding_layer_name 151 ) 152 with torch.no_grad(): 153 154 preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like( 155 preceding_layer.weight[:, replaced_unit_idx] 156 ) 157 158 # reinitalise the output weights of the unit 159 with torch.no_grad(): 160 layer.weight[replaced_unit_idx] = torch.rand_like( 161 layer.weight[replaced_unit_idx] 162 ) 163 164 # reinitialise utility 165 self.contribution_utility[layer_name][replaced_unit_idx] = 0.0 166 167 # reintialise age 168 self.age[layer_name][replaced_unit_idx] = 0 169 170 # update the number of replacements 171 self.num_replacements[layer_name] -= 1
Update the contribution utility and age of units after each training step, and conduct reinitialisation of units based on utility measures. This is the core of the CBP algorithm.
Args:
- outputs (
dict[str, Any]
): the outputs of the training step, which is the returns of thetraining_step()
method in theCLAlgorithm
. - batch (
Any
): the training data batch. - batch_idx (
int
): the index of the current batch. This is for the file name of mask figures.