clarena.cl_algorithms.adahat
The submodule in cl_algorithms
for AdaHAT (Adaptive Hard Attention to the Task) algorithm.
1r""" 2The submodule in `cl_algorithms` for [AdaHAT (Adaptive Hard Attention to the Task) algorithm](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 3""" 4 5__all__ = ["AdaHAT"] 6 7import logging 8 9import torch 10from torch import Tensor 11 12from clarena.backbones import HATMaskBackbone 13from clarena.cl_algorithms import HAT 14from clarena.cl_heads import HeadsCIL, HeadsTIL 15from clarena.utils import HATNetworkCapacity 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class AdaHAT(HAT): 22 r"""AdaHAT (Adaptive Hard Attention to the Task) algorithm. 23 24 [Adaptive HAT (Adaptive Hard Attention to the Task, 2024)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity. 25 26 We implement AdaHAT as a subclass of HAT algorithm, as AdaHAT has the same `forward()`, `compensate_task_embedding_gradients()`, `training_step()`, `on_train_end()`,`validation_step()`, `test_step()` method as `HAT` class. 27 """ 28 29 def __init__( 30 self, 31 backbone: HATMaskBackbone, 32 heads: HeadsTIL | HeadsCIL, 33 adjustment_mode: str, 34 adjustment_intensity: float, 35 s_max: float, 36 clamp_threshold: float, 37 mask_sparsity_reg_factor: float, 38 mask_sparsity_reg_mode: str = "original", 39 task_embedding_init_mode: str = "N01", 40 epsilon: float = 0.1, 41 ) -> None: 42 r"""Initialise the AdaHAT algorithm with the network. 43 44 **Args:** 45 - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism. 46 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 47 - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following: 48 1. 'adahat': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach. This is the way that AdaHAT does, which allowes the part of network for previous tasks to be updated slightly. See equation (8) and (9) chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 49 2. 'adahat_no_sum': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of parameter importance i.e. summative mask. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 50 3. 'adahat_no_reg': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of network sparsity i.e. mask sparsity regularisation value. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 51 - **adjustment_intensity** (`float`): hyperparameter, control the overall intensity of gradient adjustment. It's the $\alpha$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 53 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 54 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity. 55 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 56 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 57 2. 'cross': the cross version mask sparsity regularisation. 58 - **task_embedding_init_mode** (`str`): the initialisation method for task embeddings, should be one of the following: 59 1. 'N01' (default): standard normal distribution $N(0, 1)$. 60 2. 'U-11': uniform distribution $U(-1, 1)$. 61 3. 'U01': uniform distribution $U(0, 1)$. 62 4. 'U-10': uniform distribution $U(-1, 0)$. 63 5. 'last': inherit task embedding from last task. 64 - **epsilon** (`float`): the value added to network sparsity to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 65 """ 66 HAT.__init__( 67 self, 68 backbone=backbone, 69 heads=heads, 70 adjustment_mode=adjustment_mode, 71 s_max=s_max, 72 clamp_threshold=clamp_threshold, 73 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 74 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 75 task_embedding_init_mode=task_embedding_init_mode, 76 alpha=None, 77 ) 78 79 self.adjustment_intensity = adjustment_intensity 80 r"""Store the adjustment intensity in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 81 self.epsilon = epsilon 82 """Store the small value to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 83 84 self.summative_mask_for_previous_tasks: dict[str, Tensor] = {} 85 r"""Store the summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding summative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 86 87 # set manual optimisation 88 self.automatic_optimization = False 89 90 AdaHAT.sanity_check(self) 91 92 def sanity_check(self) -> None: 93 r"""Check the sanity of the arguments. 94 95 **Raises:** 96 - **ValueError**: If the `adjustment_intensity` is not positive. 97 """ 98 if self.adjustment_intensity <= 0: 99 raise ValueError( 100 f"The adjustment intensity should be positive, but got {self.adjustment_intensity}." 101 ) 102 103 def on_train_start(self) -> None: 104 r"""Additionally initialise the summative mask at the beginning of first task.""" 105 HAT.on_train_start(self) 106 107 # initialise the summative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 108 if self.task_id == 1: 109 for layer_name in self.backbone.weighted_layer_names: 110 layer = self.backbone.get_layer_by_name( 111 layer_name 112 ) # get the layer by its name 113 num_units = layer.weight.shape[0] 114 115 self.summative_mask_for_previous_tasks[layer_name] = torch.zeros( 116 num_units 117 ).to( 118 self.device 119 ) # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialised as zeros mask ($t = 1$). See equation (7) in chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 120 121 def clip_grad_by_adjustment( 122 self, 123 network_sparsity: dict[str, Tensor] | None = None, 124 ) -> Tensor: 125 r"""Clip the gradients by the adjustment rate. 126 127 Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes. 128 129 Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 130 131 **Args:** 132 - **network_sparsity** (`dict[str, Tensor]` | `None`): The network sparsity i.e. the mask sparsity loss of each layer for the current task. It applies only to AdaHAT modes, as it is used to calculate the adjustment rate for the gradients. 133 134 **Returns:** 135 - **capacity** (`Tensor`): the calculated network capacity. 136 """ 137 138 # initialise network capacity metric 139 capacity = HATNetworkCapacity() 140 141 # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists) 142 for layer_name in self.backbone.weighted_layer_names: 143 144 layer = self.backbone.get_layer_by_name( 145 layer_name 146 ) # get the layer by its name 147 148 # placeholder for the adjustment rate to avoid the error of using it before assignment 149 adjustment_rate_weight = 1 150 adjustment_rate_bias = 1 151 152 weight_importance, bias_importance = ( 153 self.backbone.get_layer_measure_parameter_wise( 154 unit_wise_measure=self.summative_mask_for_previous_tasks, 155 layer_name=layer_name, 156 aggregation="min", 157 ) 158 ) # AdaHAT depend on parameter importance instead of parameter mask like HAT 159 160 network_sparsity_layer = network_sparsity[layer_name] 161 162 if self.adjustment_mode == "adahat": 163 r_layer = self.adjustment_intensity / ( 164 self.epsilon + network_sparsity_layer 165 ) 166 adjustment_rate_weight = torch.div( 167 r_layer, (weight_importance + r_layer) 168 ) 169 adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer)) 170 171 elif self.adjustment_mode == "adahat_no_sum": 172 173 r_layer = self.adjustment_intensity / ( 174 self.epsilon + network_sparsity_layer 175 ) 176 adjustment_rate_weight = torch.div(r_layer, (self.task_id + r_layer)) 177 adjustment_rate_bias = torch.div(r_layer, (self.task_id + r_layer)) 178 179 elif self.adjustment_mode == "adahat_no_reg": 180 181 r_layer = self.adjustment_intensity / (self.epsilon + 0.0) 182 adjustment_rate_weight = torch.div( 183 r_layer, (weight_importance + r_layer) 184 ) 185 adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer)) 186 187 # apply the adjustment rate to the gradients 188 layer.weight.grad.data *= adjustment_rate_weight 189 if layer.bias is not None: 190 layer.bias.grad.data *= adjustment_rate_bias 191 192 # update network capacity metric 193 capacity.update(adjustment_rate_weight, adjustment_rate_bias) 194 195 return capacity.compute() 196 197 def on_train_end(self) -> None: 198 r"""Additionally update summative mask after training the task.""" 199 200 HAT.on_train_end(self) 201 202 mask_t = self.masks[ 203 f"{self.task_id}" 204 ] # get stored mask for the current task again 205 self.summative_mask_for_previous_tasks = { 206 layer_name: self.summative_mask_for_previous_tasks[layer_name] 207 + mask_t[layer_name] 208 for layer_name in self.backbone.weighted_layer_names 209 }
22class AdaHAT(HAT): 23 r"""AdaHAT (Adaptive Hard Attention to the Task) algorithm. 24 25 [Adaptive HAT (Adaptive Hard Attention to the Task, 2024)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity. 26 27 We implement AdaHAT as a subclass of HAT algorithm, as AdaHAT has the same `forward()`, `compensate_task_embedding_gradients()`, `training_step()`, `on_train_end()`,`validation_step()`, `test_step()` method as `HAT` class. 28 """ 29 30 def __init__( 31 self, 32 backbone: HATMaskBackbone, 33 heads: HeadsTIL | HeadsCIL, 34 adjustment_mode: str, 35 adjustment_intensity: float, 36 s_max: float, 37 clamp_threshold: float, 38 mask_sparsity_reg_factor: float, 39 mask_sparsity_reg_mode: str = "original", 40 task_embedding_init_mode: str = "N01", 41 epsilon: float = 0.1, 42 ) -> None: 43 r"""Initialise the AdaHAT algorithm with the network. 44 45 **Args:** 46 - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism. 47 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 48 - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following: 49 1. 'adahat': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach. This is the way that AdaHAT does, which allowes the part of network for previous tasks to be updated slightly. See equation (8) and (9) chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 50 2. 'adahat_no_sum': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of parameter importance i.e. summative mask. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 51 3. 'adahat_no_reg': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of network sparsity i.e. mask sparsity regularisation value. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 - **adjustment_intensity** (`float`): hyperparameter, control the overall intensity of gradient adjustment. It's the $\alpha$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 54 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity. 56 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 57 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 58 2. 'cross': the cross version mask sparsity regularisation. 59 - **task_embedding_init_mode** (`str`): the initialisation method for task embeddings, should be one of the following: 60 1. 'N01' (default): standard normal distribution $N(0, 1)$. 61 2. 'U-11': uniform distribution $U(-1, 1)$. 62 3. 'U01': uniform distribution $U(0, 1)$. 63 4. 'U-10': uniform distribution $U(-1, 0)$. 64 5. 'last': inherit task embedding from last task. 65 - **epsilon** (`float`): the value added to network sparsity to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 66 """ 67 HAT.__init__( 68 self, 69 backbone=backbone, 70 heads=heads, 71 adjustment_mode=adjustment_mode, 72 s_max=s_max, 73 clamp_threshold=clamp_threshold, 74 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 75 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 76 task_embedding_init_mode=task_embedding_init_mode, 77 alpha=None, 78 ) 79 80 self.adjustment_intensity = adjustment_intensity 81 r"""Store the adjustment intensity in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 82 self.epsilon = epsilon 83 """Store the small value to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 84 85 self.summative_mask_for_previous_tasks: dict[str, Tensor] = {} 86 r"""Store the summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding summative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 87 88 # set manual optimisation 89 self.automatic_optimization = False 90 91 AdaHAT.sanity_check(self) 92 93 def sanity_check(self) -> None: 94 r"""Check the sanity of the arguments. 95 96 **Raises:** 97 - **ValueError**: If the `adjustment_intensity` is not positive. 98 """ 99 if self.adjustment_intensity <= 0: 100 raise ValueError( 101 f"The adjustment intensity should be positive, but got {self.adjustment_intensity}." 102 ) 103 104 def on_train_start(self) -> None: 105 r"""Additionally initialise the summative mask at the beginning of first task.""" 106 HAT.on_train_start(self) 107 108 # initialise the summative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 109 if self.task_id == 1: 110 for layer_name in self.backbone.weighted_layer_names: 111 layer = self.backbone.get_layer_by_name( 112 layer_name 113 ) # get the layer by its name 114 num_units = layer.weight.shape[0] 115 116 self.summative_mask_for_previous_tasks[layer_name] = torch.zeros( 117 num_units 118 ).to( 119 self.device 120 ) # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialised as zeros mask ($t = 1$). See equation (7) in chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 121 122 def clip_grad_by_adjustment( 123 self, 124 network_sparsity: dict[str, Tensor] | None = None, 125 ) -> Tensor: 126 r"""Clip the gradients by the adjustment rate. 127 128 Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes. 129 130 Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 131 132 **Args:** 133 - **network_sparsity** (`dict[str, Tensor]` | `None`): The network sparsity i.e. the mask sparsity loss of each layer for the current task. It applies only to AdaHAT modes, as it is used to calculate the adjustment rate for the gradients. 134 135 **Returns:** 136 - **capacity** (`Tensor`): the calculated network capacity. 137 """ 138 139 # initialise network capacity metric 140 capacity = HATNetworkCapacity() 141 142 # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists) 143 for layer_name in self.backbone.weighted_layer_names: 144 145 layer = self.backbone.get_layer_by_name( 146 layer_name 147 ) # get the layer by its name 148 149 # placeholder for the adjustment rate to avoid the error of using it before assignment 150 adjustment_rate_weight = 1 151 adjustment_rate_bias = 1 152 153 weight_importance, bias_importance = ( 154 self.backbone.get_layer_measure_parameter_wise( 155 unit_wise_measure=self.summative_mask_for_previous_tasks, 156 layer_name=layer_name, 157 aggregation="min", 158 ) 159 ) # AdaHAT depend on parameter importance instead of parameter mask like HAT 160 161 network_sparsity_layer = network_sparsity[layer_name] 162 163 if self.adjustment_mode == "adahat": 164 r_layer = self.adjustment_intensity / ( 165 self.epsilon + network_sparsity_layer 166 ) 167 adjustment_rate_weight = torch.div( 168 r_layer, (weight_importance + r_layer) 169 ) 170 adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer)) 171 172 elif self.adjustment_mode == "adahat_no_sum": 173 174 r_layer = self.adjustment_intensity / ( 175 self.epsilon + network_sparsity_layer 176 ) 177 adjustment_rate_weight = torch.div(r_layer, (self.task_id + r_layer)) 178 adjustment_rate_bias = torch.div(r_layer, (self.task_id + r_layer)) 179 180 elif self.adjustment_mode == "adahat_no_reg": 181 182 r_layer = self.adjustment_intensity / (self.epsilon + 0.0) 183 adjustment_rate_weight = torch.div( 184 r_layer, (weight_importance + r_layer) 185 ) 186 adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer)) 187 188 # apply the adjustment rate to the gradients 189 layer.weight.grad.data *= adjustment_rate_weight 190 if layer.bias is not None: 191 layer.bias.grad.data *= adjustment_rate_bias 192 193 # update network capacity metric 194 capacity.update(adjustment_rate_weight, adjustment_rate_bias) 195 196 return capacity.compute() 197 198 def on_train_end(self) -> None: 199 r"""Additionally update summative mask after training the task.""" 200 201 HAT.on_train_end(self) 202 203 mask_t = self.masks[ 204 f"{self.task_id}" 205 ] # get stored mask for the current task again 206 self.summative_mask_for_previous_tasks = { 207 layer_name: self.summative_mask_for_previous_tasks[layer_name] 208 + mask_t[layer_name] 209 for layer_name in self.backbone.weighted_layer_names 210 }
AdaHAT (Adaptive Hard Attention to the Task) algorithm.
Adaptive HAT (Adaptive Hard Attention to the Task, 2024) is an architecture-based continual learning approach that improves HAT (Hard Attention to the Task, 2018) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity.
We implement AdaHAT as a subclass of HAT algorithm, as AdaHAT has the same forward()
, compensate_task_embedding_gradients()
, training_step()
, on_train_end()
,validation_step()
, test_step()
method as HAT
class.
30 def __init__( 31 self, 32 backbone: HATMaskBackbone, 33 heads: HeadsTIL | HeadsCIL, 34 adjustment_mode: str, 35 adjustment_intensity: float, 36 s_max: float, 37 clamp_threshold: float, 38 mask_sparsity_reg_factor: float, 39 mask_sparsity_reg_mode: str = "original", 40 task_embedding_init_mode: str = "N01", 41 epsilon: float = 0.1, 42 ) -> None: 43 r"""Initialise the AdaHAT algorithm with the network. 44 45 **Args:** 46 - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism. 47 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 48 - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following: 49 1. 'adahat': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach. This is the way that AdaHAT does, which allowes the part of network for previous tasks to be updated slightly. See equation (8) and (9) chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 50 2. 'adahat_no_sum': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of parameter importance i.e. summative mask. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 51 3. 'adahat_no_reg': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of network sparsity i.e. mask sparsity regularisation value. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 - **adjustment_intensity** (`float`): hyperparameter, control the overall intensity of gradient adjustment. It's the $\alpha$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 54 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity. 56 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 57 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 58 2. 'cross': the cross version mask sparsity regularisation. 59 - **task_embedding_init_mode** (`str`): the initialisation method for task embeddings, should be one of the following: 60 1. 'N01' (default): standard normal distribution $N(0, 1)$. 61 2. 'U-11': uniform distribution $U(-1, 1)$. 62 3. 'U01': uniform distribution $U(0, 1)$. 63 4. 'U-10': uniform distribution $U(-1, 0)$. 64 5. 'last': inherit task embedding from last task. 65 - **epsilon** (`float`): the value added to network sparsity to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 66 """ 67 HAT.__init__( 68 self, 69 backbone=backbone, 70 heads=heads, 71 adjustment_mode=adjustment_mode, 72 s_max=s_max, 73 clamp_threshold=clamp_threshold, 74 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 75 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 76 task_embedding_init_mode=task_embedding_init_mode, 77 alpha=None, 78 ) 79 80 self.adjustment_intensity = adjustment_intensity 81 r"""Store the adjustment intensity in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 82 self.epsilon = epsilon 83 """Store the small value to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).""" 84 85 self.summative_mask_for_previous_tasks: dict[str, Tensor] = {} 86 r"""Store the summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding summative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 87 88 # set manual optimisation 89 self.automatic_optimization = False 90 91 AdaHAT.sanity_check(self)
Initialise the AdaHAT algorithm with the network.
Args:
- backbone (
HATMaskBackbone
): must be a backbone network with HAT mask mechanism. - heads (
HeadsTIL
|HeadsCIL
): output heads. - adjustment_mode (
str
): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:- 'adahat': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach. This is the way that AdaHAT does, which allowes the part of network for previous tasks to be updated slightly. See equation (8) and (9) chapter 3.1 in AdaHAT paper.
- 'adahat_no_sum': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of parameter importance i.e. summative mask. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in AdaHAT paper.
- 'adahat_no_reg': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of network sparsity i.e. mask sparsity regularisation value. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in AdaHAT paper.
- adjustment_intensity (
float
): hyperparameter, control the overall intensity of gradient adjustment. It's the $\alpha$ in equation (9) in AdaHAT paper. - s_max (
float
): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in HAT paper. - clamp_threshold (
float
): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in HAT paper. - mask_sparsity_reg_factor (
float
): hyperparameter, the regularisation factor for mask sparsity. - mask_sparsity_reg_mode (
str
): the mode of mask sparsity regularisation, should be one of the following:- 'original' (default): the original mask sparsity regularisation in HAT paper.
- 'cross': the cross version mask sparsity regularisation.
- task_embedding_init_mode (
str
): the initialisation method for task embeddings, should be one of the following:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit task embedding from last task.
- epsilon (
float
): the value added to network sparsity to avoid division by zero appeared in equation (9) in AdaHAT paper.
Store the summative binary attention mask $\mathrm{M}^{
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False
you are responsible for calling .backward()
, .step()
, .zero_grad()
.
93 def sanity_check(self) -> None: 94 r"""Check the sanity of the arguments. 95 96 **Raises:** 97 - **ValueError**: If the `adjustment_intensity` is not positive. 98 """ 99 if self.adjustment_intensity <= 0: 100 raise ValueError( 101 f"The adjustment intensity should be positive, but got {self.adjustment_intensity}." 102 )
104 def on_train_start(self) -> None: 105 r"""Additionally initialise the summative mask at the beginning of first task.""" 106 HAT.on_train_start(self) 107 108 # initialise the summative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 109 if self.task_id == 1: 110 for layer_name in self.backbone.weighted_layer_names: 111 layer = self.backbone.get_layer_by_name( 112 layer_name 113 ) # get the layer by its name 114 num_units = layer.weight.shape[0] 115 116 self.summative_mask_for_previous_tasks[layer_name] = torch.zeros( 117 num_units 118 ).to( 119 self.device 120 ) # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialised as zeros mask ($t = 1$). See equation (7) in chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
Additionally initialise the summative mask at the beginning of first task.
122 def clip_grad_by_adjustment( 123 self, 124 network_sparsity: dict[str, Tensor] | None = None, 125 ) -> Tensor: 126 r"""Clip the gradients by the adjustment rate. 127 128 Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes. 129 130 Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 131 132 **Args:** 133 - **network_sparsity** (`dict[str, Tensor]` | `None`): The network sparsity i.e. the mask sparsity loss of each layer for the current task. It applies only to AdaHAT modes, as it is used to calculate the adjustment rate for the gradients. 134 135 **Returns:** 136 - **capacity** (`Tensor`): the calculated network capacity. 137 """ 138 139 # initialise network capacity metric 140 capacity = HATNetworkCapacity() 141 142 # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists) 143 for layer_name in self.backbone.weighted_layer_names: 144 145 layer = self.backbone.get_layer_by_name( 146 layer_name 147 ) # get the layer by its name 148 149 # placeholder for the adjustment rate to avoid the error of using it before assignment 150 adjustment_rate_weight = 1 151 adjustment_rate_bias = 1 152 153 weight_importance, bias_importance = ( 154 self.backbone.get_layer_measure_parameter_wise( 155 unit_wise_measure=self.summative_mask_for_previous_tasks, 156 layer_name=layer_name, 157 aggregation="min", 158 ) 159 ) # AdaHAT depend on parameter importance instead of parameter mask like HAT 160 161 network_sparsity_layer = network_sparsity[layer_name] 162 163 if self.adjustment_mode == "adahat": 164 r_layer = self.adjustment_intensity / ( 165 self.epsilon + network_sparsity_layer 166 ) 167 adjustment_rate_weight = torch.div( 168 r_layer, (weight_importance + r_layer) 169 ) 170 adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer)) 171 172 elif self.adjustment_mode == "adahat_no_sum": 173 174 r_layer = self.adjustment_intensity / ( 175 self.epsilon + network_sparsity_layer 176 ) 177 adjustment_rate_weight = torch.div(r_layer, (self.task_id + r_layer)) 178 adjustment_rate_bias = torch.div(r_layer, (self.task_id + r_layer)) 179 180 elif self.adjustment_mode == "adahat_no_reg": 181 182 r_layer = self.adjustment_intensity / (self.epsilon + 0.0) 183 adjustment_rate_weight = torch.div( 184 r_layer, (weight_importance + r_layer) 185 ) 186 adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer)) 187 188 # apply the adjustment rate to the gradients 189 layer.weight.grad.data *= adjustment_rate_weight 190 if layer.bias is not None: 191 layer.bias.grad.data *= adjustment_rate_bias 192 193 # update network capacity metric 194 capacity.update(adjustment_rate_weight, adjustment_rate_bias) 195 196 return capacity.compute()
Clip the gradients by the adjustment rate.
Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.
Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in AdaHAT paper.
Args:
- network_sparsity (
dict[str, Tensor]
|None
): The network sparsity i.e. the mask sparsity loss of each layer for the current task. It applies only to AdaHAT modes, as it is used to calculate the adjustment rate for the gradients.
Returns:
- capacity (
Tensor
): the calculated network capacity.
198 def on_train_end(self) -> None: 199 r"""Additionally update summative mask after training the task.""" 200 201 HAT.on_train_end(self) 202 203 mask_t = self.masks[ 204 f"{self.task_id}" 205 ] # get stored mask for the current task again 206 self.summative_mask_for_previous_tasks = { 207 layer_name: self.summative_mask_for_previous_tasks[layer_name] 208 + mask_t[layer_name] 209 for layer_name in self.backbone.weighted_layer_names 210 }
Additionally update summative mask after training the task.