clarena.cl_algorithms.regularizers.hat_mask_sparsity

The submodule in regularizers for HAT (Hard Attention to the Task) mask sparsity regularization.

  1r"""The submodule in `regularizers` for [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) mask sparsity regularization."""
  2
  3__all__ = ["HATMaskSparsityReg"]
  4
  5import logging
  6
  7from torch import Tensor, nn
  8
  9# always get logger for built-in logging in each module
 10pylogger = logging.getLogger(__name__)
 11
 12
 13class HATMaskSparsityReg(nn.Module):
 14    r"""Mask sparsity regularizer of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a).
 15
 16    $$
 17    R\left(\textsf{M}^t,\textsf{M}^{<t}\right)=\text{factor} * \frac{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}m_{l,i}^t\left(1-m_{l,i}^{<t}\right)}{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}\left(1-m_{l,i}^{<t}\right)}
 18    $$
 19
 20    It promotes the low capacity usage that is reflected by occupation of masks in the parameter space.
 21
 22    See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 23    """
 24
 25    def __init__(
 26        self,
 27        factor: float,
 28        mode: str = "original",
 29    ) -> None:
 30        r"""
 31        **Args:**
 32        - **factor** (`float`): the regularization factor.
 33        - **mode** (`str`): the mode of mask sparsity regularization; one of:
 34            1. 'original' (default): the original mask sparsity regularization in HAT paper.
 35            2. 'cross': the cross version mask sparsity regularization.
 36        """
 37        super().__init__()
 38
 39        self.factor = factor
 40        """The regularization factor for mask sparsity."""
 41        self.mode = mode
 42        """The mode of mask sparsity regularization."""
 43
 44    def forward(
 45        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 46    ) -> Tensor:
 47        r"""Calculate the mask sparsity regularization loss.
 48
 49        **Args:**
 50        - **mask** (`dict[str, Tensor]`): the mask for the current task.
 51        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks.
 52
 53        **Returns:**
 54        - **reg** (`Tensor`): the mask sparsity regularization value.
 55        """
 56
 57        if self.mode == "original":
 58            return self.original_reg(mask, previous_cumulative_mask)
 59
 60        if self.mode == "cross":
 61            return self.cross_reg(mask, previous_cumulative_mask)
 62
 63    def original_reg(
 64        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 65    ) -> tuple[Tensor, dict[str, Tensor]]:
 66        r"""Calculate the original mask sparsity regularization loss in HAT paper.
 67
 68        See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 69
 70        **Args:**
 71        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
 72        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
 73
 74        **Returns:**
 75        - **reg** (`Tensor`): the original mask sparsity regularization loss.
 76        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.
 77        """
 78
 79        count_available = 0  # number of units available for the new task
 80        count_new_task_occupation_in_available = (
 81            0  # number of units occupied by the new task in the available units
 82        )
 83
 84        network_sparsity = {}
 85
 86        for layer_name in mask.keys():
 87            # statistics through all layers
 88            available = (
 89                1 - previous_cumulative_mask[layer_name]
 90            ).sum()  # count the number of units available for the new task
 91
 92            new_task_occupation_in_available = (
 93                mask[layer_name] * (1 - previous_cumulative_mask[layer_name])
 94            ).sum()
 95            # count the number of units occupied by the new task in the available units
 96
 97            # add to statistics
 98            count_available += available
 99            count_new_task_occupation_in_available += new_task_occupation_in_available
