clarena.cl_algorithms.regularisers.hat_mask_sparsity
The submodule in regularisers
for HAT mask sparsity regularisation.
1r"""The submodule in `regularisers` for HAT mask sparsity regularisation.""" 2 3__all__ = ["HATMaskSparsityReg"] 4 5from torch import Tensor, nn 6 7 8class HATMaskSparsityReg(nn.Module): 9 r"""Mask sparsity regulariser of HAT (Hard Attention to the Task). 10 11 $$ 12 R\left(\textsf{M}^t,\textsf{M}^{<t}\right)=\text{factor} * \frac{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}m_{l,i}^t\left(1-m_{l,i}^{<t}\right)}{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}\left(1-m_{l,i}^{<t}\right)} 13 $$ 14 15 It promotes the low capacity usage that is reflected by occupation of masks in the parameter space. 16 17 See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 18 """ 19 20 def __init__( 21 self, 22 factor: float, 23 mode: str = "original", 24 ) -> None: 25 r"""Initialise the regulariser. 26 27 **Args:** 28 - **factor** (`float`): the regularisation factor. 29 - **mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 30 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 31 2. 'cross': the cross version mask sparsity regularisation. 32 """ 33 nn.Module.__init__(self) 34 35 self.factor = factor 36 """Store the regularisation factor for mask sparsity.""" 37 self.mode = mode 38 """Store the mode of mask sparsity regularisation.""" 39 40 def forward( 41 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 42 ) -> Tensor: 43 r"""Calculate the mask sparsity regularisation loss. 44 45 **Args:** 46 - **mask** (`dict[str, Tensor]`): the mask for the current task. 47 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. 48 49 **Returns:** 50 - **reg** (`Tensor`): the mask sparsity regularisation value. 51 """ 52 53 if self.mode == "original": 54 return self.original_reg(mask, previous_cumulative_mask) 55 56 if self.mode == "cross": 57 return self.cross_reg(mask, previous_cumulative_mask) 58 59 def original_reg( 60 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 61 ) -> Tensor: 62 r"""Calculate the original mask sparsity regularisation loss in HAT paper. 63 64 See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 65 66 **Args:** 67 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 68 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 69 70 **Returns:** 71 - **reg** (`Tensor`): the original mask sparsity regularisation loss. 72 """ 73 74 count_available = 0 # number of units available for the new task 75 count_new_task_occupation_in_available = ( 76 0 # number of units occupied by the new task in the available units 77 ) 78 79 network_sparsity = {} 80 81 for layer_name in mask.keys(): 82 # statistics through all layers 83 available = ( 84 1 - previous_cumulative_mask[layer_name] 85 ).sum() # count the number of units available for the new task 86 87 new_task_occupation_in_available = ( 88 mask[layer_name] * (1 - previous_cumulative_mask[layer_name]) 89 ).sum() 90 # count the number of units occupied by the new task in the available units 91 92 # add to statistics 93 count_available += available 94 count_new_task_occupation_in_available += new_task_occupation_in_available 95 network_sparsity[layer_name] = ( 96 (new_task_occupation_in_available / available) if available > 10 else 0 97 ) 98 99 # the mask sparsity regularisation minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 100 reg = ( 101 count_new_task_occupation_in_available / count_available 102 if count_available 103 > 10 # to avoid division by a very small number, which makes the regularisation less meaningful 104 else 0 105 ) 106 107 return self.factor * reg, network_sparsity 108 109 def cross_reg( 110 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 111 ) -> Tensor: 112 r"""Calculate the cross mask sparsity regularisation loss. This is an attempting improvement by me to the original regularisation, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks. 113 114 **Args:** 115 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 116 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 117 118 **Returns:** 119 - **reg** (`Tensor`): the cross mask sparsity regularisation loss. 120 """ 121 122 count_previous = 0 # number of units occupied by the previous tasks 123 count_new_task_occupation_in_previous = ( 124 0 # number of units occupied by the new task in the previous tasks 125 ) 126 127 network_sparsity_2 = {} 128 129 for layer_name in mask.keys(): 130 # statistics through all layers 131 previous = previous_cumulative_mask[ 132 layer_name 133 ].sum() # count the number of units occupied by the previous tasks 134 new_task_occupation_in_previous = ( 135 mask[layer_name] * previous_cumulative_mask[layer_name].sum() 136 ) # count the number of units occupied by the new task in the previous tasks 137 138 # add to statistics 139 count_previous += previous 140 count_new_task_occupation_in_previous += new_task_occupation_in_previous 141 network_sparsity_2[layer_name] = ( 142 (new_task_occupation_in_previous / previous) if previous > 10 else 0 143 ) 144 145 # the mask sparsity regularisation maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 146 reg2 = ( 147 1 - count_new_task_occupation_in_previous / count_previous 148 if count_previous 149 > 10 # to avoid division by a very small number, which makes the regularisation less meaningful 150 else 0 151 ) 152 153 reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask) 154 155 reg = ( 156 reg1 + reg2 157 ) / 2 # our cross regularisation is the average of the original and the regularisation proposed above 158 159 network_sparsity = {} 160 for layer_name in mask.keys(): 161 # merge the two network sparsity statistics 162 network_sparsity[layer_name] = ( 163 network_sparsity_1[layer_name] + network_sparsity_2[layer_name] 164 ) / 2 165 166 return self.factor * reg, network_sparsity
9class HATMaskSparsityReg(nn.Module): 10 r"""Mask sparsity regulariser of HAT (Hard Attention to the Task). 11 12 $$ 13 R\left(\textsf{M}^t,\textsf{M}^{<t}\right)=\text{factor} * \frac{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}m_{l,i}^t\left(1-m_{l,i}^{<t}\right)}{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}\left(1-m_{l,i}^{<t}\right)} 14 $$ 15 16 It promotes the low capacity usage that is reflected by occupation of masks in the parameter space. 17 18 See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 19 """ 20 21 def __init__( 22 self, 23 factor: float, 24 mode: str = "original", 25 ) -> None: 26 r"""Initialise the regulariser. 27 28 **Args:** 29 - **factor** (`float`): the regularisation factor. 30 - **mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 31 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 32 2. 'cross': the cross version mask sparsity regularisation. 33 """ 34 nn.Module.__init__(self) 35 36 self.factor = factor 37 """Store the regularisation factor for mask sparsity.""" 38 self.mode = mode 39 """Store the mode of mask sparsity regularisation.""" 40 41 def forward( 42 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 43 ) -> Tensor: 44 r"""Calculate the mask sparsity regularisation loss. 45 46 **Args:** 47 - **mask** (`dict[str, Tensor]`): the mask for the current task. 48 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. 49 50 **Returns:** 51 - **reg** (`Tensor`): the mask sparsity regularisation value. 52 """ 53 54 if self.mode == "original": 55 return self.original_reg(mask, previous_cumulative_mask) 56 57 if self.mode == "cross": 58 return self.cross_reg(mask, previous_cumulative_mask) 59 60 def original_reg( 61 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 62 ) -> Tensor: 63 r"""Calculate the original mask sparsity regularisation loss in HAT paper. 64 65 See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 66 67 **Args:** 68 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 69 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 70 71 **Returns:** 72 - **reg** (`Tensor`): the original mask sparsity regularisation loss. 73 """ 74 75 count_available = 0 # number of units available for the new task 76 count_new_task_occupation_in_available = ( 77 0 # number of units occupied by the new task in the available units 78 ) 79 80 network_sparsity = {} 81 82 for layer_name in mask.keys(): 83 # statistics through all layers 84 available = ( 85 1 - previous_cumulative_mask[layer_name] 86 ).sum() # count the number of units available for the new task 87 88 new_task_occupation_in_available = ( 89 mask[layer_name] * (1 - previous_cumulative_mask[layer_name]) 90 ).sum() 91 # count the number of units occupied by the new task in the available units 92 93 # add to statistics 94 count_available += available 95 count_new_task_occupation_in_available += new_task_occupation_in_available 96 network_sparsity[layer_name] = ( 97 (new_task_occupation_in_available / available) if available > 10 else 0 98 ) 99 100 # the mask sparsity regularisation minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 101 reg = ( 102 count_new_task_occupation_in_available / count_available 103 if count_available 104 > 10 # to avoid division by a very small number, which makes the regularisation less meaningful 105 else 0 106 ) 107 108 return self.factor * reg, network_sparsity 109 110 def cross_reg( 111 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 112 ) -> Tensor: 113 r"""Calculate the cross mask sparsity regularisation loss. This is an attempting improvement by me to the original regularisation, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks. 114 115 **Args:** 116 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 117 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 118 119 **Returns:** 120 - **reg** (`Tensor`): the cross mask sparsity regularisation loss. 121 """ 122 123 count_previous = 0 # number of units occupied by the previous tasks 124 count_new_task_occupation_in_previous = ( 125 0 # number of units occupied by the new task in the previous tasks 126 ) 127 128 network_sparsity_2 = {} 129 130 for layer_name in mask.keys(): 131 # statistics through all layers 132 previous = previous_cumulative_mask[ 133 layer_name 134 ].sum() # count the number of units occupied by the previous tasks 135 new_task_occupation_in_previous = ( 136 mask[layer_name] * previous_cumulative_mask[layer_name].sum() 137 ) # count the number of units occupied by the new task in the previous tasks 138 139 # add to statistics 140 count_previous += previous 141 count_new_task_occupation_in_previous += new_task_occupation_in_previous 142 network_sparsity_2[layer_name] = ( 143 (new_task_occupation_in_previous / previous) if previous > 10 else 0 144 ) 145 146 # the mask sparsity regularisation maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 147 reg2 = ( 148 1 - count_new_task_occupation_in_previous / count_previous 149 if count_previous 150 > 10 # to avoid division by a very small number, which makes the regularisation less meaningful 151 else 0 152 ) 153 154 reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask) 155 156 reg = ( 157 reg1 + reg2 158 ) / 2 # our cross regularisation is the average of the original and the regularisation proposed above 159 160 network_sparsity = {} 161 for layer_name in mask.keys(): 162 # merge the two network sparsity statistics 163 network_sparsity[layer_name] = ( 164 network_sparsity_1[layer_name] + network_sparsity_2[layer_name] 165 ) / 2 166 167 return self.factor * reg, network_sparsity
Mask sparsity regulariser of HAT (Hard Attention to the Task).
$$
R\left(\textsf{M}^t,\textsf{M}^{
It promotes the low capacity usage that is reflected by occupation of masks in the parameter space.
See chapter 2.6 "Promoting Low Capacity Usage" in HAT paper.
21 def __init__( 22 self, 23 factor: float, 24 mode: str = "original", 25 ) -> None: 26 r"""Initialise the regulariser. 27 28 **Args:** 29 - **factor** (`float`): the regularisation factor. 30 - **mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 31 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 32 2. 'cross': the cross version mask sparsity regularisation. 33 """ 34 nn.Module.__init__(self) 35 36 self.factor = factor 37 """Store the regularisation factor for mask sparsity.""" 38 self.mode = mode 39 """Store the mode of mask sparsity regularisation."""
Initialise the regulariser.
Args:
- factor (
float
): the regularisation factor. - mode (
str
): the mode of mask sparsity regularisation, should be one of the following:- 'original' (default): the original mask sparsity regularisation in HAT paper.
- 'cross': the cross version mask sparsity regularisation.
41 def forward( 42 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 43 ) -> Tensor: 44 r"""Calculate the mask sparsity regularisation loss. 45 46 **Args:** 47 - **mask** (`dict[str, Tensor]`): the mask for the current task. 48 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. 49 50 **Returns:** 51 - **reg** (`Tensor`): the mask sparsity regularisation value. 52 """ 53 54 if self.mode == "original": 55 return self.original_reg(mask, previous_cumulative_mask) 56 57 if self.mode == "cross": 58 return self.cross_reg(mask, previous_cumulative_mask)
Calculate the mask sparsity regularisation loss.
Args:
- mask (
dict[str, Tensor]
): the mask for the current task. - previous_cumulative_mask (
dict[str, Tensor]
): the cumulative mask for the previous tasks.
Returns:
- reg (
Tensor
): the mask sparsity regularisation value.
60 def original_reg( 61 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 62 ) -> Tensor: 63 r"""Calculate the original mask sparsity regularisation loss in HAT paper. 64 65 See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 66 67 **Args:** 68 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 69 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 70 71 **Returns:** 72 - **reg** (`Tensor`): the original mask sparsity regularisation loss. 73 """ 74 75 count_available = 0 # number of units available for the new task 76 count_new_task_occupation_in_available = ( 77 0 # number of units occupied by the new task in the available units 78 ) 79 80 network_sparsity = {} 81 82 for layer_name in mask.keys(): 83 # statistics through all layers 84 available = ( 85 1 - previous_cumulative_mask[layer_name] 86 ).sum() # count the number of units available for the new task 87 88 new_task_occupation_in_available = ( 89 mask[layer_name] * (1 - previous_cumulative_mask[layer_name]) 90 ).sum() 91 # count the number of units occupied by the new task in the available units 92 93 # add to statistics 94 count_available += available 95 count_new_task_occupation_in_available += new_task_occupation_in_available 96 network_sparsity[layer_name] = ( 97 (new_task_occupation_in_available / available) if available > 10 else 0 98 ) 99 100 # the mask sparsity regularisation minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 101 reg = ( 102 count_new_task_occupation_in_available / count_available 103 if count_available 104 > 10 # to avoid division by a very small number, which makes the regularisation less meaningful 105 else 0 106 ) 107 108 return self.factor * reg, network_sparsity
Calculate the original mask sparsity regularisation loss in HAT paper.
See chapter 2.6 "Promoting Low Capacity Usage" in HAT paper.
Args:
- mask (
dict[str, Tensor]
): the mask for the current task. The $\mathrm{A}^t$ in the paper. - previous_cumulative_mask (
dict[str, Tensor]
): the cumulative mask for the previous tasks. The $\mathrm{A}^{
Returns:
- reg (
Tensor
): the original mask sparsity regularisation loss.
110 def cross_reg( 111 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 112 ) -> Tensor: 113 r"""Calculate the cross mask sparsity regularisation loss. This is an attempting improvement by me to the original regularisation, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks. 114 115 **Args:** 116 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 117 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 118 119 **Returns:** 120 - **reg** (`Tensor`): the cross mask sparsity regularisation loss. 121 """ 122 123 count_previous = 0 # number of units occupied by the previous tasks 124 count_new_task_occupation_in_previous = ( 125 0 # number of units occupied by the new task in the previous tasks 126 ) 127 128 network_sparsity_2 = {} 129 130 for layer_name in mask.keys(): 131 # statistics through all layers 132 previous = previous_cumulative_mask[ 133 layer_name 134 ].sum() # count the number of units occupied by the previous tasks 135 new_task_occupation_in_previous = ( 136 mask[layer_name] * previous_cumulative_mask[layer_name].sum() 137 ) # count the number of units occupied by the new task in the previous tasks 138 139 # add to statistics 140 count_previous += previous 141 count_new_task_occupation_in_previous += new_task_occupation_in_previous 142 network_sparsity_2[layer_name] = ( 143 (new_task_occupation_in_previous / previous) if previous > 10 else 0 144 ) 145 146 # the mask sparsity regularisation maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 147 reg2 = ( 148 1 - count_new_task_occupation_in_previous / count_previous 149 if count_previous 150 > 10 # to avoid division by a very small number, which makes the regularisation less meaningful 151 else 0 152 ) 153 154 reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask) 155 156 reg = ( 157 reg1 + reg2 158 ) / 2 # our cross regularisation is the average of the original and the regularisation proposed above 159 160 network_sparsity = {} 161 for layer_name in mask.keys(): 162 # merge the two network sparsity statistics 163 network_sparsity[layer_name] = ( 164 network_sparsity_1[layer_name] + network_sparsity_2[layer_name] 165 ) / 2 166 167 return self.factor * reg, network_sparsity
Calculate the cross mask sparsity regularisation loss. This is an attempting improvement by me to the original regularisation, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.
Args:
- mask (
dict[str, Tensor]
): the mask for the current task. The $\mathrm{A}^t$ in the paper. - previous_cumulative_mask (
dict[str, Tensor]
): the cumulative mask for the previous tasks. The $\mathrm{A}^{
Returns:
- reg (
Tensor
): the cross mask sparsity regularisation loss.