clarena.cl_algorithms.regularisers.hat_mask_sparsity

The submodule in regularisers for HAT mask sparsity regularisation.

  1r"""The submodule in `regularisers` for HAT mask sparsity regularisation."""
  2
  3__all__ = ["HATMaskSparsityReg"]
  4
  5from torch import Tensor, nn
  6
  7
  8class HATMaskSparsityReg(nn.Module):
  9    r"""Mask sparsity regulariser of HAT (Hard Attention to the Task).
 10
 11    $$
 12    R\left(\textsf{M}^t,\textsf{M}^{<t}\right)=\text{factor} * \frac{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}m_{l,i}^t\left(1-m_{l,i}^{<t}\right)}{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}\left(1-m_{l,i}^{<t}\right)}
 13    $$
 14
 15    It promotes the low capacity usage that is reflected by occupation of masks in the parameter space.
 16
 17    See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 18    """
 19
 20    def __init__(
 21        self,
 22        factor: float,
 23        mode: str = "original",
 24    ) -> None:
 25        r"""Initialise the regulariser.
 26
 27        **Args:**
 28        - **factor** (`float`): the regularisation factor.
 29        - **mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
 30            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
 31            2. 'cross': the cross version mask sparsity regularisation.
 32        """
 33        nn.Module.__init__(self)
 34
 35        self.factor = factor
 36        """Store the regularisation factor for mask sparsity."""
 37        self.mode = mode
 38        """Store the mode of mask sparsity regularisation."""
 39
 40    def forward(
 41        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 42    ) -> Tensor:
 43        r"""Calculate the mask sparsity regularisation loss.
 44
 45        **Args:**
 46        - **mask** (`dict[str, Tensor]`): the mask for the current task.
 47        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks.
 48
 49        **Returns:**
 50        - **reg** (`Tensor`): the mask sparsity regularisation value.
 51        """
 52
 53        if self.mode == "original":
 54            return self.original_reg(mask, previous_cumulative_mask)
 55        
 56        if self.mode == "cross":
 57            return self.cross_reg(mask, previous_cumulative_mask)
 58
 59    def original_reg(
 60        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 61    ) -> Tensor:
 62        r"""Calculate the original mask sparsity regularisation loss in HAT paper.
 63
 64        See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 65
 66        **Args:**
 67        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
 68        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
 69
 70        **Returns:**
 71        - **reg** (`Tensor`): the original mask sparsity regularisation loss.
 72        """
 73
 74        count_available = 0  # number of units available for the new task
 75        count_new_task_occupation_in_available = (
 76            0  # number of units occupied by the new task in the available units
 77        )
 78
 79        network_sparsity = {}
 80
 81        for layer_name in mask.keys():
 82            # statistics through all layers
 83            available = (
 84                1 - previous_cumulative_mask[layer_name]
 85            ).sum()  # count the number of units available for the new task
 86
 87            new_task_occupation_in_available = (
 88                mask[layer_name] * (1 - previous_cumulative_mask[layer_name])
 89            ).sum()
 90            # count the number of units occupied by the new task in the available units
 91
 92            # add to statistics
 93            count_available += available
 94            count_new_task_occupation_in_available += new_task_occupation_in_available
 95            network_sparsity[layer_name] = (
 96                (new_task_occupation_in_available / available) if available > 10 else 0
 97            )
 98
 99        # the mask sparsity regularisation minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