100            network_sparsity[layer_name] = (
101                (new_task_occupation_in_available / available) if available > 10 else 0
102            )
103
104        # the mask sparsity regularization minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
105        reg = (
106            count_new_task_occupation_in_available / count_available
107            if count_available
108            > 10  # to avoid division by a very small number, which makes the regularization less meaningful
109            else 0
110        )
111
112        return self.factor * reg, network_sparsity
113
114    def cross_reg(
115        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
116    ) -> tuple[Tensor, dict[str, Tensor]]:
117        r"""Calculate the cross mask sparsity regularization loss. This is an attempting improvement by me to the original regularization, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.
118
119        **Args:**
120        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
121        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
122
123        **Returns:**
124        - **reg** (`Tensor`): the cross mask sparsity regularization loss.
125        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.
126        """
127
128        count_previous = 0  # number of units occupied by the previous tasks
129        count_new_task_occupation_in_previous = (
130            0  # number of units occupied by the new task in the previous tasks
131        )
132
133        network_sparsity_2 = {}
134
135        for layer_name in mask.keys():
136            # statistics through all layers
137            previous = previous_cumulative_mask[
138                layer_name
139            ].sum()  # count the number of units occupied by the previous tasks
140            new_task_occupation_in_previous = (
141                mask[layer_name] * previous_cumulative_mask[layer_name].sum()
142            )  # count the number of units occupied by the new task in the previous tasks
143
144            # add to statistics
145            count_previous += previous
146            count_new_task_occupation_in_previous += new_task_occupation_in_previous
147            network_sparsity_2[layer_name] = (
148                (new_task_occupation_in_previous / previous) if previous > 10 else 0
149            )
150
151        # the mask sparsity regularization maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
152        reg2 = (
153            1 - count_new_task_occupation_in_previous / count_previous
154            if count_previous
155            > 10  # to avoid division by a very small number, which makes the regularization less meaningful
156            else 0
157        )
158
159        reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask)
160
161        reg = (
162            reg1 + reg2
163        ) / 2  # our cross regularization is the average of the original and the regularization proposed above
164
165        network_sparsity = {}
166        for layer_name in mask.keys():
167            # merge the two network sparsity statistics
168            network_sparsity[layer_name] = (
169                network_sparsity_1[layer_name] + network_sparsity_2[layer_name]
170            ) / 2
171
172        return self.factor * reg, network_sparsity
class HATMaskSparsityReg(torch.nn.modules.module.Module):
 14class HATMaskSparsityReg(nn.Module):
 15    r"""Mask sparsity regularizer of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a).
 16
 17    $$
 18    R\left(\textsf{M}^t,\textsf{M}^{<t}\right)=\text{factor} * \frac{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}m_{l,i}^t\left(1-m_{l,i}^{<t}\right)}{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}\left(1-m_{l,i}^{<t}\right)}
 19    $$
 20
 21    It promotes the low capacity usage that is reflected by occupation of masks in the parameter space.
 22
 23    See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 24    """
 25
 26    def __init__(
 27        self,
 28        factor: float,
 29        mode: str = "original",
 30    ) -> None:
 31        r"""
 32        **Args:**
 33        - **factor** (`float`): the regularization factor.
 34        - **mode** (`str`): the mode of mask sparsity regularization; one of:
 35            1. 'original' (default): the original mask sparsity regularization in HAT paper.
 36            2. 'cross': the cross version mask sparsity regularization.
 37        """
 38        super().__init__()
 39
 40        self.factor = factor
 41        """The regularization factor for mask sparsity."""
 42        self.mode = mode
 43        """The mode of mask sparsity regularization."""
 44
 45    def forward(
 46        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 47    ) -> Tensor:
 48        r"""Calculate the mask sparsity regularization loss.
 49
 50        **Args:**
 51        - **mask** (`dict[str, Tensor]`): the mask for the current task.
 52        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks.
 53
 54        **Returns:**
 55        - **reg** (`Tensor`): the mask sparsity regularization value.
 56        """
 57
 58        if self.mode == "original":
 59            return self.original_reg(mask, previous_cumulative_mask)
 60
 61        if self.mode == "cross":
 62            return self.cross_reg(mask, previous_cumulative_mask)
 63
 64    def original_reg(
 65        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 66    ) -> tuple[Tensor, dict[str, Tensor]]:
 67        r"""Calculate the original mask sparsity regularization loss in HAT paper.
 68
 69        See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 70
 71        **Args:**
 72        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
 73        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
 74
 75        **Returns:**
 76        - **reg** (`Tensor`): the original mask sparsity regularization loss.
 77        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.
 78        """
 79
 80        count_available = 0  # number of units available for the new task
 81        count_new_task_occupation_in_available = (
 82            0  # number of units occupied by the new task in the available units
 83        )
 84
 85        network_sparsity = {}
 86
 87        for layer_name in mask.keys():
 88            # statistics through all layers
 89            available = (
 90                1 - previous_cumulative_mask[layer_name]
 91            ).sum()  # count the number of units available for the new task
 92
 93            new_task_occupation_in_available = (
 94                mask[layer_name] * (1 - previous_cumulative_mask[layer_name])
 95            ).sum()
 96            # count the number of units occupied by the new task in the available units
 97
 98            # add to statistics
 99            count_available += available
