clarena.cl_algorithms.adahat

The submodule in cl_algorithms for AdaHAT (Adaptive Hard Attention to the Task) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) algorithm.
  3"""
  4
  5__all__ = ["AdaHAT"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12
 13from clarena.backbones import HATMaskBackbone
 14from clarena.cl_algorithms import HAT
 15from clarena.heads import HeadsTIL
 16from clarena.heads.head_dil import HeadDIL
 17from clarena.utils.metrics import HATNetworkCapacityMetric
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23class AdaHAT(HAT):
 24    r"""[AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) algorithm.
 25
 26    An architecture-based continual learning approach that improves [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
 27
 28    We implement AdaHAT as a subclass of HAT, as it shares the same `forward()`, `compensate_task_embedding_gradients()`, `training_step()`, `on_train_end()`, `validation_step()`, and `test_step()` methods as the `HAT` class.
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: HATMaskBackbone,
 34        heads: HeadsTIL | HeadDIL,
 35        adjustment_mode: str,
 36        adjustment_intensity: float,
 37        s_max: float,
 38        clamp_threshold: float,
 39        mask_sparsity_reg_factor: float,
 40        mask_sparsity_reg_mode: str = "original",
 41        task_embedding_init_mode: str = "N01",
 42        epsilon: float = 0.1,
 43        non_algorithmic_hparams: dict[str, Any] = {},
 44        **kwargs,
 45    ) -> None:
 46        r"""Initialize the AdaHAT algorithm with the network.
 47
 48        **Args:**
 49        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 50        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 51        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 52            1. 'adahat': set gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach (allows slight updates on previous-task parameters). See Eqs. (8) and (9) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 53            2. 'adahat_no_sum': as above but without parameter-importance (i.e., no summative mask). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54            3. 'adahat_no_reg': as above but without network sparsity (i.e., no mask sparsity regularization term). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 55        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
 56        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 57        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 58        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 59        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 60            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 61            2. 'cross': the cross version of mask sparsity regularization.
 62        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 63            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 64            2. 'U-11': uniform distribution $U(-1, 1)$.
 65            3. 'U01': uniform distribution $U(0, 1)$.
 66            4. 'U-10': uniform distribution $U(-1, 0)$.
 67            5. 'last': inherit the task embedding from the last task.
 68        - **epsilon** (`float`): the value added to network sparsity to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
 69        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 70        - **kwargs**: Reserved for multiple inheritance.
 71        """
 72        super().__init__(
 73            backbone=backbone,
 74            heads=heads,
 75            adjustment_mode=adjustment_mode,
 76            s_max=s_max,
 77            clamp_threshold=clamp_threshold,
 78            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 79            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 80            task_embedding_init_mode=task_embedding_init_mode,
 81            alpha=None,
 82            non_algorithmic_hparams=non_algorithmic_hparams,
 83            **kwargs,
 84        )
 85
 86        self.adjustment_intensity: float = adjustment_intensity
 87        r"""The adjustment intensity in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
 88        self.epsilon: float | None = epsilon
 89        r"""The small value to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9))."""
 90
 91        # save additional algorithmic hyperparameters
 92        self.save_hyperparameters("adjustment_intensity", "epsilon")
 93
 94        self.summative_mask_for_previous_tasks: dict[str, Tensor] = {}
 95        r"""The summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, )."""
 96
 97        # set manual optimization
 98        self.automatic_optimization = False
 99
100        AdaHAT.sanity_check(self)
101
102    def sanity_check(self) -> None:
103        r"""Sanity check."""
104        if self.adjustment_intensity <= 0:
105            raise ValueError(
106                f"The adjustment intensity should be positive, but got {self.adjustment_intensity}."
107            )
108
109    def on_train_start(self) -> None:
110        r"""Additionally initialize the summative mask at the beginning of the first task."""
111        super().on_train_start()
112
113        # initialize the summative mask at the beginning of the first task. This should not be called in `__init__()` method because `self.device` is not available at that time
114        if self.summative_mask_for_previous_tasks == {}:
115            for layer_name in self.backbone.weighted_layer_names:
116                layer = self.backbone.get_layer_by_name(
117                    layer_name
118                )  # get the layer by its name
119                num_units = layer.weight.shape[0]
120
121                self.summative_mask_for_previous_tasks[layer_name] = torch.zeros(
122                    num_units
123                ).to(
124                    self.device
125                )  # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialized as a zeros mask for $t = 1$. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
126
127    def clip_grad_by_adjustment(
128        self,
129        network_sparsity: dict[str, Tensor] | None = None,
130    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
131        r"""Clip the gradients by the adjustment rate. See Eq. (8) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
132
133        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
134
135        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
136
137        **Args:**
138        - **network_sparsity** (`dict[str, Tensor]` | `None`): the network sparsity (i.e., the mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. Applies only to mode `adahat` and `adahat_no_sum`, not `adahat_no_reg`.
139
140        **Returns:**
141        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
142        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
143        - **capacity** (`Tensor`): the calculated network capacity.
144        """
145
146        # initialize network capacity metric
147        capacity = HATNetworkCapacityMetric().to(self.device)
148        adjustment_rate_weight = {}
149        adjustment_rate_bias = {}
150
151        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
152        for layer_name in self.backbone.weighted_layer_names:
153
154            layer = self.backbone.get_layer_by_name(
155                layer_name
156            )  # get the layer by its name
157
158            # placeholder for the adjustment rate to avoid the error of using it before assignment
159            adjustment_rate_weight_layer = 1
160            adjustment_rate_bias_layer = 1
161
162            weight_importance, bias_importance = (
163                self.backbone.get_layer_measure_parameter_wise(
164                    neuron_wise_measure=self.summative_mask_for_previous_tasks,
165                    layer_name=layer_name,
166                    aggregation_mode="min",
167                )
168            )  # AdaHAT depends on parameter importance rather than parameter masks (as in HAT)
169
170            network_sparsity_layer = network_sparsity[layer_name]
171
172            if self.adjustment_mode == "adahat":
173                r_layer = self.adjustment_intensity / (
174                    self.epsilon + network_sparsity_layer
175                )
176                adjustment_rate_weight_layer = torch.div(
177                    r_layer, (weight_importance + r_layer)
178                )
179                adjustment_rate_bias_layer = torch.div(
180                    r_layer, (bias_importance + r_layer)
181                )
182
183            elif self.adjustment_mode == "adahat_no_sum":
184
185                r_layer = self.adjustment_intensity / (
186                    self.epsilon + network_sparsity_layer
187                )
188                adjustment_rate_weight_layer = torch.div(
189                    r_layer, (self.task_id + r_layer)
190                )
191                adjustment_rate_bias_layer = torch.div(
192                    r_layer, (self.task_id + r_layer)
193                )
194
195            elif self.adjustment_mode == "adahat_no_reg":
196
197                r_layer = self.adjustment_intensity / (self.epsilon + 0.0)
198                adjustment_rate_weight_layer = torch.div(
199                    r_layer, (weight_importance + r_layer)
200                )
201                adjustment_rate_bias_layer = torch.div(
202                    r_layer, (bias_importance + r_layer)
203                )
204
205            # apply the adjustment rate to the gradients
206            layer.weight.grad.data *= adjustment_rate_weight_layer
207            if layer.bias is not None:
208                layer.bias.grad.data *= adjustment_rate_bias_layer
209
210            # store the adjustment rate for logging
211            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
212            if layer.bias is not None:
213                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
214
215            # update network capacity metric
216            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
217
218        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
219
220    def on_train_end(self) -> None:
221        r"""Additionally update the summative mask after training the task."""
222        super().on_train_end()
223
224        mask_t = self.backbone.masks[
225            self.task_id
226        ]  # get stored mask for the current task again
227
228        # update the summative mask for previous tasks. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
229        self.summative_mask_for_previous_tasks = {
230            layer_name: self.summative_mask_for_previous_tasks[layer_name]
231            + mask_t[layer_name]
232            for layer_name in self.backbone.weighted_layer_names
233        }
class AdaHAT(clarena.cl_algorithms.hat.HAT):
 24class AdaHAT(HAT):
 25    r"""[AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) algorithm.
 26
 27    An architecture-based continual learning approach that improves [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.
 28
 29    We implement AdaHAT as a subclass of HAT, as it shares the same `forward()`, `compensate_task_embedding_gradients()`, `training_step()`, `on_train_end()`, `validation_step()`, and `test_step()` methods as the `HAT` class.
 30    """
 31
 32    def __init__(
 33        self,
 34        backbone: HATMaskBackbone,
 35        heads: HeadsTIL | HeadDIL,
 36        adjustment_mode: str,
 37        adjustment_intensity: float,
 38        s_max: float,
 39        clamp_threshold: float,
 40        mask_sparsity_reg_factor: float,
 41        mask_sparsity_reg_mode: str = "original",
 42        task_embedding_init_mode: str = "N01",
 43        epsilon: float = 0.1,
 44        non_algorithmic_hparams: dict[str, Any] = {},
 45        **kwargs,
 46    ) -> None:
 47        r"""Initialize the AdaHAT algorithm with the network.
 48
 49        **Args:**
 50        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 51        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 52        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 53            1. 'adahat': set gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach (allows slight updates on previous-task parameters). See Eqs. (8) and (9) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54            2. 'adahat_no_sum': as above but without parameter-importance (i.e., no summative mask). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 55            3. 'adahat_no_reg': as above but without network sparsity (i.e., no mask sparsity regularization term). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 56        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
 57        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 58        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 59        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 60        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 61            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 62            2. 'cross': the cross version of mask sparsity regularization.
 63        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 64            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 65            2. 'U-11': uniform distribution $U(-1, 1)$.
 66            3. 'U01': uniform distribution $U(0, 1)$.
 67            4. 'U-10': uniform distribution $U(-1, 0)$.
 68            5. 'last': inherit the task embedding from the last task.
 69        - **epsilon** (`float`): the value added to network sparsity to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
 70        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 71        - **kwargs**: Reserved for multiple inheritance.
 72        """
 73        super().__init__(
 74            backbone=backbone,
 75            heads=heads,
 76            adjustment_mode=adjustment_mode,
 77            s_max=s_max,
 78            clamp_threshold=clamp_threshold,
 79            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 80            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 81            task_embedding_init_mode=task_embedding_init_mode,
 82            alpha=None,
 83            non_algorithmic_hparams=non_algorithmic_hparams,
 84            **kwargs,
 85        )
 86
 87        self.adjustment_intensity: float = adjustment_intensity
 88        r"""The adjustment intensity in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
 89        self.epsilon: float | None = epsilon
 90        r"""The small value to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9))."""
 91
 92        # save additional algorithmic hyperparameters
 93        self.save_hyperparameters("adjustment_intensity", "epsilon")
 94
 95        self.summative_mask_for_previous_tasks: dict[str, Tensor] = {}
 96        r"""The summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, )."""
 97
 98        # set manual optimization
 99        self.automatic_optimization = False
100
101        AdaHAT.sanity_check(self)
102
103    def sanity_check(self) -> None:
104        r"""Sanity check."""
105        if self.adjustment_intensity <= 0:
106            raise ValueError(
107                f"The adjustment intensity should be positive, but got {self.adjustment_intensity}."
108            )
109
110    def on_train_start(self) -> None:
111        r"""Additionally initialize the summative mask at the beginning of the first task."""
112        super().on_train_start()
113
114        # initialize the summative mask at the beginning of the first task. This should not be called in `__init__()` method because `self.device` is not available at that time
115        if self.summative_mask_for_previous_tasks == {}:
116            for layer_name in self.backbone.weighted_layer_names:
117                layer = self.backbone.get_layer_by_name(
118                    layer_name
119                )  # get the layer by its name
120                num_units = layer.weight.shape[0]
121
122                self.summative_mask_for_previous_tasks[layer_name] = torch.zeros(
123                    num_units
124                ).to(
125                    self.device
126                )  # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialized as a zeros mask for $t = 1$. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
127
128    def clip_grad_by_adjustment(
129        self,
130        network_sparsity: dict[str, Tensor] | None = None,
131    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
132        r"""Clip the gradients by the adjustment rate. See Eq. (8) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
133
134        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
135
136        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
137
138        **Args:**
139        - **network_sparsity** (`dict[str, Tensor]` | `None`): the network sparsity (i.e., the mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. Applies only to mode `adahat` and `adahat_no_sum`, not `adahat_no_reg`.
140
141        **Returns:**
142        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
143        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
144        - **capacity** (`Tensor`): the calculated network capacity.
145        """
146
147        # initialize network capacity metric
148        capacity = HATNetworkCapacityMetric().to(self.device)
149        adjustment_rate_weight = {}
150        adjustment_rate_bias = {}
151
152        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
153        for layer_name in self.backbone.weighted_layer_names:
154
155            layer = self.backbone.get_layer_by_name(
156                layer_name
157            )  # get the layer by its name
158
159            # placeholder for the adjustment rate to avoid the error of using it before assignment
160            adjustment_rate_weight_layer = 1
161            adjustment_rate_bias_layer = 1
162
163            weight_importance, bias_importance = (
164                self.backbone.get_layer_measure_parameter_wise(
165                    neuron_wise_measure=self.summative_mask_for_previous_tasks,
166                    layer_name=layer_name,
167                    aggregation_mode="min",
168                )
169            )  # AdaHAT depends on parameter importance rather than parameter masks (as in HAT)
170
171            network_sparsity_layer = network_sparsity[layer_name]
172
173            if self.adjustment_mode == "adahat":
174                r_layer = self.adjustment_intensity / (
175                    self.epsilon + network_sparsity_layer
176                )
177                adjustment_rate_weight_layer = torch.div(
178                    r_layer, (weight_importance + r_layer)
179                )
180                adjustment_rate_bias_layer = torch.div(
181                    r_layer, (bias_importance + r_layer)
182                )
183
184            elif self.adjustment_mode == "adahat_no_sum":
185
186                r_layer = self.adjustment_intensity / (
187                    self.epsilon + network_sparsity_layer
188                )
189                adjustment_rate_weight_layer = torch.div(
190                    r_layer, (self.task_id + r_layer)
191                )
192                adjustment_rate_bias_layer = torch.div(
193                    r_layer, (self.task_id + r_layer)
194                )
195
196            elif self.adjustment_mode == "adahat_no_reg":
197
198                r_layer = self.adjustment_intensity / (self.epsilon + 0.0)
199                adjustment_rate_weight_layer = torch.div(
200                    r_layer, (weight_importance + r_layer)
201                )
202                adjustment_rate_bias_layer = torch.div(
203                    r_layer, (bias_importance + r_layer)
204                )
205
206            # apply the adjustment rate to the gradients
207            layer.weight.grad.data *= adjustment_rate_weight_layer
208            if layer.bias is not None:
209                layer.bias.grad.data *= adjustment_rate_bias_layer
210
211            # store the adjustment rate for logging
212            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
213            if layer.bias is not None:
214                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
215
216            # update network capacity metric
217            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
218
219        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
220
221    def on_train_end(self) -> None:
222        r"""Additionally update the summative mask after training the task."""
223        super().on_train_end()
224
225        mask_t = self.backbone.masks[
226            self.task_id
227        ]  # get stored mask for the current task again
228
229        # update the summative mask for previous tasks. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
230        self.summative_mask_for_previous_tasks = {
231            layer_name: self.summative_mask_for_previous_tasks[layer_name]
232            + mask_t[layer_name]
233            for layer_name in self.backbone.weighted_layer_names
234        }

AdaHAT (Adaptive Hard Attention to the Task) algorithm.

An architecture-based continual learning approach that improves HAT (Hard Attention to the Task) by introducing adaptive soft gradient clipping based on parameter importance and network sparsity.

We implement AdaHAT as a subclass of HAT, as it shares the same forward(), compensate_task_embedding_gradients(), training_step(), on_train_end(), validation_step(), and test_step() methods as the HAT class.

AdaHAT( backbone: clarena.backbones.HATMaskBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadDIL, adjustment_mode: str, adjustment_intensity: float, s_max: float, clamp_threshold: float, mask_sparsity_reg_factor: float, mask_sparsity_reg_mode: str = 'original', task_embedding_init_mode: str = 'N01', epsilon: float = 0.1, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
 32    def __init__(
 33        self,
 34        backbone: HATMaskBackbone,
 35        heads: HeadsTIL | HeadDIL,
 36        adjustment_mode: str,
 37        adjustment_intensity: float,
 38        s_max: float,
 39        clamp_threshold: float,
 40        mask_sparsity_reg_factor: float,
 41        mask_sparsity_reg_mode: str = "original",
 42        task_embedding_init_mode: str = "N01",
 43        epsilon: float = 0.1,
 44        non_algorithmic_hparams: dict[str, Any] = {},
 45        **kwargs,
 46    ) -> None:
 47        r"""Initialize the AdaHAT algorithm with the network.
 48
 49        **Args:**
 50        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 51        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 52        - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
 53            1. 'adahat': set gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach (allows slight updates on previous-task parameters). See Eqs. (8) and (9) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 54            2. 'adahat_no_sum': as above but without parameter-importance (i.e., no summative mask). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 55            3. 'adahat_no_reg': as above but without network sparsity (i.e., no mask sparsity regularization term). See Sec. 4.3 (ablation study) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 56        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
 57        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 58        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 59        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 60        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 61            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 62            2. 'cross': the cross version of mask sparsity regularization.
 63        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 64            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 65            2. 'U-11': uniform distribution $U(-1, 1)$.
 66            3. 'U01': uniform distribution $U(0, 1)$.
 67            4. 'U-10': uniform distribution $U(-1, 0)$.
 68            5. 'last': inherit the task embedding from the last task.
 69        - **epsilon** (`float`): the value added to network sparsity to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)).
 70        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 71        - **kwargs**: Reserved for multiple inheritance.
 72        """
 73        super().__init__(
 74            backbone=backbone,
 75            heads=heads,
 76            adjustment_mode=adjustment_mode,
 77            s_max=s_max,
 78            clamp_threshold=clamp_threshold,
 79            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 80            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 81            task_embedding_init_mode=task_embedding_init_mode,
 82            alpha=None,
 83            non_algorithmic_hparams=non_algorithmic_hparams,
 84            **kwargs,
 85        )
 86
 87        self.adjustment_intensity: float = adjustment_intensity
 88        r"""The adjustment intensity in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)."""
 89        self.epsilon: float | None = epsilon
 90        r"""The small value to avoid division by zero (appearing in Eq. (9) of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9))."""
 91
 92        # save additional algorithmic hyperparameters
 93        self.save_hyperparameters("adjustment_intensity", "epsilon")
 94
 95        self.summative_mask_for_previous_tasks: dict[str, Tensor] = {}
 96        r"""The summative binary attention mask $\mathrm{M}^{<t,\text{sum}}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, )."""
 97
 98        # set manual optimization
 99        self.automatic_optimization = False
100
101        AdaHAT.sanity_check(self)

Initialize the AdaHAT algorithm with the network.

Args:

  • backbone (HATMaskBackbone): must be a backbone network with the HAT mask mechanism.
  • heads (HeadsTIL | HeadDIL): output heads. AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
  • adjustment_mode (str): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:
    1. 'adahat': set gradients of parameters linking to masked units to a soft adjustment rate in the original AdaHAT approach (allows slight updates on previous-task parameters). See Eqs. (8) and (9) in Sec. 3.1 of the AdaHAT paper.
    2. 'adahat_no_sum': as above but without parameter-importance (i.e., no summative mask). See Sec. 4.3 (ablation study) in the AdaHAT paper.
    3. 'adahat_no_reg': as above but without network sparsity (i.e., no mask sparsity regularization term). See Sec. 4.3 (ablation study) in the AdaHAT paper.
  • adjustment_intensity (float): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in Eq. (9) of the AdaHAT paper).
  • s_max (float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper.
  • clamp_threshold (float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.
  • mask_sparsity_reg_factor (float): hyperparameter, the regularization factor for mask sparsity.
  • mask_sparsity_reg_mode (str): the mode of mask sparsity regularization, must be one of:
    1. 'original' (default): the original mask sparsity regularization in the HAT paper.
    2. 'cross': the cross version of mask sparsity regularization.
  • task_embedding_init_mode (str): the initialization mode for task embeddings, must be one of:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit the task embedding from the last task.
  • epsilon (float): the value added to network sparsity to avoid division by zero (appearing in Eq. (9) of the AdaHAT paper).
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
adjustment_intensity: float

The adjustment intensity in Eq. (9) of the AdaHAT paper.

epsilon: float | None

The small value to avoid division by zero (appearing in Eq. (9) of the AdaHAT paper).

summative_mask_for_previous_tasks: dict[str, torch.Tensor]

The summative binary attention mask $\mathrm{M}^{

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
103    def sanity_check(self) -> None:
104        r"""Sanity check."""
105        if self.adjustment_intensity <= 0:
106            raise ValueError(
107                f"The adjustment intensity should be positive, but got {self.adjustment_intensity}."
108            )

Sanity check.

def on_train_start(self) -> None:
110    def on_train_start(self) -> None:
111        r"""Additionally initialize the summative mask at the beginning of the first task."""
112        super().on_train_start()
113
114        # initialize the summative mask at the beginning of the first task. This should not be called in `__init__()` method because `self.device` is not available at that time
115        if self.summative_mask_for_previous_tasks == {}:
116            for layer_name in self.backbone.weighted_layer_names:
117                layer = self.backbone.get_layer_by_name(
118                    layer_name
119                )  # get the layer by its name
120                num_units = layer.weight.shape[0]
121
122                self.summative_mask_for_previous_tasks[layer_name] = torch.zeros(
123                    num_units
124                ).to(
125                    self.device
126                )  # the summative mask $\mathrm{M}^{<t,\text{sum}}$ is initialized as a zeros mask for $t = 1$. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)

Additionally initialize the summative mask at the beginning of the first task.

def clip_grad_by_adjustment( self, network_sparsity: dict[str, torch.Tensor] | None = None) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], torch.Tensor]:
128    def clip_grad_by_adjustment(
129        self,
130        network_sparsity: dict[str, Tensor] | None = None,
131    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
132        r"""Clip the gradients by the adjustment rate. See Eq. (8) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
133
134        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
135
136        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
137
138        **Args:**
139        - **network_sparsity** (`dict[str, Tensor]` | `None`): the network sparsity (i.e., the mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. Applies only to mode `adahat` and `adahat_no_sum`, not `adahat_no_reg`.
140
141        **Returns:**
142        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
143        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
144        - **capacity** (`Tensor`): the calculated network capacity.
145        """
146
147        # initialize network capacity metric
148        capacity = HATNetworkCapacityMetric().to(self.device)
149        adjustment_rate_weight = {}
150        adjustment_rate_bias = {}
151
152        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (9) in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
153        for layer_name in self.backbone.weighted_layer_names:
154
155            layer = self.backbone.get_layer_by_name(
156                layer_name
157            )  # get the layer by its name
158
159            # placeholder for the adjustment rate to avoid the error of using it before assignment
160            adjustment_rate_weight_layer = 1
161            adjustment_rate_bias_layer = 1
162
163            weight_importance, bias_importance = (
164                self.backbone.get_layer_measure_parameter_wise(
165                    neuron_wise_measure=self.summative_mask_for_previous_tasks,
166                    layer_name=layer_name,
167                    aggregation_mode="min",
168                )
169            )  # AdaHAT depends on parameter importance rather than parameter masks (as in HAT)
170
171            network_sparsity_layer = network_sparsity[layer_name]
172
173            if self.adjustment_mode == "adahat":
174                r_layer = self.adjustment_intensity / (
175                    self.epsilon + network_sparsity_layer
176                )
177                adjustment_rate_weight_layer = torch.div(
178                    r_layer, (weight_importance + r_layer)
179                )
180                adjustment_rate_bias_layer = torch.div(
181                    r_layer, (bias_importance + r_layer)
182                )
183
184            elif self.adjustment_mode == "adahat_no_sum":
185
186                r_layer = self.adjustment_intensity / (
187                    self.epsilon + network_sparsity_layer
188                )
189                adjustment_rate_weight_layer = torch.div(
190                    r_layer, (self.task_id + r_layer)
191                )
192                adjustment_rate_bias_layer = torch.div(
193                    r_layer, (self.task_id + r_layer)
194                )
195
196            elif self.adjustment_mode == "adahat_no_reg":
197
198                r_layer = self.adjustment_intensity / (self.epsilon + 0.0)
199                adjustment_rate_weight_layer = torch.div(
200                    r_layer, (weight_importance + r_layer)
201                )
202                adjustment_rate_bias_layer = torch.div(
203                    r_layer, (bias_importance + r_layer)
204                )
205
206            # apply the adjustment rate to the gradients
207            layer.weight.grad.data *= adjustment_rate_weight_layer
208            if layer.bias is not None:
209                layer.bias.grad.data *= adjustment_rate_bias_layer
210
211            # store the adjustment rate for logging
212            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
213            if layer.bias is not None:
214                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
215
216            # update network capacity metric
217            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
218
219        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()

Clip the gradients by the adjustment rate. See Eq. (8) in the AdaHAT paper.

Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.

Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.

Args:

  • network_sparsity (dict[str, Tensor] | None): the network sparsity (i.e., the mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. Applies only to mode adahat and adahat_no_sum, not adahat_no_reg.

Returns:

  • adjustment_rate_weight (dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors.
  • adjustment_rate_bias (dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors.
  • capacity (Tensor): the calculated network capacity.
def on_train_end(self) -> None:
221    def on_train_end(self) -> None:
222        r"""Additionally update the summative mask after training the task."""
223        super().on_train_end()
224
225        mask_t = self.backbone.masks[
226            self.task_id
227        ]  # get stored mask for the current task again
228
229        # update the summative mask for previous tasks. See Eq. (7) in Sec. 3.1 of the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)
230        self.summative_mask_for_previous_tasks = {
231            layer_name: self.summative_mask_for_previous_tasks[layer_name]
232            + mask_t[layer_name]
233            for layer_name in self.backbone.weighted_layer_names
234        }

Additionally update the summative mask after training the task.