100        reg = (
101            count_new_task_occupation_in_available / count_available
102            if count_available
103            > 10  # to avoid division by a very small number, which makes the regularisation less meaningful
104            else 0
105        )
106
107        return self.factor * reg, network_sparsity
108
109    def cross_reg(
110        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
111    ) -> Tensor:
112        r"""Calculate the cross mask sparsity regularisation loss. This is an attempting improvement by me to the original regularisation, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.
113
114        **Args:**
115        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
116        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
117
118        **Returns:**
119        - **reg** (`Tensor`): the cross mask sparsity regularisation loss.
120        """
121
122        count_previous = 0  # number of units occupied by the previous tasks
123        count_new_task_occupation_in_previous = (
124            0  # number of units occupied by the new task in the previous tasks
125        )
126
127        network_sparsity_2 = {}
128
129        for layer_name in mask.keys():
130            # statistics through all layers
131            previous = previous_cumulative_mask[
132                layer_name
133            ].sum()  # count the number of units occupied by the previous tasks
134            new_task_occupation_in_previous = (
135                mask[layer_name] * previous_cumulative_mask[layer_name].sum()
136            )  # count the number of units occupied by the new task in the previous tasks
137
138            # add to statistics
139            count_previous += previous
140            count_new_task_occupation_in_previous += new_task_occupation_in_previous
141            network_sparsity_2[layer_name] = (
142                (new_task_occupation_in_previous / previous) if previous > 10 else 0
143            )
144
145        # the mask sparsity regularisation maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
146        reg2 = (
147            1 - count_new_task_occupation_in_previous / count_previous
148            if count_previous
149            > 10  # to avoid division by a very small number, which makes the regularisation less meaningful
150            else 0
151        )
152
153        reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask)
154
155        reg = (
156            reg1 + reg2
157        ) / 2  # our cross regularisation is the average of the original and the regularisation proposed above
158
159        network_sparsity = {}
160        for layer_name in mask.keys():
161            # merge the two network sparsity statistics
162            network_sparsity[layer_name] = (
163                network_sparsity_1[layer_name] + network_sparsity_2[layer_name]
164            ) / 2
165
166        return self.factor * reg, network_sparsity
class HATMaskSparsityReg(torch.nn.modules.module.Module):
  9class HATMaskSparsityReg(nn.Module):
 10    r"""Mask sparsity regulariser of HAT (Hard Attention to the Task).
 11
 12    $$
 13    R\left(\textsf{M}^t,\textsf{M}^{<t}\right)=\text{factor} * \frac{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}m_{l,i}^t\left(1-m_{l,i}^{<t}\right)}{\sum_{l=1}^{L-1}\sum_{i=1}^{N_l}\left(1-m_{l,i}^{<t}\right)}
 14    $$
 15
 16    It promotes the low capacity usage that is reflected by occupation of masks in the parameter space.
 17
 18    See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 19    """
 20
 21    def __init__(
 22        self,
 23        factor: float,
 24        mode: str = "original",
 25    ) -> None:
 26        r"""Initialise the regulariser.
 27
 28        **Args:**
 29        - **factor** (`float`): the regularisation factor.
 30        - **mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
 31            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
 32            2. 'cross': the cross version mask sparsity regularisation.
 33        """
 34        nn.Module.__init__(self)
 35
 36        self.factor = factor
 37        """Store the regularisation factor for mask sparsity."""
 38        self.mode = mode
 39        """Store the mode of mask sparsity regularisation."""
 40
 41    def forward(
 42        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 43    ) -> Tensor:
 44        r"""Calculate the mask sparsity regularisation loss.
 45
 46        **Args:**
 47        - **mask** (`dict[str, Tensor]`): the mask for the current task.
 48        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks.
 49
 50        **Returns:**
 51        - **reg** (`Tensor`): the mask sparsity regularisation value.
 52        """
 53
 54        if self.mode == "original":
 55            return self.original_reg(mask, previous_cumulative_mask)
 56        
 57        if self.mode == "cross":
 58            return self.cross_reg(mask, previous_cumulative_mask)
 59
 60    def original_reg(
 61        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 62    ) -> Tensor:
 63        r"""Calculate the original mask sparsity regularisation loss in HAT paper.
 64
 65        See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 66
 67        **Args:**
 68        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
 69        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
 70
 71        **Returns:**
 72        - **reg** (`Tensor`): the original mask sparsity regularisation loss.
 73        """
 74
 75        count_available = 0  # number of units available for the new task
 76        count_new_task_occupation_in_available = (
 77            0  # number of units occupied by the new task in the available units
 78        )
 79
 80        network_sparsity = {}
 81
 82        for layer_name in mask.keys():
 83            # statistics through all layers
 84            available = (
 85                1 - previous_cumulative_mask[layer_name]
 86            ).sum()  # count the number of units available for the new task
 87
 88            new_task_occupation_in_available = (
 89                mask[layer_name] * (1 - previous_cumulative_mask[layer_name])
 90            ).sum()
 91            # count the number of units occupied by the new task in the available units
 92
 93            # add to statistics
 94            count_available += available
 95            count_new_task_occupation_in_available += new_task_occupation_in_available
 96            network_sparsity[layer_name] = (
 97                (new_task_occupation_in_available / available) if available > 10 else 0
 98            )
 99