100            count_new_task_occupation_in_available += new_task_occupation_in_available
101            network_sparsity[layer_name] = (
102                (new_task_occupation_in_available / available) if available > 10 else 0
103            )
104
105        # the mask sparsity regularization minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
106        reg = (
107            count_new_task_occupation_in_available / count_available
108            if count_available
109            > 10  # to avoid division by a very small number, which makes the regularization less meaningful
110            else 0
111        )
112
113        return self.factor * reg, network_sparsity
114
115    def cross_reg(
116        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
117    ) -> tuple[Tensor, dict[str, Tensor]]:
118        r"""Calculate the cross mask sparsity regularization loss. This is an attempting improvement by me to the original regularization, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.
119
120        **Args:**
121        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
122        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
123
124        **Returns:**
125        - **reg** (`Tensor`): the cross mask sparsity regularization loss.
126        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.
127        """
128
129        count_previous = 0  # number of units occupied by the previous tasks
130        count_new_task_occupation_in_previous = (
131            0  # number of units occupied by the new task in the previous tasks
132        )
133
134        network_sparsity_2 = {}
135
136        for layer_name in mask.keys():
137            # statistics through all layers
138            previous = previous_cumulative_mask[
139                layer_name
140            ].sum()  # count the number of units occupied by the previous tasks
141            new_task_occupation_in_previous = (
142                mask[layer_name] * previous_cumulative_mask[layer_name].sum()
143            )  # count the number of units occupied by the new task in the previous tasks
144
145            # add to statistics
146            count_previous += previous
147            count_new_task_occupation_in_previous += new_task_occupation_in_previous
148            network_sparsity_2[layer_name] = (
149                (new_task_occupation_in_previous / previous) if previous > 10 else 0
150            )
151
152        # the mask sparsity regularization maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
153        reg2 = (
154            1 - count_new_task_occupation_in_previous / count_previous
155            if count_previous
156            > 10  # to avoid division by a very small number, which makes the regularization less meaningful
157            else 0
158        )
159
160        reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask)
161
162        reg = (
163            reg1 + reg2
164        ) / 2  # our cross regularization is the average of the original and the regularization proposed above
165
166        network_sparsity = {}
167        for layer_name in mask.keys():
168            # merge the two network sparsity statistics
169            network_sparsity[layer_name] = (
170                network_sparsity_1[layer_name] + network_sparsity_2[layer_name]
171            ) / 2
172
173        return self.factor * reg, network_sparsity

Mask sparsity regularizer of HAT (Hard Attention to the Task).

