clarena.cl_algorithms.regularizers.hat_mask_sparsity
The submodule in regularizers for HAT (Hard Attention to the Task) mask sparsity regularization.
1r"""The submodule in `regularizers` for [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) mask sparsity regularization.""" 2 3__all__ = ["HATMaskSparsityReg"] 4 5import logging 6 7from torch import Tensor, nn 8 9# always get logger for built-in logging in each module 10pylogger = logging.getLogger(__name__) 11 12 13class HATMaskSparsityReg(nn.Module): 14 r"""Mask sparsity regularizer of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a). 15 16 $$ 17 R\left(\textsf{M}^t,\textsf{M}^{<t}\right)=\text{factor} * \frac{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}m_{l,i}^t\left(1-m_{l,i}^{<t}\right)}{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}\left(1-m_{l,i}^{<t}\right)} 18 $$ 19 20 It promotes the low capacity usage that is reflected by occupation of masks in the parameter space. 21 22 See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 23 """ 24 25 def __init__( 26 self, 27 factor: float, 28 mode: str = "original", 29 ) -> None: 30 r""" 31 **Args:** 32 - **factor** (`float`): the regularization factor. 33 - **mode** (`str`): the mode of mask sparsity regularization; one of: 34 1. 'original' (default): the original mask sparsity regularization in HAT paper. 35 2. 'cross': the cross version mask sparsity regularization. 36 """ 37 super().__init__() 38 39 self.factor = factor 40 """The regularization factor for mask sparsity.""" 41 self.mode = mode 42 """The mode of mask sparsity regularization.""" 43 44 def forward( 45 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 46 ) -> Tensor: 47 r"""Calculate the mask sparsity regularization loss. 48 49 **Args:** 50 - **mask** (`dict[str, Tensor]`): the mask for the current task. 51 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. 52 53 **Returns:** 54 - **reg** (`Tensor`): the mask sparsity regularization value. 55 """ 56 57 if self.mode == "original": 58 return self.original_reg(mask, previous_cumulative_mask) 59 60 if self.mode == "cross": 61 return self.cross_reg(mask, previous_cumulative_mask) 62 63 def original_reg( 64 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 65 ) -> tuple[Tensor, dict[str, Tensor]]: 66 r"""Calculate the original mask sparsity regularization loss in HAT paper. 67 68 See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 69 70 **Args:** 71 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 72 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 73 74 **Returns:** 75 - **reg** (`Tensor`): the original mask sparsity regularization loss. 76 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value. 77 """ 78 79 count_available = 0 # number of units available for the new task 80 count_new_task_occupation_in_available = ( 81 0 # number of units occupied by the new task in the available units 82 ) 83 84 network_sparsity = {} 85 86 for layer_name in mask.keys(): 87 # statistics through all layers 88 available = ( 89 1 - previous_cumulative_mask[layer_name] 90 ).sum() # count the number of units available for the new task 91 92 new_task_occupation_in_available = ( 93 mask[layer_name] * (1 - previous_cumulative_mask[layer_name]) 94 ).sum() 95 # count the number of units occupied by the new task in the available units 96 97 # add to statistics 98 count_available += available 99 count_new_task_occupation_in_available += new_task_occupation_in_available 100 network_sparsity[layer_name] = ( 101 (new_task_occupation_in_available / available) if available > 10 else 0 102 ) 103 104 # the mask sparsity regularization minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 105 reg = ( 106 count_new_task_occupation_in_available / count_available 107 if count_available 108 > 10 # to avoid division by a very small number, which makes the regularization less meaningful 109 else 0 110 ) 111 112 return self.factor * reg, network_sparsity 113 114 def cross_reg( 115 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 116 ) -> tuple[Tensor, dict[str, Tensor]]: 117 r"""Calculate the cross mask sparsity regularization loss. This is an attempting improvement by me to the original regularization, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks. 118 119 **Args:** 120 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 121 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 122 123 **Returns:** 124 - **reg** (`Tensor`): the cross mask sparsity regularization loss. 125 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value. 126 """ 127 128 count_previous = 0 # number of units occupied by the previous tasks 129 count_new_task_occupation_in_previous = ( 130 0 # number of units occupied by the new task in the previous tasks 131 ) 132 133 network_sparsity_2 = {} 134 135 for layer_name in mask.keys(): 136 # statistics through all layers 137 previous = previous_cumulative_mask[ 138 layer_name 139 ].sum() # count the number of units occupied by the previous tasks 140 new_task_occupation_in_previous = ( 141 mask[layer_name] * previous_cumulative_mask[layer_name].sum() 142 ) # count the number of units occupied by the new task in the previous tasks 143 144 # add to statistics 145 count_previous += previous 146 count_new_task_occupation_in_previous += new_task_occupation_in_previous 147 network_sparsity_2[layer_name] = ( 148 (new_task_occupation_in_previous / previous) if previous > 10 else 0 149 ) 150 151 # the mask sparsity regularization maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 152 reg2 = ( 153 1 - count_new_task_occupation_in_previous / count_previous 154 if count_previous 155 > 10 # to avoid division by a very small number, which makes the regularization less meaningful 156 else 0 157 ) 158 159 reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask) 160 161 reg = ( 162 reg1 + reg2 163 ) / 2 # our cross regularization is the average of the original and the regularization proposed above 164 165 network_sparsity = {} 166 for layer_name in mask.keys(): 167 # merge the two network sparsity statistics 168 network_sparsity[layer_name] = ( 169 network_sparsity_1[layer_name] + network_sparsity_2[layer_name] 170 ) / 2 171 172 return self.factor * reg, network_sparsity
14class HATMaskSparsityReg(nn.Module): 15 r"""Mask sparsity regularizer of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a). 16 17 $$ 18 R\left(\textsf{M}^t,\textsf{M}^{<t}\right)=\text{factor} * \frac{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}m_{l,i}^t\left(1-m_{l,i}^{<t}\right)}{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}\left(1-m_{l,i}^{<t}\right)} 19 $$ 20 21 It promotes the low capacity usage that is reflected by occupation of masks in the parameter space. 22 23 See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 24 """ 25 26 def __init__( 27 self, 28 factor: float, 29 mode: str = "original", 30 ) -> None: 31 r""" 32 **Args:** 33 - **factor** (`float`): the regularization factor. 34 - **mode** (`str`): the mode of mask sparsity regularization; one of: 35 1. 'original' (default): the original mask sparsity regularization in HAT paper. 36 2. 'cross': the cross version mask sparsity regularization. 37 """ 38 super().__init__() 39 40 self.factor = factor 41 """The regularization factor for mask sparsity.""" 42 self.mode = mode 43 """The mode of mask sparsity regularization.""" 44 45 def forward( 46 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 47 ) -> Tensor: 48 r"""Calculate the mask sparsity regularization loss. 49 50 **Args:** 51 - **mask** (`dict[str, Tensor]`): the mask for the current task. 52 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. 53 54 **Returns:** 55 - **reg** (`Tensor`): the mask sparsity regularization value. 56 """ 57 58 if self.mode == "original": 59 return self.original_reg(mask, previous_cumulative_mask) 60 61 if self.mode == "cross": 62 return self.cross_reg(mask, previous_cumulative_mask) 63 64 def original_reg( 65 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 66 ) -> tuple[Tensor, dict[str, Tensor]]: 67 r"""Calculate the original mask sparsity regularization loss in HAT paper. 68 69 See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 70 71 **Args:** 72 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 73 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 74 75 **Returns:** 76 - **reg** (`Tensor`): the original mask sparsity regularization loss. 77 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value. 78 """ 79 80 count_available = 0 # number of units available for the new task 81 count_new_task_occupation_in_available = ( 82 0 # number of units occupied by the new task in the available units 83 ) 84 85 network_sparsity = {} 86 87 for layer_name in mask.keys(): 88 # statistics through all layers 89 available = ( 90 1 - previous_cumulative_mask[layer_name] 91 ).sum() # count the number of units available for the new task 92 93 new_task_occupation_in_available = ( 94 mask[layer_name] * (1 - previous_cumulative_mask[layer_name]) 95 ).sum() 96 # count the number of units occupied by the new task in the available units 97 98 # add to statistics 99 count_available += available 100 count_new_task_occupation_in_available += new_task_occupation_in_available 101 network_sparsity[layer_name] = ( 102 (new_task_occupation_in_available / available) if available > 10 else 0 103 ) 104 105 # the mask sparsity regularization minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 106 reg = ( 107 count_new_task_occupation_in_available / count_available 108 if count_available 109 > 10 # to avoid division by a very small number, which makes the regularization less meaningful 110 else 0 111 ) 112 113 return self.factor * reg, network_sparsity 114 115 def cross_reg( 116 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 117 ) -> tuple[Tensor, dict[str, Tensor]]: 118 r"""Calculate the cross mask sparsity regularization loss. This is an attempting improvement by me to the original regularization, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks. 119 120 **Args:** 121 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 122 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 123 124 **Returns:** 125 - **reg** (`Tensor`): the cross mask sparsity regularization loss. 126 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value. 127 """ 128 129 count_previous = 0 # number of units occupied by the previous tasks 130 count_new_task_occupation_in_previous = ( 131 0 # number of units occupied by the new task in the previous tasks 132 ) 133 134 network_sparsity_2 = {} 135 136 for layer_name in mask.keys(): 137 # statistics through all layers 138 previous = previous_cumulative_mask[ 139 layer_name 140 ].sum() # count the number of units occupied by the previous tasks 141 new_task_occupation_in_previous = ( 142 mask[layer_name] * previous_cumulative_mask[layer_name].sum() 143 ) # count the number of units occupied by the new task in the previous tasks 144 145 # add to statistics 146 count_previous += previous 147 count_new_task_occupation_in_previous += new_task_occupation_in_previous 148 network_sparsity_2[layer_name] = ( 149 (new_task_occupation_in_previous / previous) if previous > 10 else 0 150 ) 151 152 # the mask sparsity regularization maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 153 reg2 = ( 154 1 - count_new_task_occupation_in_previous / count_previous 155 if count_previous 156 > 10 # to avoid division by a very small number, which makes the regularization less meaningful 157 else 0 158 ) 159 160 reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask) 161 162 reg = ( 163 reg1 + reg2 164 ) / 2 # our cross regularization is the average of the original and the regularization proposed above 165 166 network_sparsity = {} 167 for layer_name in mask.keys(): 168 # merge the two network sparsity statistics 169 network_sparsity[layer_name] = ( 170 network_sparsity_1[layer_name] + network_sparsity_2[layer_name] 171 ) / 2 172 173 return self.factor * reg, network_sparsity
Mask sparsity regularizer of HAT (Hard Attention to the Task).
$$
R\left(\textsf{M}^t,\textsf{M}^{ It promotes the low capacity usage that is reflected by occupation of masks in the parameter space. See chapter 2.6 "Promoting Low Capacity Usage" in the HAT paper.
26 def __init__( 27 self, 28 factor: float, 29 mode: str = "original", 30 ) -> None: 31 r""" 32 **Args:** 33 - **factor** (`float`): the regularization factor. 34 - **mode** (`str`): the mode of mask sparsity regularization; one of: 35 1. 'original' (default): the original mask sparsity regularization in HAT paper. 36 2. 'cross': the cross version mask sparsity regularization. 37 """ 38 super().__init__() 39 40 self.factor = factor 41 """The regularization factor for mask sparsity.""" 42 self.mode = mode 43 """The mode of mask sparsity regularization."""
Args:
- factor (
float): the regularization factor. - mode (
str): the mode of mask sparsity regularization; one of:- 'original' (default): the original mask sparsity regularization in HAT paper.
- 'cross': the cross version mask sparsity regularization.
45 def forward( 46 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 47 ) -> Tensor: 48 r"""Calculate the mask sparsity regularization loss. 49 50 **Args:** 51 - **mask** (`dict[str, Tensor]`): the mask for the current task. 52 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. 53 54 **Returns:** 55 - **reg** (`Tensor`): the mask sparsity regularization value. 56 """ 57 58 if self.mode == "original": 59 return self.original_reg(mask, previous_cumulative_mask) 60 61 if self.mode == "cross": 62 return self.cross_reg(mask, previous_cumulative_mask)
Calculate the mask sparsity regularization loss.
Args:
- mask (
dict[str, Tensor]): the mask for the current task. - previous_cumulative_mask (
dict[str, Tensor]): the cumulative mask for the previous tasks.
Returns:
- reg (
Tensor): the mask sparsity regularization value.
64 def original_reg( 65 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 66 ) -> tuple[Tensor, dict[str, Tensor]]: 67 r"""Calculate the original mask sparsity regularization loss in HAT paper. 68 69 See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 70 71 **Args:** 72 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 73 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 74 75 **Returns:** 76 - **reg** (`Tensor`): the original mask sparsity regularization loss. 77 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value. 78 """ 79 80 count_available = 0 # number of units available for the new task 81 count_new_task_occupation_in_available = ( 82 0 # number of units occupied by the new task in the available units 83 ) 84 85 network_sparsity = {} 86 87 for layer_name in mask.keys(): 88 # statistics through all layers 89 available = ( 90 1 - previous_cumulative_mask[layer_name] 91 ).sum() # count the number of units available for the new task 92 93 new_task_occupation_in_available = ( 94 mask[layer_name] * (1 - previous_cumulative_mask[layer_name]) 95 ).sum() 96 # count the number of units occupied by the new task in the available units 97 98 # add to statistics 99 count_available += available 100 count_new_task_occupation_in_available += new_task_occupation_in_available 101 network_sparsity[layer_name] = ( 102 (new_task_occupation_in_available / available) if available > 10 else 0 103 ) 104 105 # the mask sparsity regularization minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 106 reg = ( 107 count_new_task_occupation_in_available / count_available 108 if count_available 109 > 10 # to avoid division by a very small number, which makes the regularization less meaningful 110 else 0 111 ) 112 113 return self.factor * reg, network_sparsity
Calculate the original mask sparsity regularization loss in HAT paper.
See chapter 2.6 "Promoting Low Capacity Usage" in the HAT paper.
Args:
- mask (
dict[str, Tensor]): the mask for the current task. The $\mathrm{A}^t$ in the paper. - previous_cumulative_mask (
dict[str, Tensor]): the cumulative mask for the previous tasks. The $\mathrm{A}^{
Returns:
- reg (
Tensor): the original mask sparsity regularization loss. - network_sparsity (
dict[str, Tensor]): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.
115 def cross_reg( 116 self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor] 117 ) -> tuple[Tensor, dict[str, Tensor]]: 118 r"""Calculate the cross mask sparsity regularization loss. This is an attempting improvement by me to the original regularization, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks. 119 120 **Args:** 121 - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper. 122 - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper. 123 124 **Returns:** 125 - **reg** (`Tensor`): the cross mask sparsity regularization loss. 126 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value. 127 """ 128 129 count_previous = 0 # number of units occupied by the previous tasks 130 count_new_task_occupation_in_previous = ( 131 0 # number of units occupied by the new task in the previous tasks 132 ) 133 134 network_sparsity_2 = {} 135 136 for layer_name in mask.keys(): 137 # statistics through all layers 138 previous = previous_cumulative_mask[ 139 layer_name 140 ].sum() # count the number of units occupied by the previous tasks 141 new_task_occupation_in_previous = ( 142 mask[layer_name] * previous_cumulative_mask[layer_name].sum() 143 ) # count the number of units occupied by the new task in the previous tasks 144 145 # add to statistics 146 count_previous += previous 147 count_new_task_occupation_in_previous += new_task_occupation_in_previous 148 network_sparsity_2[layer_name] = ( 149 (new_task_occupation_in_previous / previous) if previous > 10 else 0 150 ) 151 152 # the mask sparsity regularization maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units. 153 reg2 = ( 154 1 - count_new_task_occupation_in_previous / count_previous 155 if count_previous 156 > 10 # to avoid division by a very small number, which makes the regularization less meaningful 157 else 0 158 ) 159 160 reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask) 161 162 reg = ( 163 reg1 + reg2 164 ) / 2 # our cross regularization is the average of the original and the regularization proposed above 165 166 network_sparsity = {} 167 for layer_name in mask.keys(): 168 # merge the two network sparsity statistics 169 network_sparsity[layer_name] = ( 170 network_sparsity_1[layer_name] + network_sparsity_2[layer_name] 171 ) / 2 172 173 return self.factor * reg, network_sparsity
Calculate the cross mask sparsity regularization loss. This is an attempting improvement by me to the original regularization, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.
Args:
- mask (
dict[str, Tensor]): the mask for the current task. The $\mathrm{A}^t$ in the paper. - previous_cumulative_mask (
dict[str, Tensor]): the cumulative mask for the previous tasks. The $\mathrm{A}^{
Returns:
- reg (
Tensor): the cross mask sparsity regularization loss. - network_sparsity (
dict[str, Tensor]): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.