100        # the mask sparsity regularisation minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
101        reg = (
102            count_new_task_occupation_in_available / count_available
103            if count_available
104            > 10  # to avoid division by a very small number, which makes the regularisation less meaningful
105            else 0
106        )
107
108        return self.factor * reg, network_sparsity
109
110    def cross_reg(
111        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
112    ) -> Tensor:
113        r"""Calculate the cross mask sparsity regularisation loss. This is an attempting improvement by me to the original regularisation, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.
114
115        **Args:**
116        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
117        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
118
119        **Returns:**
120        - **reg** (`Tensor`): the cross mask sparsity regularisation loss.
121        """
122
123        count_previous = 0  # number of units occupied by the previous tasks
124        count_new_task_occupation_in_previous = (
125            0  # number of units occupied by the new task in the previous tasks
126        )
127
128        network_sparsity_2 = {}
129
130        for layer_name in mask.keys():
131            # statistics through all layers
132            previous = previous_cumulative_mask[
133                layer_name
134            ].sum()  # count the number of units occupied by the previous tasks
135            new_task_occupation_in_previous = (
136                mask[layer_name] * previous_cumulative_mask[layer_name].sum()
137            )  # count the number of units occupied by the new task in the previous tasks
138
139            # add to statistics
140            count_previous += previous
141            count_new_task_occupation_in_previous += new_task_occupation_in_previous
142            network_sparsity_2[layer_name] = (
143                (new_task_occupation_in_previous / previous) if previous > 10 else 0
144            )
145
146        # the mask sparsity regularisation maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
147        reg2 = (
148            1 - count_new_task_occupation_in_previous / count_previous
149            if count_previous
150            > 10  # to avoid division by a very small number, which makes the regularisation less meaningful
151            else 0
152        )
153
154        reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask)
155
156        reg = (
157            reg1 + reg2
158        ) / 2  # our cross regularisation is the average of the original and the regularisation proposed above
159
160        network_sparsity = {}
161        for layer_name in mask.keys():
162            # merge the two network sparsity statistics
163            network_sparsity[layer_name] = (
164                network_sparsity_1[layer_name] + network_sparsity_2[layer_name]
165            ) / 2
166
167        return self.factor * reg, network_sparsity

Mask sparsity regulariser of HAT (Hard Attention to the Task).