$$ R\left(\textsf{M}^t,\textsf{M}^{

It promotes the low capacity usage that is reflected by occupation of masks in the parameter space.

See chapter 2.6 "Promoting Low Capacity Usage" in the HAT paper.

HATMaskSparsityReg(factor: float, mode: str = 'original')
26    def __init__(
27        self,
28        factor: float,
29        mode: str = "original",
30    ) -> None:
31        r"""
32        **Args:**
33        - **factor** (`float`): the regularization factor.
34        - **mode** (`str`): the mode of mask sparsity regularization; one of:
35            1. 'original' (default): the original mask sparsity regularization in HAT paper.
36            2. 'cross': the cross version mask sparsity regularization.
37        """
38        super().__init__()
39
40        self.factor = factor
41        """The regularization factor for mask sparsity."""
42        self.mode = mode
43        """The mode of mask sparsity regularization."""

Args:

  • factor (float): the regularization factor.
  • mode (str): the mode of mask sparsity regularization; one of:
    1. 'original' (default): the original mask sparsity regularization in HAT paper.
    2. 'cross': the cross version mask sparsity regularization.
factor

The regularization factor for mask sparsity.

mode

The mode of mask sparsity regularization.

def forward( self, mask: dict[str, torch.Tensor], previous_cumulative_mask: dict[str, torch.Tensor]) -> torch.Tensor:
45    def forward(
46        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
47    ) -> Tensor:
48        r"""Calculate the mask sparsity regularization loss.
49
50        **Args:**
51        - **mask** (`dict[str, Tensor]`): the mask for the current task.
52        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks.
53
54        **Returns:**
55        - **reg** (`Tensor`): the mask sparsity regularization value.
56        """
57
58        if self.mode == "original":
59            return self.original_reg(mask, previous_cumulative_mask)
60
61        if self.mode == "cross":
62            return self.cross_reg(mask, previous_cumulative_mask)

Calculate the mask sparsity regularization loss.

Args:

  • mask (dict[str, Tensor]): the mask for the current task.
  • previous_cumulative_mask (dict[str, Tensor]): the cumulative mask for the previous tasks.

Returns:

  • reg (Tensor): the mask sparsity regularization value.
def original_reg( self, mask: dict[str, torch.Tensor], previous_cumulative_mask: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
 64    def original_reg(
 65        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 66    ) -> tuple[Tensor, dict[str, Tensor]]:
 67        r"""Calculate the original mask sparsity regularization loss in HAT paper.
 68
 69        See chapter 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 70
 71        **Args:**
 72        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
 73        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
 74
 75        **Returns:**
 76        - **reg** (`Tensor`): the original mask sparsity regularization loss.
 77        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.
 78        """
 79
 80        count_available = 0  # number of units available for the new task
 81        count_new_task_occupation_in_available = (
 82            0  # number of units occupied by the new task in the available units
 83        )
 84
 85        network_sparsity = {}
 86
 87        for layer_name in mask.keys():
 88            # statistics through all layers
 89            available = (
 90                1 - previous_cumulative_mask[layer_name]
 91            ).sum()  # count the number of units available for the new task
 92
 93            new_task_occupation_in_available = (
 94                mask[layer_name] * (1 - previous_cumulative_mask[layer_name])
 95            ).sum()
 96            # count the number of units occupied by the new task in the available units
 97
 98            # add to statistics
 99            count_available += available
100            count_new_task_occupation_in_available += new_task_occupation_in_available
101            network_sparsity[layer_name] = (
102                (new_task_occupation_in_available / available) if available > 10 else 0
103            )
104
105        # the mask sparsity regularization minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
106        reg = (
107            count_new_task_occupation_in_available / count_available
108            if count_available
109            > 10  # to avoid division by a very small number, which makes the regularization less meaningful
110            else 0
111        )
112
113        return self.factor * reg, network_sparsity

Calculate the original mask sparsity regularization loss in HAT paper.

See chapter 2.6 "Promoting Low Capacity Usage" in the HAT paper.

Args:

  • mask (dict[str, Tensor]): the mask for the current task. The $\mathrm{A}^t$ in the paper.
  • previous_cumulative_mask (dict[str, Tensor]): the cumulative mask for the previous tasks. The $\mathrm{A}^{

Returns:

  • reg (Tensor): the original mask sparsity regularization loss.
  • network_sparsity (dict[str, Tensor]): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.
def cross_reg( self, mask: dict[str, torch.Tensor], previous_cumulative_mask: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
115    def cross_reg(
116        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
117    ) -> tuple[Tensor, dict[str, Tensor]]:
118        r"""Calculate the cross mask sparsity regularization loss. This is an attempting improvement by me to the original regularization, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.
119
120        **Args:**
121        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
122        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
123
124        **Returns:**
125        - **reg** (`Tensor`): the cross mask sparsity regularization loss.
126        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.
127        """
128
129        count_previous = 0  # number of units occupied by the previous tasks
130        count_new_task_occupation_in_previous = (
131            0  # number of units occupied by the new task in the previous tasks
132        )
133
134        network_sparsity_2 = {}
135
136        for layer_name in mask.keys():
137            # statistics through all layers
138            previous = previous_cumulative_mask[
139                layer_name
140            ].sum()  # count the number of units occupied by the previous tasks
141            new_task_occupation_in_previous = (
142                mask[layer_name] * previous_cumulative_mask[layer_name].sum()
143            )  # count the number of units occupied by the new task in the previous tasks
144
145            # add to statistics
146            count_previous += previous
147            count_new_task_occupation_in_previous += new_task_occupation_in_previous
148            network_sparsity_2[layer_name] = (
149                (new_task_occupation_in_previous / previous) if previous > 10 else 0
150            )
151
152        # the mask sparsity regularization maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularizizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
153        reg2 = (
154            1 - count_new_task_occupation_in_previous / count_previous
155            if count_previous
156            > 10  # to avoid division by a very small number, which makes the regularization less meaningful
157            else 0
158        )
159
160        reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask)
161
162        reg = (
163            reg1 + reg2
164        ) / 2  # our cross regularization is the average of the original and the regularization proposed above
165
166        network_sparsity = {}
167        for layer_name in mask.keys():
168            # merge the two network sparsity statistics
169            network_sparsity[layer_name] = (
170                network_sparsity_1[layer_name] + network_sparsity_2[layer_name]
171            ) / 2
172
173        return self.factor * reg, network_sparsity

Calculate the cross mask sparsity regularization loss. This is an attempting improvement by me to the original regularization, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.

Args:

  • mask (dict[str, Tensor]): the mask for the current task. The $\mathrm{A}^t$ in the paper.
  • previous_cumulative_mask (dict[str, Tensor]): the cumulative mask for the previous tasks. The $\mathrm{A}^{

Returns:

  • reg (Tensor): the cross mask sparsity regularization loss.
  • network_sparsity (dict[str, Tensor]): the network sparsity for each layer. Keys are layer names and values are the network sparsity value.