clarena.cl_algorithms.adahat

The submodule in cl_algorithms for AdaHAT (Adaptive Hard Attention to the Task) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [AdaHAT (Adaptive Hard Attention to the Task) algorithm](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
  3"""
  4
  5__all__ = ["AdaHAT"]
  6
  7import logging
  8
  9import torch
 10from torch import Tensor
 11
 12from clarena.backbones import HATMaskBackbone
 13from clarena.cl_algorithms import HAT
 14from clarena.cl_heads import HeadsCIL, HeadsTIL
 15from clarena.utils import HATNetworkCapacity
 16
 17# always get logger for built-in logging in each module
 18pylogger = logging.getLogger(__name__)
 19
 20
 21class AdaHAT(HAT):
 22    r"""AdaHAT (Adaptive Hard Attention to the Task) algorithm.
 23
 24    [Adaptive HAT (Adaptive Hard Attention to the Task, 2024)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity.
 25
 26    We implement AdaHAT as a subclass of HAT algorithm, as AdaHAT has the same  `forward()`, `compensate_task_embedding_gradients()`, `training_step()`, `on_train_end()`,`validation_step()`, `test_step()` method as `HAT` class.
 27    """
 28
 29    def __init__(
 30        self,
 31        backbone: HATMaskBackbone,
 32        heads: HeadsTIL | HeadsCIL,
 33        adjustment_mode: str,
 34        adjustment_intensity: float,
 35        s_max: float,
 36        clamp_threshold: float,
 37        mask_sparsity_reg_factor: float,
 38        mask_sparsity_reg_mode: str = "original",
 39        task_embedding_init_mode: str = "N01",
 40        epsilon: float = 0.1,
 41    ) -> None:
 42        r"""Initialise the AdaHAT algorithm with the network.
 43
 44        **Args:**
 45        - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism.
 46        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 47        - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:
 48            1. 'adahat': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach. This is the way that AdaHAT does, which allowes the part of network for previous tasks to be updated slightly. See equation (8) and (9) chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 49            2. 'adahat_no_sum': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of parameter importance i.e. summative mask. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 50            3. 'adahat_no_reg': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of network sparsity i.e. mask sparsity regularisation value. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 51        - **adjustment_intensity** (`float`): hyperparameter, control the overall intensity of gradient adjustment. It's the $\alpha$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 53        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 54        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity.
 55        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
 56            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
 57            2. 'cross': the cross version mask sparsity regularisation.
 58        - **task_embedding_init_mode** (`str`): the initialisation method for task embeddings, should be one of the following:
 59            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 60            2. 'U-11': uniform distribution $U(-1, 1)$.
 61            3. 'U01': uniform distribution $U(0, 1)$.
 62            4. 'U-10': uniform distribution $U(-1, 0)$.
 63            5. 'last': inherit task embedding from last task.
 64        - **epsilon** (`float`): the value added to network sparsity to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 65        """
 66        HAT.__init__(
 67            self,
 68            backbone=backbone,
 69            heads=heads,
 70            adjustment_mode=adjustment_mode,
 71            s_max=s_max,
 72            clamp_threshold=clamp_threshold,
 73            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 74            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 75            task_embedding_init_mode=task_embedding_init_mode,
 76            alpha=None,
 77        )
 78
 79        self.adjustment_intensity = adjustment_intensity
 80        r"""Store the adjustment intensity in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
 81        self.epsilon = epsilon
 82        """Store the small value to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
 83
 84        self.summative_mask_for_previous_tasks: dict[str, Tensor] = {}
 85        r"""Store the summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding summative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
 86
 87        # set manual optimisation
 88        self.automatic_optimization = False
 89
 90        AdaHAT.sanity_check(self)
 91
 92    def sanity_check(self) -> None:
 93        r"""Check the sanity of the arguments.
 94
 95        **Raises:**
 96        - **ValueError**: If the `adjustment_intensity` is not positive.
 97        """
 98        if self.adjustment_intensity <= 0:
 99            raise ValueError(
100                f"The adjustment intensity should be positive, but got {self.adjustment_intensity}."
101            )
102
103    def on_train_start(self) -> None:
104        r"""Additionally initialise the summative mask at the beginning of first task."""
105        HAT.on_train_start(self)
106
107        # initialise the summative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
108        if self.task_id == 1:
109            for layer_name in self.backbone.weighted_layer_names:
110                layer = self.backbone.get_layer_by_name(
111                    layer_name
112                )  # get the layer by its name
113                num_units = layer.weight.shape[0]
114
115                self.summative_mask_for_previous_tasks[layer_name] = torch.zeros(
116                    num_units
117                ).to(
118                    self.device
119                )  # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialised as zeros mask ($t = 1$). See equation (7) in chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
120
121    def clip_grad_by_adjustment(
122        self,
123        network_sparsity: dict[str, Tensor] | None = None,
124    ) -> Tensor:
125        r"""Clip the gradients by the adjustment rate.
126
127        Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.
128
129        Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
130
131        **Args:**
132        - **network_sparsity** (`dict[str, Tensor]` | `None`): The network sparsity i.e. the mask sparsity loss of each layer for the current task. It applies only to AdaHAT modes, as it is used to calculate the adjustment rate for the gradients.
133
134        **Returns:**
135        - **capacity** (`Tensor`): the calculated network capacity.
136        """
137
138        # initialise network capacity metric
139        capacity = HATNetworkCapacity()
140
141        # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists)
142        for layer_name in self.backbone.weighted_layer_names:
143
144            layer = self.backbone.get_layer_by_name(
145                layer_name
146            )  # get the layer by its name
147
148            # placeholder for the adjustment rate to avoid the error of using it before assignment
149            adjustment_rate_weight = 1
150            adjustment_rate_bias = 1
151
152            weight_importance, bias_importance = (
153                self.backbone.get_layer_measure_parameter_wise(
154                    unit_wise_measure=self.summative_mask_for_previous_tasks,
155                    layer_name=layer_name,
156                    aggregation="min",
157                )
158            )  # AdaHAT depend on parameter importance instead of parameter mask like HAT
159
160            network_sparsity_layer = network_sparsity[layer_name]
161
162            if self.adjustment_mode == "adahat":
163                r_layer = self.adjustment_intensity / (
164                    self.epsilon + network_sparsity_layer
165                )
166                adjustment_rate_weight = torch.div(
167                    r_layer, (weight_importance + r_layer)
168                )
169                adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer))
170
171            elif self.adjustment_mode == "adahat_no_sum":
172
173                r_layer = self.adjustment_intensity / (
174                    self.epsilon + network_sparsity_layer
175                )
176                adjustment_rate_weight = torch.div(r_layer, (self.task_id + r_layer))
177                adjustment_rate_bias = torch.div(r_layer, (self.task_id + r_layer))
178
179            elif self.adjustment_mode == "adahat_no_reg":
180
181                r_layer = self.adjustment_intensity / (self.epsilon + 0.0)
182                adjustment_rate_weight = torch.div(
183                    r_layer, (weight_importance + r_layer)
184                )
185                adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer))
186
187            # apply the adjustment rate to the gradients
188            layer.weight.grad.data *= adjustment_rate_weight
189            if layer.bias is not None:
190                layer.bias.grad.data *= adjustment_rate_bias
191
192            # update network capacity metric
193            capacity.update(adjustment_rate_weight, adjustment_rate_bias)
194
195        return capacity.compute()
196
197    def on_train_end(self) -> None:
198        r"""Additionally update summative mask after training the task."""
199
200        HAT.on_train_end(self)
201
202        mask_t = self.masks[
203            f"{self.task_id}"
204        ]  # get stored mask for the current task again
205        self.summative_mask_for_previous_tasks = {
206            layer_name: self.summative_mask_for_previous_tasks[layer_name]
207            + mask_t[layer_name]
208            for layer_name in self.backbone.weighted_layer_names
209        }
class AdaHAT(clarena.cl_algorithms.hat.HAT):
 22class AdaHAT(HAT):
 23    r"""AdaHAT (Adaptive Hard Attention to the Task) algorithm.
 24
 25    [Adaptive HAT (Adaptive Hard Attention to the Task, 2024)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) is an architecture-based continual learning approach that improves [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity.
 26
 27    We implement AdaHAT as a subclass of HAT algorithm, as AdaHAT has the same  `forward()`, `compensate_task_embedding_gradients()`, `training_step()`, `on_train_end()`,`validation_step()`, `test_step()` method as `HAT` class.
 28    """
 29
 30    def __init__(
 31        self,
 32        backbone: HATMaskBackbone,
 33        heads: HeadsTIL | HeadsCIL,
 34        adjustment_mode: str,
 35        adjustment_intensity: float,
 36        s_max: float,
 37        clamp_threshold: float,
 38        mask_sparsity_reg_factor: float,
 39        mask_sparsity_reg_mode: str = "original",
 40        task_embedding_init_mode: str = "N01",
 41        epsilon: float = 0.1,
 42    ) -> None:
 43        r"""Initialise the AdaHAT algorithm with the network.
 44
 45        **Args:**
 46        - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism.
 47        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 48        - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:
 49            1. 'adahat': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach. This is the way that AdaHAT does, which allowes the part of network for previous tasks to be updated slightly. See equation (8) and (9) chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 50            2. 'adahat_no_sum': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of parameter importance i.e. summative mask. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 51            3. 'adahat_no_reg': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of network sparsity i.e. mask sparsity regularisation value. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 52        - **adjustment_intensity** (`float`): hyperparameter, control the overall intensity of gradient adjustment. It's the $\alpha$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 54        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 55        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity.
 56        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
 57            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
 58            2. 'cross': the cross version mask sparsity regularisation.
 59        - **task_embedding_init_mode** (`str`): the initialisation method for task embeddings, should be one of the following:
 60            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 61            2. 'U-11': uniform distribution $U(-1, 1)$.
 62            3. 'U01': uniform distribution $U(0, 1)$.
 63            4. 'U-10': uniform distribution $U(-1, 0)$.
 64            5. 'last': inherit task embedding from last task.
 65        - **epsilon** (`float`): the value added to network sparsity to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 66        """
 67        HAT.__init__(
 68            self,
 69            backbone=backbone,
 70            heads=heads,
 71            adjustment_mode=adjustment_mode,
 72            s_max=s_max,
 73            clamp_threshold=clamp_threshold,
 74            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 75            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 76            task_embedding_init_mode=task_embedding_init_mode,
 77            alpha=None,
 78        )
 79
 80        self.adjustment_intensity = adjustment_intensity
 81        r"""Store the adjustment intensity in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
 82        self.epsilon = epsilon
 83        """Store the small value to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
 84
 85        self.summative_mask_for_previous_tasks: dict[str, Tensor] = {}
 86        r"""Store the summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding summative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
 87
 88        # set manual optimisation
 89        self.automatic_optimization = False
 90
 91        AdaHAT.sanity_check(self)
 92
 93    def sanity_check(self) -> None:
 94        r"""Check the sanity of the arguments.
 95
 96        **Raises:**
 97        - **ValueError**: If the `adjustment_intensity` is not positive.
 98        """
 99        if self.adjustment_intensity <= 0:
100            raise ValueError(
101                f"The adjustment intensity should be positive, but got {self.adjustment_intensity}."
102            )
103
104    def on_train_start(self) -> None:
105        r"""Additionally initialise the summative mask at the beginning of first task."""
106        HAT.on_train_start(self)
107
108        # initialise the summative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
109        if self.task_id == 1:
110            for layer_name in self.backbone.weighted_layer_names:
111                layer = self.backbone.get_layer_by_name(
112                    layer_name
113                )  # get the layer by its name
114                num_units = layer.weight.shape[0]
115
116                self.summative_mask_for_previous_tasks[layer_name] = torch.zeros(
117                    num_units
118                ).to(
119                    self.device
120                )  # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialised as zeros mask ($t = 1$). See equation (7) in chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
121
122    def clip_grad_by_adjustment(
123        self,
124        network_sparsity: dict[str, Tensor] | None = None,
125    ) -> Tensor:
126        r"""Clip the gradients by the adjustment rate.
127
128        Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.
129
130        Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
131
132        **Args:**
133        - **network_sparsity** (`dict[str, Tensor]` | `None`): The network sparsity i.e. the mask sparsity loss of each layer for the current task. It applies only to AdaHAT modes, as it is used to calculate the adjustment rate for the gradients.
134
135        **Returns:**
136        - **capacity** (`Tensor`): the calculated network capacity.
137        """
138
139        # initialise network capacity metric
140        capacity = HATNetworkCapacity()
141
142        # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists)
143        for layer_name in self.backbone.weighted_layer_names:
144
145            layer = self.backbone.get_layer_by_name(
146                layer_name
147            )  # get the layer by its name
148
149            # placeholder for the adjustment rate to avoid the error of using it before assignment
150            adjustment_rate_weight = 1
151            adjustment_rate_bias = 1
152
153            weight_importance, bias_importance = (
154                self.backbone.get_layer_measure_parameter_wise(
155                    unit_wise_measure=self.summative_mask_for_previous_tasks,
156                    layer_name=layer_name,
157                    aggregation="min",
158                )
159            )  # AdaHAT depend on parameter importance instead of parameter mask like HAT
160
161            network_sparsity_layer = network_sparsity[layer_name]
162
163            if self.adjustment_mode == "adahat":
164                r_layer = self.adjustment_intensity / (
165                    self.epsilon + network_sparsity_layer
166                )
167                adjustment_rate_weight = torch.div(
168                    r_layer, (weight_importance + r_layer)
169                )
170                adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer))
171
172            elif self.adjustment_mode == "adahat_no_sum":
173
174                r_layer = self.adjustment_intensity / (
175                    self.epsilon + network_sparsity_layer
176                )
177                adjustment_rate_weight = torch.div(r_layer, (self.task_id + r_layer))
178                adjustment_rate_bias = torch.div(r_layer, (self.task_id + r_layer))
179
180            elif self.adjustment_mode == "adahat_no_reg":
181
182                r_layer = self.adjustment_intensity / (self.epsilon + 0.0)
183                adjustment_rate_weight = torch.div(
184                    r_layer, (weight_importance + r_layer)
185                )
186                adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer))
187
188            # apply the adjustment rate to the gradients
189            layer.weight.grad.data *= adjustment_rate_weight
190            if layer.bias is not None:
191                layer.bias.grad.data *= adjustment_rate_bias
192
193            # update network capacity metric
194            capacity.update(adjustment_rate_weight, adjustment_rate_bias)
195
196        return capacity.compute()
197
198    def on_train_end(self) -> None:
199        r"""Additionally update summative mask after training the task."""
200
201        HAT.on_train_end(self)
202
203        mask_t = self.masks[
204            f"{self.task_id}"
205        ]  # get stored mask for the current task again
206        self.summative_mask_for_previous_tasks = {
207            layer_name: self.summative_mask_for_previous_tasks[layer_name]
208            + mask_t[layer_name]
209            for layer_name in self.backbone.weighted_layer_names
210        }

AdaHAT (Adaptive Hard Attention to the Task) algorithm.

Adaptive HAT (Adaptive Hard Attention to the Task, 2024) is an architecture-based continual learning approach that improves HAT (Hard Attention to the Task, 2018) by introducing new adaptive soft gradient clipping based on parameter importance and network sparsity.

We implement AdaHAT as a subclass of HAT algorithm, as AdaHAT has the same forward(), compensate_task_embedding_gradients(), training_step(), on_train_end(),validation_step(), test_step() method as HAT class.

AdaHAT( backbone: clarena.backbones.HATMaskBackbone, heads: clarena.cl_heads.HeadsTIL | clarena.cl_heads.HeadsCIL, adjustment_mode: str, adjustment_intensity: float, s_max: float, clamp_threshold: float, mask_sparsity_reg_factor: float, mask_sparsity_reg_mode: str = 'original', task_embedding_init_mode: str = 'N01', epsilon: float = 0.1)
30    def __init__(
31        self,
32        backbone: HATMaskBackbone,
33        heads: HeadsTIL | HeadsCIL,
34        adjustment_mode: str,
35        adjustment_intensity: float,
36        s_max: float,
37        clamp_threshold: float,
38        mask_sparsity_reg_factor: float,
39        mask_sparsity_reg_mode: str = "original",
40        task_embedding_init_mode: str = "N01",
41        epsilon: float = 0.1,
42    ) -> None:
43        r"""Initialise the AdaHAT algorithm with the network.
44
45        **Args:**
46        - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism.
47        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
48        - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:
49            1. 'adahat': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach. This is the way that AdaHAT does, which allowes the part of network for previous tasks to be updated slightly. See equation (8) and (9) chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
50            2. 'adahat_no_sum': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of parameter importance i.e. summative mask. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
51            3. 'adahat_no_reg': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of network sparsity i.e. mask sparsity regularisation value. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
52        - **adjustment_intensity** (`float`): hyperparameter, control the overall intensity of gradient adjustment. It's the $\alpha$ in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
53        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
54        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
55        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity.
56        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following:
57            1. 'original' (default): the original mask sparsity regularisation in HAT paper.
58            2. 'cross': the cross version mask sparsity regularisation.
59        - **task_embedding_init_mode** (`str`): the initialisation method for task embeddings, should be one of the following:
60            1. 'N01' (default): standard normal distribution $N(0, 1)$.
61            2. 'U-11': uniform distribution $U(-1, 1)$.
62            3. 'U01': uniform distribution $U(0, 1)$.
63            4. 'U-10': uniform distribution $U(-1, 0)$.
64            5. 'last': inherit task embedding from last task.
65        - **epsilon** (`float`): the value added to network sparsity to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
66        """
67        HAT.__init__(
68            self,
69            backbone=backbone,
70            heads=heads,
71            adjustment_mode=adjustment_mode,
72            s_max=s_max,
73            clamp_threshold=clamp_threshold,
74            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
75            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
76            task_embedding_init_mode=task_embedding_init_mode,
77            alpha=None,
78        )
79
80        self.adjustment_intensity = adjustment_intensity
81        r"""Store the adjustment intensity in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
82        self.epsilon = epsilon
83        """Store the small value to avoid division by zero appeared in equation (9) in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
84
85        self.summative_mask_for_previous_tasks: dict[str, Tensor] = {}
86        r"""Store the summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding summative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """
87
88        # set manual optimisation
89        self.automatic_optimization = False
90
91        AdaHAT.sanity_check(self)

Initialise the AdaHAT algorithm with the network.

Args:

  • backbone (HATMaskBackbone): must be a backbone network with HAT mask mechanism.
  • heads (HeadsTIL | HeadsCIL): output heads.
  • adjustment_mode (str): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:
    1. 'adahat': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach. This is the way that AdaHAT does, which allowes the part of network for previous tasks to be updated slightly. See equation (8) and (9) chapter 3.1 in AdaHAT paper.
    2. 'adahat_no_sum': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of parameter importance i.e. summative mask. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in AdaHAT paper.
    3. 'adahat_no_reg': set the gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach, but without considering the information of network sparsity i.e. mask sparsity regularisation value. This is the way that one of the AdaHAT ablation study does. See chapter 4.3 in AdaHAT paper.
  • adjustment_intensity (float): hyperparameter, control the overall intensity of gradient adjustment. It's the $\alpha$ in equation (9) in AdaHAT paper.
  • s_max (float): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in HAT paper.
  • clamp_threshold (float): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in HAT paper.
  • mask_sparsity_reg_factor (float): hyperparameter, the regularisation factor for mask sparsity.
  • mask_sparsity_reg_mode (str): the mode of mask sparsity regularisation, should be one of the following:
    1. 'original' (default): the original mask sparsity regularisation in HAT paper.
    2. 'cross': the cross version mask sparsity regularisation.
  • task_embedding_init_mode (str): the initialisation method for task embeddings, should be one of the following:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit task embedding from last task.
  • epsilon (float): the value added to network sparsity to avoid division by zero appeared in equation (9) in AdaHAT paper.
adjustment_intensity

Store the adjustment intensity in equation (9) in AdaHAT paper.

epsilon

Store the small value to avoid division by zero appeared in equation (9) in AdaHAT paper.

summative_mask_for_previous_tasks: dict[str, torch.Tensor]

Store the summative binary attention mask $\mathrm{M}^{

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
 93    def sanity_check(self) -> None:
 94        r"""Check the sanity of the arguments.
 95
 96        **Raises:**
 97        - **ValueError**: If the `adjustment_intensity` is not positive.
 98        """
 99        if self.adjustment_intensity <= 0:
100            raise ValueError(
101                f"The adjustment intensity should be positive, but got {self.adjustment_intensity}."
102            )

Check the sanity of the arguments.

Raises:

def on_train_start(self) -> None:
104    def on_train_start(self) -> None:
105        r"""Additionally initialise the summative mask at the beginning of first task."""
106        HAT.on_train_start(self)
107
108        # initialise the summative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
109        if self.task_id == 1:
110            for layer_name in self.backbone.weighted_layer_names:
111                layer = self.backbone.get_layer_by_name(
112                    layer_name
113                )  # get the layer by its name
114                num_units = layer.weight.shape[0]
115
116                self.summative_mask_for_previous_tasks[layer_name] = torch.zeros(
117                    num_units
118                ).to(
119                    self.device
120                )  # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialised as zeros mask ($t = 1$). See equation (7) in chapter 3.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).

Additionally initialise the summative mask at the beginning of first task.

def clip_grad_by_adjustment( self, network_sparsity: dict[str, torch.Tensor] | None = None) -> torch.Tensor:
122    def clip_grad_by_adjustment(
123        self,
124        network_sparsity: dict[str, Tensor] | None = None,
125    ) -> Tensor:
126        r"""Clip the gradients by the adjustment rate.
127
128        Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.
129
130        Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
131
132        **Args:**
133        - **network_sparsity** (`dict[str, Tensor]` | `None`): The network sparsity i.e. the mask sparsity loss of each layer for the current task. It applies only to AdaHAT modes, as it is used to calculate the adjustment rate for the gradients.
134
135        **Returns:**
136        - **capacity** (`Tensor`): the calculated network capacity.
137        """
138
139        # initialise network capacity metric
140        capacity = HATNetworkCapacity()
141
142        # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists)
143        for layer_name in self.backbone.weighted_layer_names:
144
145            layer = self.backbone.get_layer_by_name(
146                layer_name
147            )  # get the layer by its name
148
149            # placeholder for the adjustment rate to avoid the error of using it before assignment
150            adjustment_rate_weight = 1
151            adjustment_rate_bias = 1
152
153            weight_importance, bias_importance = (
154                self.backbone.get_layer_measure_parameter_wise(
155                    unit_wise_measure=self.summative_mask_for_previous_tasks,
156                    layer_name=layer_name,
157                    aggregation="min",
158                )
159            )  # AdaHAT depend on parameter importance instead of parameter mask like HAT
160
161            network_sparsity_layer = network_sparsity[layer_name]
162
163            if self.adjustment_mode == "adahat":
164                r_layer = self.adjustment_intensity / (
165                    self.epsilon + network_sparsity_layer
166                )
167                adjustment_rate_weight = torch.div(
168                    r_layer, (weight_importance + r_layer)
169                )
170                adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer))
171
172            elif self.adjustment_mode == "adahat_no_sum":
173
174                r_layer = self.adjustment_intensity / (
175                    self.epsilon + network_sparsity_layer
176                )
177                adjustment_rate_weight = torch.div(r_layer, (self.task_id + r_layer))
178                adjustment_rate_bias = torch.div(r_layer, (self.task_id + r_layer))
179
180            elif self.adjustment_mode == "adahat_no_reg":
181
182                r_layer = self.adjustment_intensity / (self.epsilon + 0.0)
183                adjustment_rate_weight = torch.div(
184                    r_layer, (weight_importance + r_layer)
185                )
186                adjustment_rate_bias = torch.div(r_layer, (bias_importance + r_layer))
187
188            # apply the adjustment rate to the gradients
189            layer.weight.grad.data *= adjustment_rate_weight
190            if layer.bias is not None:
191                layer.bias.grad.data *= adjustment_rate_bias
192
193            # update network capacity metric
194            capacity.update(adjustment_rate_weight, adjustment_rate_bias)
195
196        return capacity.compute()

Clip the gradients by the adjustment rate.

Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.

Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in AdaHAT paper.

Args:

  • network_sparsity (dict[str, Tensor] | None): The network sparsity i.e. the mask sparsity loss of each layer for the current task. It applies only to AdaHAT modes, as it is used to calculate the adjustment rate for the gradients.

Returns:

  • capacity (Tensor): the calculated network capacity.
def on_train_end(self) -> None:
198    def on_train_end(self) -> None:
199        r"""Additionally update summative mask after training the task."""
200
201        HAT.on_train_end(self)
202
203        mask_t = self.masks[
204            f"{self.task_id}"
205        ]  # get stored mask for the current task again
206        self.summative_mask_for_previous_tasks = {
207            layer_name: self.summative_mask_for_previous_tasks[layer_name]
208            + mask_t[layer_name]
209            for layer_name in self.backbone.weighted_layer_names
210        }

Additionally update summative mask after training the task.