clarena.cl_algorithms.adahat
The submodule in cl_algorithms for AdaHAT (Adaptive Hard Attention to the Task) algorithm.
1r""" 2The submodule in `cl_algorithms` for [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) algorithm. 3""" 4 5__all__ = ["AdaHAT"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12 13from clarena.backbones import HATMaskBackbone 14from clarena.cl_algorithms import HAT 15from clarena.heads import HeadsTIL 16from clarena.heads.head_dil import HeadDIL 17from clarena.utils.metrics import HATNetworkCapacityMetric 18 19# always get logger for built-in logging in each module 20pylogger = logging.getLogger(__name__) 21 22 23class AdaHAT(HAT): 24 r"""[AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) algorithm. 25 26 An architecture-based continual learning approach that improves [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) by introducing adaptive soft gradient clipping based on parameter importance and network sparsity. 27 28 We implement AdaHAT as a subclass of HAT, as it shares the same `forward()`, `compensate_task_embedding_gradients()`, `training_step()`, `on_train_end()`, `validation_step()`, and `test_step()` methods as the `HAT` class. 29 """ 30 31 def __init__( 32 self, 33 backbone: HATMaskBackbone, 34 heads: HeadsTIL | HeadDIL, 35 adjustment_mode: str, 36 adjustment_intensity: float, 37 s_max: float, 38 clamp_threshold: float, 39 mask_sparsity_reg_factor: float, 40 mask_sparsity_reg_mode: str = "original", 41 task_embedding_init_mode: str = "N01", 42 epsilon: float = 0.1, 43 non_algorithmic_hparams: dict[str, Any] = {}, 44 **kwargs, 45 ) -> None: 46 r"""Initialize the AdaHAT algorithm with the network. 47 48 **Args:** 49 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 50 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 51 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 52 1. 'adahat': set gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach (allows slight updates on previous-task parameters). See Eqs. (8) and (9) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 2. 'adahat_no_sum': as above but without parameter-importance (i.e., no summative mask). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 3. 'adahat_no_reg': as above but without network sparsity (i.e., no mask sparsity regularization term). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 55 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 56 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 57 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 58 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 59 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 60 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 61 2. 'cross': the cross version of mask sparsity regularization. 62 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 63 1. 'N01' (default): standard normal distribution $N(0, 1)$. 64 2. 'U-11': uniform distribution $U(-1, 1)$. 65 3. 'U01': uniform distribution $U(0, 1)$. 66 4. 'U-10': uniform distribution $U(-1, 0)$. 67 5. 'last': inherit the task embedding from the last task. 68 - **epsilon** (`float`): the value added to network sparsity to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 69 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 70 - **kwargs**: Reserved for multiple inheritance. 71 """ 72 super().__init__( 73 backbone=backbone, 74 heads=heads, 75 adjustment_mode=adjustment_mode, 76 s_max=s_max, 77 clamp_threshold=clamp_threshold, 78 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 79 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 80 task_embedding_init_mode=task_embedding_init_mode, 81 alpha=None, 82 non_algorithmic_hparams=non_algorithmic_hparams, 83 **kwargs, 84 ) 85 86 self.adjustment_intensity: float = adjustment_intensity 87 r"""The adjustment intensity in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 88 self.epsilon: float | None = epsilon 89 r"""The small value to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).""" 90 91 # save additional algorithmic hyperparameters 92 self.save_hyperparameters("adjustment_intensity", "epsilon") 93 94 self.summative_mask_for_previous_tasks: dict[str, Tensor] = {} 95 r"""The summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ).""" 96 97 # set manual optimization 98 self.automatic_optimization = False 99 100 AdaHAT.sanity_check(self) 101 102 def sanity_check(self) -> None: 103 r"""Sanity check.""" 104 if self.adjustment_intensity <= 0: 105 raise ValueError( 106 f"The adjustment intensity should be positive, but got {self.adjustment_intensity}." 107 ) 108 109 def on_train_start(self) -> None: 110 r"""Additionally initialize the summative mask at the beginning of the first task.""" 111 super().on_train_start() 112 113 # initialize the summative mask at the beginning of the first task. This should not be called in `__init__()` method because `self.device` is not available at that time 114 if self.summative_mask_for_previous_tasks == {}: 115 for layer_name in self.backbone.weighted_layer_names: 116 layer = self.backbone.get_layer_by_name( 117 layer_name 118 ) # get the layer by its name 119 num_units = layer.weight.shape[0] 120 121 self.summative_mask_for_previous_tasks[layer_name] = torch.zeros( 122 num_units 123 ).to( 124 self.device 125 ) # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialized as a zeros mask for $t = 1$. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) 126 127 def clip_grad_by_adjustment( 128 self, 129 network_sparsity: dict[str, Tensor] | None = None, 130 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 131 r"""Clip the gradients by the adjustment rate. See Eq. (8) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 132 133 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 134 135 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 136 137 **Args:** 138 - **network_sparsity** (`dict[str, Tensor]` | `None`): the network sparsity (i.e., the mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. Applies only to mode `adahat` and `adahat_no_sum`, not `adahat_no_reg`. 139 140 **Returns:** 141 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 142 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 143 - **capacity** (`Tensor`): the calculated network capacity. 144 """ 145 146 # initialize network capacity metric 147 capacity = HATNetworkCapacityMetric().to(self.device) 148 adjustment_rate_weight = {} 149 adjustment_rate_bias = {} 150 151 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) 152 for layer_name in self.backbone.weighted_layer_names: 153 154 layer = self.backbone.get_layer_by_name( 155 layer_name 156 ) # get the layer by its name 157 158 # placeholder for the adjustment rate to avoid the error of using it before assignment 159 adjustment_rate_weight_layer = 1 160 adjustment_rate_bias_layer = 1 161 162 weight_importance, bias_importance = ( 163 self.backbone.get_layer_measure_parameter_wise( 164 neuron_wise_measure=self.summative_mask_for_previous_tasks, 165 layer_name=layer_name, 166 aggregation_mode="min", 167 ) 168 ) # AdaHAT depends on parameter importance rather than parameter masks (as in HAT) 169 170 network_sparsity_layer = network_sparsity[layer_name] 171 172 if self.adjustment_mode == "adahat": 173 r_layer = self.adjustment_intensity / ( 174 self.epsilon + network_sparsity_layer 175 ) 176 adjustment_rate_weight_layer = torch.div( 177 r_layer, (weight_importance + r_layer) 178 ) 179 adjustment_rate_bias_layer = torch.div( 180 r_layer, (bias_importance + r_layer) 181 ) 182 183 elif self.adjustment_mode == "adahat_no_sum": 184 185 r_layer = self.adjustment_intensity / ( 186 self.epsilon + network_sparsity_layer 187 ) 188 adjustment_rate_weight_layer = torch.div( 189 r_layer, (self.task_id + r_layer) 190 ) 191 adjustment_rate_bias_layer = torch.div( 192 r_layer, (self.task_id + r_layer) 193 ) 194 195 elif self.adjustment_mode == "adahat_no_reg": 196 197 r_layer = self.adjustment_intensity / (self.epsilon + 0.0) 198 adjustment_rate_weight_layer = torch.div( 199 r_layer, (weight_importance + r_layer) 200 ) 201 adjustment_rate_bias_layer = torch.div( 202 r_layer, (bias_importance + r_layer) 203 ) 204 205 # apply the adjustment rate to the gradients 206 layer.weight.grad.data *= adjustment_rate_weight_layer 207 if layer.bias is not None: 208 layer.bias.grad.data *= adjustment_rate_bias_layer 209 210 # store the adjustment rate for logging 211 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 212 if layer.bias is not None: 213 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 214 215 # update network capacity metric 216 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 217 218 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 219 220 def on_train_end(self) -> None: 221 r"""Additionally update the summative mask after training the task.""" 222 super().on_train_end() 223 224 mask_t = self.backbone.masks[ 225 self.task_id 226 ] # get stored mask for the current task again 227 228 # update the summative mask for previous tasks. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) 229 self.summative_mask_for_previous_tasks = { 230 layer_name: self.summative_mask_for_previous_tasks[layer_name] 231 + mask_t[layer_name] 232 for layer_name in self.backbone.weighted_layer_names 233 }
24class AdaHAT(HAT): 25 r"""[AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) algorithm. 26 27 An architecture-based continual learning approach that improves [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) by introducing adaptive soft gradient clipping based on parameter importance and network sparsity. 28 29 We implement AdaHAT as a subclass of HAT, as it shares the same `forward()`, `compensate_task_embedding_gradients()`, `training_step()`, `on_train_end()`, `validation_step()`, and `test_step()` methods as the `HAT` class. 30 """ 31 32 def __init__( 33 self, 34 backbone: HATMaskBackbone, 35 heads: HeadsTIL | HeadDIL, 36 adjustment_mode: str, 37 adjustment_intensity: float, 38 s_max: float, 39 clamp_threshold: float, 40 mask_sparsity_reg_factor: float, 41 mask_sparsity_reg_mode: str = "original", 42 task_embedding_init_mode: str = "N01", 43 epsilon: float = 0.1, 44 non_algorithmic_hparams: dict[str, Any] = {}, 45 **kwargs, 46 ) -> None: 47 r"""Initialize the AdaHAT algorithm with the network. 48 49 **Args:** 50 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 51 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 52 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 53 1. 'adahat': set gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach (allows slight updates on previous-task parameters). See Eqs. (8) and (9) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 2. 'adahat_no_sum': as above but without parameter-importance (i.e., no summative mask). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 55 3. 'adahat_no_reg': as above but without network sparsity (i.e., no mask sparsity regularization term). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 56 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 57 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 58 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 59 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 60 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 61 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 62 2. 'cross': the cross version of mask sparsity regularization. 63 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 64 1. 'N01' (default): standard normal distribution $N(0, 1)$. 65 2. 'U-11': uniform distribution $U(-1, 1)$. 66 3. 'U01': uniform distribution $U(0, 1)$. 67 4. 'U-10': uniform distribution $U(-1, 0)$. 68 5. 'last': inherit the task embedding from the last task. 69 - **epsilon** (`float`): the value added to network sparsity to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 70 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 71 - **kwargs**: Reserved for multiple inheritance. 72 """ 73 super().__init__( 74 backbone=backbone, 75 heads=heads, 76 adjustment_mode=adjustment_mode, 77 s_max=s_max, 78 clamp_threshold=clamp_threshold, 79 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 80 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 81 task_embedding_init_mode=task_embedding_init_mode, 82 alpha=None, 83 non_algorithmic_hparams=non_algorithmic_hparams, 84 **kwargs, 85 ) 86 87 self.adjustment_intensity: float = adjustment_intensity 88 r"""The adjustment intensity in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 89 self.epsilon: float | None = epsilon 90 r"""The small value to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).""" 91 92 # save additional algorithmic hyperparameters 93 self.save_hyperparameters("adjustment_intensity", "epsilon") 94 95 self.summative_mask_for_previous_tasks: dict[str, Tensor] = {} 96 r"""The summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ).""" 97 98 # set manual optimization 99 self.automatic_optimization = False 100 101 AdaHAT.sanity_check(self) 102 103 def sanity_check(self) -> None: 104 r"""Sanity check.""" 105 if self.adjustment_intensity <= 0: 106 raise ValueError( 107 f"The adjustment intensity should be positive, but got {self.adjustment_intensity}." 108 ) 109 110 def on_train_start(self) -> None: 111 r"""Additionally initialize the summative mask at the beginning of the first task.""" 112 super().on_train_start() 113 114 # initialize the summative mask at the beginning of the first task. This should not be called in `__init__()` method because `self.device` is not available at that time 115 if self.summative_mask_for_previous_tasks == {}: 116 for layer_name in self.backbone.weighted_layer_names: 117 layer = self.backbone.get_layer_by_name( 118 layer_name 119 ) # get the layer by its name 120 num_units = layer.weight.shape[0] 121 122 self.summative_mask_for_previous_tasks[layer_name] = torch.zeros( 123 num_units 124 ).to( 125 self.device 126 ) # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialized as a zeros mask for $t = 1$. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) 127 128 def clip_grad_by_adjustment( 129 self, 130 network_sparsity: dict[str, Tensor] | None = None, 131 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 132 r"""Clip the gradients by the adjustment rate. See Eq. (8) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 133 134 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 135 136 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 137 138 **Args:** 139 - **network_sparsity** (`dict[str, Tensor]` | `None`): the network sparsity (i.e., the mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. Applies only to mode `adahat` and `adahat_no_sum`, not `adahat_no_reg`. 140 141 **Returns:** 142 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 143 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 144 - **capacity** (`Tensor`): the calculated network capacity. 145 """ 146 147 # initialize network capacity metric 148 capacity = HATNetworkCapacityMetric().to(self.device) 149 adjustment_rate_weight = {} 150 adjustment_rate_bias = {} 151 152 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) 153 for layer_name in self.backbone.weighted_layer_names: 154 155 layer = self.backbone.get_layer_by_name( 156 layer_name 157 ) # get the layer by its name 158 159 # placeholder for the adjustment rate to avoid the error of using it before assignment 160 adjustment_rate_weight_layer = 1 161 adjustment_rate_bias_layer = 1 162 163 weight_importance, bias_importance = ( 164 self.backbone.get_layer_measure_parameter_wise( 165 neuron_wise_measure=self.summative_mask_for_previous_tasks, 166 layer_name=layer_name, 167 aggregation_mode="min", 168 ) 169 ) # AdaHAT depends on parameter importance rather than parameter masks (as in HAT) 170 171 network_sparsity_layer = network_sparsity[layer_name] 172 173 if self.adjustment_mode == "adahat": 174 r_layer = self.adjustment_intensity / ( 175 self.epsilon + network_sparsity_layer 176 ) 177 adjustment_rate_weight_layer = torch.div( 178 r_layer, (weight_importance + r_layer) 179 ) 180 adjustment_rate_bias_layer = torch.div( 181 r_layer, (bias_importance + r_layer) 182 ) 183 184 elif self.adjustment_mode == "adahat_no_sum": 185 186 r_layer = self.adjustment_intensity / ( 187 self.epsilon + network_sparsity_layer 188 ) 189 adjustment_rate_weight_layer = torch.div( 190 r_layer, (self.task_id + r_layer) 191 ) 192 adjustment_rate_bias_layer = torch.div( 193 r_layer, (self.task_id + r_layer) 194 ) 195 196 elif self.adjustment_mode == "adahat_no_reg": 197 198 r_layer = self.adjustment_intensity / (self.epsilon + 0.0) 199 adjustment_rate_weight_layer = torch.div( 200 r_layer, (weight_importance + r_layer) 201 ) 202 adjustment_rate_bias_layer = torch.div( 203 r_layer, (bias_importance + r_layer) 204 ) 205 206 # apply the adjustment rate to the gradients 207 layer.weight.grad.data *= adjustment_rate_weight_layer 208 if layer.bias is not None: 209 layer.bias.grad.data *= adjustment_rate_bias_layer 210 211 # store the adjustment rate for logging 212 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 213 if layer.bias is not None: 214 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 215 216 # update network capacity metric 217 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 218 219 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 220 221 def on_train_end(self) -> None: 222 r"""Additionally update the summative mask after training the task.""" 223 super().on_train_end() 224 225 mask_t = self.backbone.masks[ 226 self.task_id 227 ] # get stored mask for the current task again 228 229 # update the summative mask for previous tasks. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) 230 self.summative_mask_for_previous_tasks = { 231 layer_name: self.summative_mask_for_previous_tasks[layer_name] 232 + mask_t[layer_name] 233 for layer_name in self.backbone.weighted_layer_names 234 }
AdaHAT (Adaptive Hard Attention to the Task) algorithm.
An architecture-based continual learning approach that improves HAT (Hard Attention to the Task) by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
We implement AdaHAT as a subclass of HAT, as it shares the same forward(), compensate_task_embedding_gradients(), training_step(), on_train_end(), validation_step(), and test_step() methods as the HAT class.
32 def __init__( 33 self, 34 backbone: HATMaskBackbone, 35 heads: HeadsTIL | HeadDIL, 36 adjustment_mode: str, 37 adjustment_intensity: float, 38 s_max: float, 39 clamp_threshold: float, 40 mask_sparsity_reg_factor: float, 41 mask_sparsity_reg_mode: str = "original", 42 task_embedding_init_mode: str = "N01", 43 epsilon: float = 0.1, 44 non_algorithmic_hparams: dict[str, Any] = {}, 45 **kwargs, 46 ) -> None: 47 r"""Initialize the AdaHAT algorithm with the network. 48 49 **Args:** 50 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 51 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 52 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 53 1. 'adahat': set gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach (allows slight updates on previous-task parameters). See Eqs. (8) and (9) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 2. 'adahat_no_sum': as above but without parameter-importance (i.e., no summative mask). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 55 3. 'adahat_no_reg': as above but without network sparsity (i.e., no mask sparsity regularization term). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 56 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 57 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 58 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 59 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 60 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 61 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 62 2. 'cross': the cross version of mask sparsity regularization. 63 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 64 1. 'N01' (default): standard normal distribution $N(0, 1)$. 65 2. 'U-11': uniform distribution $U(-1, 1)$. 66 3. 'U01': uniform distribution $U(0, 1)$. 67 4. 'U-10': uniform distribution $U(-1, 0)$. 68 5. 'last': inherit the task embedding from the last task. 69 - **epsilon** (`float`): the value added to network sparsity to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 70 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 71 - **kwargs**: Reserved for multiple inheritance. 72 """ 73 super().__init__( 74 backbone=backbone, 75 heads=heads, 76 adjustment_mode=adjustment_mode, 77 s_max=s_max, 78 clamp_threshold=clamp_threshold, 79 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 80 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 81 task_embedding_init_mode=task_embedding_init_mode, 82 alpha=None, 83 non_algorithmic_hparams=non_algorithmic_hparams, 84 **kwargs, 85 ) 86 87 self.adjustment_intensity: float = adjustment_intensity 88 r"""The adjustment intensity in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 89 self.epsilon: float | None = epsilon 90 r"""The small value to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).""" 91 92 # save additional algorithmic hyperparameters 93 self.save_hyperparameters("adjustment_intensity", "epsilon") 94 95 self.summative_mask_for_previous_tasks: dict[str, Tensor] = {} 96 r"""The summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ).""" 97 98 # set manual optimization 99 self.automatic_optimization = False 100 101 AdaHAT.sanity_check(self)
Initialize the AdaHAT algorithm with the network.
Args:
- backbone (
HATMaskBackbone): must be a backbone network with the HAT mask mechanism. - heads (
HeadsTIL|HeadDIL): output heads. AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). - adjustment_mode (
str): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:- 'adahat': set gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach (allows slight updates on previous-task parameters). See Eqs. (8) and (9) in Sec. 3.1 of the AdaHAT paper.
- 'adahat_no_sum': as above but without parameter-importance (i.e., no summative mask). See Sec. 4.3 (ablation study) in the AdaHAT paper.
- 'adahat_no_reg': as above but without network sparsity (i.e., no mask sparsity regularization term). See Sec. 4.3 (ablation study) in the AdaHAT paper.
- adjustment_intensity (
float): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in Eq. (9) of the AdaHAT paper). - s_max (
float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper. - clamp_threshold (
float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper. - mask_sparsity_reg_factor (
float): hyperparameter, the regularization factor for mask sparsity. - mask_sparsity_reg_mode (
str): the mode of mask sparsity regularization, must be one of:- 'original' (default): the original mask sparsity regularization in the HAT paper.
- 'cross': the cross version of mask sparsity regularization.
- task_embedding_init_mode (
str): the initialization mode for task embeddings, must be one of:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit the task embedding from the last task.
- epsilon (
float): the value added to network sparsity to avoid division by zero (appearing in Eq. (9) of the AdaHAT paper). - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
The small value to avoid division by zero (appearing in Eq. (9) of the AdaHAT paper).
The summative binary attention mask $\mathrm{M}^{
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False you are responsible for calling .backward(), .step(), .zero_grad().
103 def sanity_check(self) -> None: 104 r"""Sanity check.""" 105 if self.adjustment_intensity <= 0: 106 raise ValueError( 107 f"The adjustment intensity should be positive, but got {self.adjustment_intensity}." 108 )
Sanity check.
110 def on_train_start(self) -> None: 111 r"""Additionally initialize the summative mask at the beginning of the first task.""" 112 super().on_train_start() 113 114 # initialize the summative mask at the beginning of the first task. This should not be called in `__init__()` method because `self.device` is not available at that time 115 if self.summative_mask_for_previous_tasks == {}: 116 for layer_name in self.backbone.weighted_layer_names: 117 layer = self.backbone.get_layer_by_name( 118 layer_name 119 ) # get the layer by its name 120 num_units = layer.weight.shape[0] 121 122 self.summative_mask_for_previous_tasks[layer_name] = torch.zeros( 123 num_units 124 ).to( 125 self.device 126 ) # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialized as a zeros mask for $t = 1$. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
Additionally initialize the summative mask at the beginning of the first task.
128 def clip_grad_by_adjustment( 129 self, 130 network_sparsity: dict[str, Tensor] | None = None, 131 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 132 r"""Clip the gradients by the adjustment rate. See Eq. (8) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 133 134 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 135 136 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 137 138 **Args:** 139 - **network_sparsity** (`dict[str, Tensor]` | `None`): the network sparsity (i.e., the mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. Applies only to mode `adahat` and `adahat_no_sum`, not `adahat_no_reg`. 140 141 **Returns:** 142 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 143 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 144 - **capacity** (`Tensor`): the calculated network capacity. 145 """ 146 147 # initialize network capacity metric 148 capacity = HATNetworkCapacityMetric().to(self.device) 149 adjustment_rate_weight = {} 150 adjustment_rate_bias = {} 151 152 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) 153 for layer_name in self.backbone.weighted_layer_names: 154 155 layer = self.backbone.get_layer_by_name( 156 layer_name 157 ) # get the layer by its name 158 159 # placeholder for the adjustment rate to avoid the error of using it before assignment 160 adjustment_rate_weight_layer = 1 161 adjustment_rate_bias_layer = 1 162 163 weight_importance, bias_importance = ( 164 self.backbone.get_layer_measure_parameter_wise( 165 neuron_wise_measure=self.summative_mask_for_previous_tasks, 166 layer_name=layer_name, 167 aggregation_mode="min", 168 ) 169 ) # AdaHAT depends on parameter importance rather than parameter masks (as in HAT) 170 171 network_sparsity_layer = network_sparsity[layer_name] 172 173 if self.adjustment_mode == "adahat": 174 r_layer = self.adjustment_intensity / ( 175 self.epsilon + network_sparsity_layer 176 ) 177 adjustment_rate_weight_layer = torch.div( 178 r_layer, (weight_importance + r_layer) 179 ) 180 adjustment_rate_bias_layer = torch.div( 181 r_layer, (bias_importance + r_layer) 182 ) 183 184 elif self.adjustment_mode == "adahat_no_sum": 185 186 r_layer = self.adjustment_intensity / ( 187 self.epsilon + network_sparsity_layer 188 ) 189 adjustment_rate_weight_layer = torch.div( 190 r_layer, (self.task_id + r_layer) 191 ) 192 adjustment_rate_bias_layer = torch.div( 193 r_layer, (self.task_id + r_layer) 194 ) 195 196 elif self.adjustment_mode == "adahat_no_reg": 197 198 r_layer = self.adjustment_intensity / (self.epsilon + 0.0) 199 adjustment_rate_weight_layer = torch.div( 200 r_layer, (weight_importance + r_layer) 201 ) 202 adjustment_rate_bias_layer = torch.div( 203 r_layer, (bias_importance + r_layer) 204 ) 205 206 # apply the adjustment rate to the gradients 207 layer.weight.grad.data *= adjustment_rate_weight_layer 208 if layer.bias is not None: 209 layer.bias.grad.data *= adjustment_rate_bias_layer 210 211 # store the adjustment rate for logging 212 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 213 if layer.bias is not None: 214 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 215 216 # update network capacity metric 217 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 218 219 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
Clip the gradients by the adjustment rate. See Eq. (8) in the AdaHAT paper.
Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.
Args:
- network_sparsity (
dict[str, Tensor]|None): the network sparsity (i.e., the mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. Applies only to modeadahatandadahat_no_sum, notadahat_no_reg.
Returns:
- adjustment_rate_weight (
dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. - adjustment_rate_bias (
dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. - capacity (
Tensor): the calculated network capacity.
221 def on_train_end(self) -> None: 222 r"""Additionally update the summative mask after training the task.""" 223 super().on_train_end() 224 225 mask_t = self.backbone.masks[ 226 self.task_id 227 ] # get stored mask for the current task again 228 229 # update the summative mask for previous tasks. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) 230 self.summative_mask_for_previous_tasks = { 231 layer_name: self.summative_mask_for_previous_tasks[layer_name] 232 + mask_t[layer_name] 233 for layer_name in self.backbone.weighted_layer_names 234 }
Additionally update the summative mask after training the task.
Inherited Members
- clarena.cl_algorithms.hat.HAT
- adjustment_mode
- s_max
- clamp_threshold
- mask_sparsity_reg_factor
- mask_sparsity_reg_mode
- mark_sparsity_reg
- task_embedding_init_mode
- alpha
- cumulative_mask_for_previous_tasks
- clip_grad_by_mask
- compensate_task_embedding_gradients
- forward
- training_step
- validation_step
- test_step