$$ R\left(\textsf{M}^t,\textsf{M}^{

It promotes the low capacity usage that is reflected by occupation of masks in the parameter space.

See chapter 2.6 "Promoting Low Capacity Usage" in HAT paper.

HATMaskSparsityReg(factor: float, mode: str = 'original')
21    def __init__(
22        self,
23        factor: float,
24        mode: str = "original",
25    ) -> None:
26        r"""Initialise the regulariser.
27
28        **Args:**
29        - **factor** (`float`): the regularisation factor.
30        - **mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
31            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
32            2. 'cross': the cross version mask sparsity regularisation.
33        """
34        nn.Module.__init__(self)
35
36        self.factor = factor
37        """Store the regularisation factor for mask sparsity."""
38        self.mode = mode
39        """Store the mode of mask sparsity regularisation."""

Initialise the regulariser.

Args:

  • factor (float): the regularisation factor.
  • mode (str): the mode of mask sparsity regularisation, should be one of the following:
    1. 'original' (default): the original mask sparsity regularisation in HAT paper.
    2. 'cross': the cross version mask sparsity regularisation.
factor

Store the regularisation factor for mask sparsity.

mode

Store the mode of mask sparsity regularisation.

def forward( self, mask: dict[str, torch.Tensor], previous_cumulative_mask: dict[str, torch.Tensor]) -> torch.Tensor:
41    def forward(
42        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
43    ) -> Tensor:
44        r"""Calculate the mask sparsity regularisation loss.
45
46        **Args:**
47        - **mask** (`dict[str, Tensor]`): the mask for the current task.
48        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks.
49
50        **Returns:**
51        - **reg** (`Tensor`): the mask sparsity regularisation value.
52        """
53
54        if self.mode == "original":
55            return self.original_reg(mask, previous_cumulative_mask)
56        
57        if self.mode == "cross":
58            return self.cross_reg(mask, previous_cumulative_mask)

Calculate the mask sparsity regularisation loss.

Args:

  • mask (dict[str, Tensor]): the mask for the current task.
  • previous_cumulative_mask (dict[str, Tensor]): the cumulative mask for the previous tasks.

Returns:

  • reg (Tensor): the mask sparsity regularisation value.
def original_reg( self, mask: dict[str, torch.Tensor], previous_cumulative_mask: dict[str, torch.Tensor]) -> torch.Tensor:
 60    def original_reg(
 61        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
 62    ) -> Tensor:
 63        r"""Calculate the original mask sparsity regularisation loss in HAT paper.
 64
 65        See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 66
 67        **Args:**
 68        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
 69        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
 70
 71        **Returns:**
 72        - **reg** (`Tensor`): the original mask sparsity regularisation loss.
 73        """
 74
 75        count_available = 0  # number of units available for the new task
 76        count_new_task_occupation_in_available = (
 77            0  # number of units occupied by the new task in the available units
 78        )
 79
 80        network_sparsity = {}
 81
 82        for layer_name in mask.keys():
 83            # statistics through all layers
 84            available = (
 85                1 - previous_cumulative_mask[layer_name]
 86            ).sum()  # count the number of units available for the new task
 87
 88            new_task_occupation_in_available = (
 89                mask[layer_name] * (1 - previous_cumulative_mask[layer_name])
 90            ).sum()
 91            # count the number of units occupied by the new task in the available units
 92
 93            # add to statistics
 94            count_available += available
 95            count_new_task_occupation_in_available += new_task_occupation_in_available
 96            network_sparsity[layer_name] = (
 97                (new_task_occupation_in_available / available) if available > 10 else 0
 98            )
 99
100        # the mask sparsity regularisation minimises the ratio of the number of units occupied by the new task to the number of units available for the new task. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
101        reg = (
102            count_new_task_occupation_in_available / count_available
103            if count_available
104            > 10  # to avoid division by a very small number, which makes the regularisation less meaningful
105            else 0
106        )
107
108        return self.factor * reg, network_sparsity

Calculate the original mask sparsity regularisation loss in HAT paper.

See chapter 2.6 "Promoting Low Capacity Usage" in HAT paper.

Args:

  • mask (dict[str, Tensor]): the mask for the current task. The $\mathrm{A}^t$ in the paper.
  • previous_cumulative_mask (dict[str, Tensor]): the cumulative mask for the previous tasks. The $\mathrm{A}^{

Returns:

  • reg (Tensor): the original mask sparsity regularisation loss.
def cross_reg( self, mask: dict[str, torch.Tensor], previous_cumulative_mask: dict[str, torch.Tensor]) -> torch.Tensor:
110    def cross_reg(
111        self, mask: dict[str, Tensor], previous_cumulative_mask: dict[str, Tensor]
112    ) -> Tensor:
113        r"""Calculate the cross mask sparsity regularisation loss. This is an attempting improvement by me to the original regularisation, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.
114
115        **Args:**
116        - **mask** (`dict[str, Tensor]`): the mask for the current task. The $\mathrm{A}^t$ in the paper.
117        - **previous_cumulative_mask** (`dict[str, Tensor]`): the cumulative mask for the previous tasks. The $\mathrm{A}^{<t}$ in the paper.
118
119        **Returns:**
120        - **reg** (`Tensor`): the cross mask sparsity regularisation loss.
121        """
122
123        count_previous = 0  # number of units occupied by the previous tasks
124        count_new_task_occupation_in_previous = (
125            0  # number of units occupied by the new task in the previous tasks
126        )
127
128        network_sparsity_2 = {}
129
130        for layer_name in mask.keys():
131            # statistics through all layers
132            previous = previous_cumulative_mask[
133                layer_name
134            ].sum()  # count the number of units occupied by the previous tasks
135            new_task_occupation_in_previous = (
136                mask[layer_name] * previous_cumulative_mask[layer_name].sum()
137            )  # count the number of units occupied by the new task in the previous tasks
138
139            # add to statistics
140            count_previous += previous
141            count_new_task_occupation_in_previous += new_task_occupation_in_previous
142            network_sparsity_2[layer_name] = (
143                (new_task_occupation_in_previous / previous) if previous > 10 else 0
144            )
145
146        # the mask sparsity regularisation maximises the ratio of the number of units occupied by the new task to the number of units occupied by the previous tasks. The regularisizer is to let HAT allocates more units from previous tasks to the new task rather than using available units.
147        reg2 = (
148            1 - count_new_task_occupation_in_previous / count_previous
149            if count_previous
150            > 10  # to avoid division by a very small number, which makes the regularisation less meaningful
151            else 0
152        )
153
154        reg1, network_sparsity_1 = self.original_reg(mask, previous_cumulative_mask)
155
156        reg = (
157            reg1 + reg2
158        ) / 2  # our cross regularisation is the average of the original and the regularisation proposed above
159
160        network_sparsity = {}
161        for layer_name in mask.keys():
162            # merge the two network sparsity statistics
163            network_sparsity[layer_name] = (
164                network_sparsity_1[layer_name] + network_sparsity_2[layer_name]
165            ) / 2
166
167        return self.factor * reg, network_sparsity

Calculate the cross mask sparsity regularisation loss. This is an attempting improvement by me to the original regularisation, which not only considers the sparsity in available units but also the density in the units occupied by previous tasks.

Args:

  • mask (dict[str, Tensor]): the mask for the current task. The $\mathrm{A}^t$ in the paper.
  • previous_cumulative_mask (dict[str, Tensor]): the cumulative mask for the previous tasks. The $\mathrm{A}^{

Returns:

  • reg (Tensor): the cross mask sparsity regularisation loss.