clarena.cl_algorithms.hat
The submodule in cl_algorithms
for HAT (Hard Attention to the Task) algorithm.
1r""" 2The submodule in `cl_algorithms` for [HAT (Hard Attention to the Task) algorithm](http://proceedings.mlr.press/v80/serra18a). 3""" 4 5__all__ = ["HAT"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12from torch.utils.data import DataLoader 13 14from clarena.backbones import HATMaskBackbone 15from clarena.cl_algorithms import CLAlgorithm 16from clarena.cl_algorithms.regularisers import HATMaskSparsityReg 17from clarena.cl_heads import HeadsCIL, HeadsTIL 18from clarena.utils import HATNetworkCapacity 19 20# always get logger for built-in logging in each module 21pylogger = logging.getLogger(__name__) 22 23 24class HAT(CLAlgorithm): 25 r"""HAT (Hard Attention to the Task) algorithm. 26 27 [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters. 28 29 """ 30 31 def __init__( 32 self, 33 backbone: HATMaskBackbone, 34 heads: HeadsTIL | HeadsCIL, 35 adjustment_mode: str, 36 s_max: float, 37 clamp_threshold: float, 38 mask_sparsity_reg_factor: float, 39 mask_sparsity_reg_mode: str = "original", 40 task_embedding_init_mode: str = "N01", 41 alpha: float | None = None, 42 ) -> None: 43 r"""Initialise the HAT algorithm with the network. 44 45 **Args:** 46 - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism. 47 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 48 - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following: 49 1. 'hat': set the gradients of parameters linking to masked units to zero. This is the way that HAT does, which fixes the part of network for previous tasks completely. See equation (2) in chapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 50 2. 'hat_random': set the gradients of parameters linking to masked units to random 0-1 values. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 51 3. 'hat_const_alpha': set the gradients of parameters linking to masked units to a constant value of `alpha`. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 4. 'hat_const_1': set the gradients of parameters linking to masked units to a constant value of 1, which means no gradient constraint on any parameter at all. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 54 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity. 56 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 57 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 58 2. 'cross': the cross version mask sparsity regularisation. 59 - **task_embedding_init_mode** (`str`): the initialisation mode for task embeddings, should be one of the following: 60 1. 'N01' (default): standard normal distribution $N(0, 1)$. 61 2. 'U-11': uniform distribution $U(-1, 1)$. 62 3. 'U01': uniform distribution $U(0, 1)$. 63 4. 'U-10': uniform distribution $U(-1, 0)$. 64 5. 'last': inherit task embedding from last task. 65 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). It applies only when adjustment_mode is 'hat_const_alpha'. 66 """ 67 CLAlgorithm.__init__(self, backbone=backbone, heads=heads) 68 69 self.adjustment_mode = adjustment_mode 70 r"""Store the adjustment mode for gradient clipping.""" 71 self.s_max = s_max 72 r"""Store s_max. """ 73 self.clamp_threshold = clamp_threshold 74 r"""Store the clamp threshold for task embedding gradient compensation.""" 75 self.mask_sparsity_reg_factor = mask_sparsity_reg_factor 76 r"""Store the mask sparsity regularisation factor.""" 77 self.mask_sparsity_reg_mode = mask_sparsity_reg_mode 78 r"""Store the mask sparsity regularisation mode.""" 79 self.mark_sparsity_reg = HATMaskSparsityReg( 80 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 81 ) 82 r"""Initialise and store the mask sparsity regulariser.""" 83 self.task_embedding_init_mode = task_embedding_init_mode 84 r"""Store the task embedding initialisation mode.""" 85 self.alpha = alpha if adjustment_mode == "hat_const_alpha" else None 86 r"""Store the alpha for `hat_const_alpha`.""" 87 self.epsilon = None 88 r"""HAT doesn't use the epsilon for `hat_const_alpha`. We still set it here to be consistent with the `epsilon` in `clip_grad_by_adjustment()` method in `HATMaskBackbone`.""" 89 90 self.masks: dict[str, dict[str, Tensor]] = {} 91 r"""Store the binary attention mask of each previous task gated from the task embedding. Keys are task IDs (string type) and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 92 93 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 94 r"""Store the cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding cumulative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 95 96 # set manual optimisation 97 self.automatic_optimization = False 98 99 HAT.sanity_check(self) 100 101 def sanity_check(self) -> None: 102 r"""Check the sanity of the arguments. 103 104 **Raises:** 105 - **ValueError**: when backbone is not designed for HAT, or the `mask_sparsity_reg_mode` or `task_embedding_init_mode` is not one of the valid options. Also, if `alpha` is not given when `adjustment_mode` is 'hat_const_alpha'. 106 """ 107 if not isinstance(self.backbone, HATMaskBackbone): 108 raise ValueError("The backbone should be an instance of HATMaskBackbone.") 109 110 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 111 raise ValueError( 112 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 113 ) 114 if self.task_embedding_init_mode not in [ 115 "N01", 116 "U01", 117 "U-10", 118 "masked", 119 "unmasked", 120 ]: 121 raise ValueError( 122 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 123 ) 124 125 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 126 raise ValueError( 127 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 128 ) 129 130 def on_train_start(self) -> None: 131 r"""Initialise the task embedding before training the next task and initialise the cumulative mask at the beginning of first task.""" 132 133 self.backbone.initialise_task_embedding(mode=self.task_embedding_init_mode) 134 135 # initialise the cumulative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 136 if self.task_id == 1: 137 for layer_name in self.backbone.weighted_layer_names: 138 layer = self.backbone.get_layer_by_name( 139 layer_name 140 ) # get the layer by its name 141 num_units = layer.weight.shape[0] 142 143 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 144 num_units 145 ).to( 146 self.device 147 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialised as zeros mask ($t = 1$). See equation (2) in chapter 3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or equation (5) in chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 148 149 def clip_grad_by_adjustment( 150 self, 151 **kwargs, 152 ) -> Tensor: 153 r"""Clip the gradients by the adjustment rate. 154 155 Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes. 156 157 Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 158 159 160 **Returns:** 161 - **capacity** (`Tensor`): the calculated network capacity. 162 """ 163 164 # initialise network capacity metric 165 capacity = HATNetworkCapacity() 166 167 # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists) 168 for layer_name in self.backbone.weighted_layer_names: 169 170 layer = self.backbone.get_layer_by_name( 171 layer_name 172 ) # get the layer by its name 173 174 # placeholder for the adjustment rate to avoid the error of using it before assignment 175 adjustment_rate_weight = 1 176 adjustment_rate_bias = 1 177 178 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 179 unit_wise_measure=self.cumulative_mask_for_previous_tasks, 180 layer_name=layer_name, 181 aggregation="min", 182 ) 183 184 if self.adjustment_mode == "hat": 185 adjustment_rate_weight = 1 - weight_mask 186 adjustment_rate_bias = 1 - bias_mask 187 188 elif self.adjustment_mode == "hat_random": 189 adjustment_rate_weight = torch.rand_like(weight_mask) * weight_mask + ( 190 1 - weight_mask 191 ) 192 adjustment_rate_bias = torch.rand_like(bias_mask) * bias_mask + ( 193 1 - bias_mask 194 ) 195 196 elif self.adjustment_mode == "hat_const_alpha": 197 adjustment_rate_weight = self.alpha * torch.ones_like( 198 weight_mask 199 ) * weight_mask + (1 - weight_mask) 200 adjustment_rate_bias = self.alpha * torch.ones_like( 201 bias_mask 202 ) * bias_mask + (1 - bias_mask) 203 204 elif self.adjustment_mode == "hat_const_1": 205 adjustment_rate_weight = torch.ones_like(weight_mask) * weight_mask + ( 206 1 - weight_mask 207 ) 208 adjustment_rate_bias = torch.ones_like(bias_mask) * bias_mask + ( 209 1 - bias_mask 210 ) 211 212 # apply the adjustment rate to the gradients 213 layer.weight.grad.data *= adjustment_rate_weight 214 if layer.bias is not None: 215 layer.bias.grad.data *= adjustment_rate_bias 216 217 # update network capacity metric 218 capacity.update(adjustment_rate_weight, adjustment_rate_bias) 219 220 return capacity.compute() 221 222 def compensate_task_embedding_gradients( 223 self, 224 batch_idx: int, 225 num_batches: int, 226 ) -> None: 227 r"""Compensate the gradients of task embeddings during training. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 228 229 **Args:** 230 - **batch_idx** (`int`): the current training batch index. 231 - **num_batches** (`int`): the total number of training batches. 232 """ 233 234 for te in self.backbone.task_embedding_t.values(): 235 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 236 batch_idx - 1 237 ) / ( 238 num_batches - 1 239 ) # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a) 240 241 num = ( 242 torch.cosh( 243 torch.clamp( 244 anneal_scalar * te.weight.data, 245 -self.clamp_threshold, 246 self.clamp_threshold, 247 ) 248 ) 249 + 1 250 ) 251 252 den = torch.cosh(te.weight.data) + 1 253 254 compensation = self.s_max / anneal_scalar * num / den 255 256 te.weight.grad.data *= compensation 257 258 def forward( 259 self, 260 input: torch.Tensor, 261 stage: str, 262 batch_idx: int | None = None, 263 num_batches: int | None = None, 264 task_id: int | None = None, 265 ) -> tuple[Tensor, dict[str, Tensor]]: 266 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 267 268 **Args:** 269 - **input** (`Tensor`): The input tensor from data. 270 - **stage** (`str`): the stage of the forward pass, should be one of the following: 271 1. 'train': training stage. 272 2. 'validation': validation stage. 273 3. 'test': testing stage. 274 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 275 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 276 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 277 278 **Returns:** 279 - **logits** (`Tensor`): the output logits tensor. 280 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 281 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 282 """ 283 feature, mask, hidden_features = self.backbone( 284 input, 285 stage=stage, 286 s_max=self.s_max if stage == "train" or stage == "validation" else None, 287 batch_idx=batch_idx if stage == "train" else None, 288 num_batches=num_batches if stage == "train" else None, 289 test_mask=self.masks[f"{task_id}"] if stage == "test" else None, 290 ) 291 logits = self.heads(feature, task_id) 292 293 return logits, mask, hidden_features 294 295 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 296 r"""Training step for current task `self.task_id`. 297 298 **Args:** 299 - **batch** (`Any`): a batch of training data. 300 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 301 302 **Returns:** 303 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For HAT, it includes 'mask' and 'capacity' for logging. 304 """ 305 x, y = batch 306 307 # zero the gradients before forward pass in manual optimisation mode 308 opt = self.optimizers() 309 opt.zero_grad() 310 311 # classification loss 312 num_batches = self.trainer.num_training_batches 313 logits, mask, hidden_features = self.forward( 314 x, 315 stage="train", 316 batch_idx=batch_idx, 317 num_batches=num_batches, 318 task_id=self.task_id, 319 ) 320 loss_cls = self.criterion(logits, y) 321 322 # regularisation loss. See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 323 loss_reg, network_sparsity = self.mark_sparsity_reg( 324 mask, self.cumulative_mask_for_previous_tasks 325 ) 326 327 # total loss 328 loss = loss_cls + loss_reg 329 330 # backward step (manually) 331 self.manual_backward(loss) # calculate the gradients 332 # HAT hard clip gradients by the cumulative masks. See equation (2) inchapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). Network capacity is calculated along with this process. Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 333 capacity = self.clip_grad_by_adjustment( 334 network_sparsity=network_sparsity, # pass a keyword argument network sparsity here to make it compatible with AdaHAT. AdaHAT inherits this `training_step()` method. 335 ) 336 # compensate the gradients of task embedding. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 337 self.compensate_task_embedding_gradients( 338 batch_idx=batch_idx, 339 num_batches=num_batches, 340 ) 341 # update parameters with the modified gradients 342 opt.step() 343 344 # accuracy of the batch 345 acc = (logits.argmax(dim=1) == y).float().mean() 346 347 return { 348 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 349 "loss_cls": loss_cls, 350 "loss_reg": loss_reg, 351 "acc": acc, 352 "hidden_features": hidden_features, 353 "mask": mask, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 354 "capacity": capacity, 355 } 356 357 def on_train_end(self) -> None: 358 r"""Store the mask and update cumulative mask after training the task.""" 359 360 # store the mask for the current task 361 mask_t = { 362 layer_name: self.backbone.gate_fn( 363 self.backbone.task_embedding_t[layer_name].weight * self.s_max 364 ) 365 .squeeze() 366 .detach() 367 for layer_name in self.backbone.weighted_layer_names 368 } 369 370 self.masks[f"{self.task_id}"] = mask_t 371 372 # update the cumulative and summative masks 373 self.cumulative_mask_for_previous_tasks = { 374 layer_name: torch.max( 375 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 376 ) 377 for layer_name in self.backbone.weighted_layer_names 378 } 379 380 def validation_step(self, batch: Any) -> dict[str, Tensor]: 381 r"""Validation step for current task `self.task_id`. 382 383 **Args:** 384 - **batch** (`Any`): a batch of validation data. 385 386 **Returns:** 387 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. 388 """ 389 x, y = batch 390 logits, mask, hidden_features = self.forward( 391 x, stage="validation", task_id=self.task_id 392 ) 393 loss_cls = self.criterion(logits, y) 394 acc = (logits.argmax(dim=1) == y).float().mean() 395 396 return { 397 "loss_cls": loss_cls, 398 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 399 } 400 401 def test_step( 402 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 403 ) -> dict[str, Tensor]: 404 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 405 406 **Args:** 407 - **batch** (`Any`): a batch of test data. 408 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 409 410 **Returns:** 411 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. 412 """ 413 test_task_id = dataloader_idx + 1 414 415 x, y = batch 416 logits, mask, hidden_features = self.forward( 417 x, 418 stage="test", 419 task_id=test_task_id, 420 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 421 loss_cls = self.criterion(logits, y) 422 acc = (logits.argmax(dim=1) == y).float().mean() 423 424 return { 425 "loss_cls": loss_cls, 426 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 427 }
25class HAT(CLAlgorithm): 26 r"""HAT (Hard Attention to the Task) algorithm. 27 28 [HAT (Hard Attention to the Task, 2018)](http://proceedings.mlr.press/v80/serra18a) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters. 29 30 """ 31 32 def __init__( 33 self, 34 backbone: HATMaskBackbone, 35 heads: HeadsTIL | HeadsCIL, 36 adjustment_mode: str, 37 s_max: float, 38 clamp_threshold: float, 39 mask_sparsity_reg_factor: float, 40 mask_sparsity_reg_mode: str = "original", 41 task_embedding_init_mode: str = "N01", 42 alpha: float | None = None, 43 ) -> None: 44 r"""Initialise the HAT algorithm with the network. 45 46 **Args:** 47 - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism. 48 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 49 - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following: 50 1. 'hat': set the gradients of parameters linking to masked units to zero. This is the way that HAT does, which fixes the part of network for previous tasks completely. See equation (2) in chapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 51 2. 'hat_random': set the gradients of parameters linking to masked units to random 0-1 values. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 3. 'hat_const_alpha': set the gradients of parameters linking to masked units to a constant value of `alpha`. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 4. 'hat_const_1': set the gradients of parameters linking to masked units to a constant value of 1, which means no gradient constraint on any parameter at all. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 56 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity. 57 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 58 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 59 2. 'cross': the cross version mask sparsity regularisation. 60 - **task_embedding_init_mode** (`str`): the initialisation mode for task embeddings, should be one of the following: 61 1. 'N01' (default): standard normal distribution $N(0, 1)$. 62 2. 'U-11': uniform distribution $U(-1, 1)$. 63 3. 'U01': uniform distribution $U(0, 1)$. 64 4. 'U-10': uniform distribution $U(-1, 0)$. 65 5. 'last': inherit task embedding from last task. 66 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). It applies only when adjustment_mode is 'hat_const_alpha'. 67 """ 68 CLAlgorithm.__init__(self, backbone=backbone, heads=heads) 69 70 self.adjustment_mode = adjustment_mode 71 r"""Store the adjustment mode for gradient clipping.""" 72 self.s_max = s_max 73 r"""Store s_max. """ 74 self.clamp_threshold = clamp_threshold 75 r"""Store the clamp threshold for task embedding gradient compensation.""" 76 self.mask_sparsity_reg_factor = mask_sparsity_reg_factor 77 r"""Store the mask sparsity regularisation factor.""" 78 self.mask_sparsity_reg_mode = mask_sparsity_reg_mode 79 r"""Store the mask sparsity regularisation mode.""" 80 self.mark_sparsity_reg = HATMaskSparsityReg( 81 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 82 ) 83 r"""Initialise and store the mask sparsity regulariser.""" 84 self.task_embedding_init_mode = task_embedding_init_mode 85 r"""Store the task embedding initialisation mode.""" 86 self.alpha = alpha if adjustment_mode == "hat_const_alpha" else None 87 r"""Store the alpha for `hat_const_alpha`.""" 88 self.epsilon = None 89 r"""HAT doesn't use the epsilon for `hat_const_alpha`. We still set it here to be consistent with the `epsilon` in `clip_grad_by_adjustment()` method in `HATMaskBackbone`.""" 90 91 self.masks: dict[str, dict[str, Tensor]] = {} 92 r"""Store the binary attention mask of each previous task gated from the task embedding. Keys are task IDs (string type) and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 93 94 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 95 r"""Store the cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding cumulative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 96 97 # set manual optimisation 98 self.automatic_optimization = False 99 100 HAT.sanity_check(self) 101 102 def sanity_check(self) -> None: 103 r"""Check the sanity of the arguments. 104 105 **Raises:** 106 - **ValueError**: when backbone is not designed for HAT, or the `mask_sparsity_reg_mode` or `task_embedding_init_mode` is not one of the valid options. Also, if `alpha` is not given when `adjustment_mode` is 'hat_const_alpha'. 107 """ 108 if not isinstance(self.backbone, HATMaskBackbone): 109 raise ValueError("The backbone should be an instance of HATMaskBackbone.") 110 111 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 112 raise ValueError( 113 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 114 ) 115 if self.task_embedding_init_mode not in [ 116 "N01", 117 "U01", 118 "U-10", 119 "masked", 120 "unmasked", 121 ]: 122 raise ValueError( 123 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 124 ) 125 126 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 127 raise ValueError( 128 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 129 ) 130 131 def on_train_start(self) -> None: 132 r"""Initialise the task embedding before training the next task and initialise the cumulative mask at the beginning of first task.""" 133 134 self.backbone.initialise_task_embedding(mode=self.task_embedding_init_mode) 135 136 # initialise the cumulative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 137 if self.task_id == 1: 138 for layer_name in self.backbone.weighted_layer_names: 139 layer = self.backbone.get_layer_by_name( 140 layer_name 141 ) # get the layer by its name 142 num_units = layer.weight.shape[0] 143 144 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 145 num_units 146 ).to( 147 self.device 148 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialised as zeros mask ($t = 1$). See equation (2) in chapter 3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or equation (5) in chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 149 150 def clip_grad_by_adjustment( 151 self, 152 **kwargs, 153 ) -> Tensor: 154 r"""Clip the gradients by the adjustment rate. 155 156 Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes. 157 158 Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 159 160 161 **Returns:** 162 - **capacity** (`Tensor`): the calculated network capacity. 163 """ 164 165 # initialise network capacity metric 166 capacity = HATNetworkCapacity() 167 168 # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists) 169 for layer_name in self.backbone.weighted_layer_names: 170 171 layer = self.backbone.get_layer_by_name( 172 layer_name 173 ) # get the layer by its name 174 175 # placeholder for the adjustment rate to avoid the error of using it before assignment 176 adjustment_rate_weight = 1 177 adjustment_rate_bias = 1 178 179 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 180 unit_wise_measure=self.cumulative_mask_for_previous_tasks, 181 layer_name=layer_name, 182 aggregation="min", 183 ) 184 185 if self.adjustment_mode == "hat": 186 adjustment_rate_weight = 1 - weight_mask 187 adjustment_rate_bias = 1 - bias_mask 188 189 elif self.adjustment_mode == "hat_random": 190 adjustment_rate_weight = torch.rand_like(weight_mask) * weight_mask + ( 191 1 - weight_mask 192 ) 193 adjustment_rate_bias = torch.rand_like(bias_mask) * bias_mask + ( 194 1 - bias_mask 195 ) 196 197 elif self.adjustment_mode == "hat_const_alpha": 198 adjustment_rate_weight = self.alpha * torch.ones_like( 199 weight_mask 200 ) * weight_mask + (1 - weight_mask) 201 adjustment_rate_bias = self.alpha * torch.ones_like( 202 bias_mask 203 ) * bias_mask + (1 - bias_mask) 204 205 elif self.adjustment_mode == "hat_const_1": 206 adjustment_rate_weight = torch.ones_like(weight_mask) * weight_mask + ( 207 1 - weight_mask 208 ) 209 adjustment_rate_bias = torch.ones_like(bias_mask) * bias_mask + ( 210 1 - bias_mask 211 ) 212 213 # apply the adjustment rate to the gradients 214 layer.weight.grad.data *= adjustment_rate_weight 215 if layer.bias is not None: 216 layer.bias.grad.data *= adjustment_rate_bias 217 218 # update network capacity metric 219 capacity.update(adjustment_rate_weight, adjustment_rate_bias) 220 221 return capacity.compute() 222 223 def compensate_task_embedding_gradients( 224 self, 225 batch_idx: int, 226 num_batches: int, 227 ) -> None: 228 r"""Compensate the gradients of task embeddings during training. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 229 230 **Args:** 231 - **batch_idx** (`int`): the current training batch index. 232 - **num_batches** (`int`): the total number of training batches. 233 """ 234 235 for te in self.backbone.task_embedding_t.values(): 236 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 237 batch_idx - 1 238 ) / ( 239 num_batches - 1 240 ) # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a) 241 242 num = ( 243 torch.cosh( 244 torch.clamp( 245 anneal_scalar * te.weight.data, 246 -self.clamp_threshold, 247 self.clamp_threshold, 248 ) 249 ) 250 + 1 251 ) 252 253 den = torch.cosh(te.weight.data) + 1 254 255 compensation = self.s_max / anneal_scalar * num / den 256 257 te.weight.grad.data *= compensation 258 259 def forward( 260 self, 261 input: torch.Tensor, 262 stage: str, 263 batch_idx: int | None = None, 264 num_batches: int | None = None, 265 task_id: int | None = None, 266 ) -> tuple[Tensor, dict[str, Tensor]]: 267 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 268 269 **Args:** 270 - **input** (`Tensor`): The input tensor from data. 271 - **stage** (`str`): the stage of the forward pass, should be one of the following: 272 1. 'train': training stage. 273 2. 'validation': validation stage. 274 3. 'test': testing stage. 275 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 276 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 277 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 278 279 **Returns:** 280 - **logits** (`Tensor`): the output logits tensor. 281 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 282 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 283 """ 284 feature, mask, hidden_features = self.backbone( 285 input, 286 stage=stage, 287 s_max=self.s_max if stage == "train" or stage == "validation" else None, 288 batch_idx=batch_idx if stage == "train" else None, 289 num_batches=num_batches if stage == "train" else None, 290 test_mask=self.masks[f"{task_id}"] if stage == "test" else None, 291 ) 292 logits = self.heads(feature, task_id) 293 294 return logits, mask, hidden_features 295 296 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 297 r"""Training step for current task `self.task_id`. 298 299 **Args:** 300 - **batch** (`Any`): a batch of training data. 301 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 302 303 **Returns:** 304 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For HAT, it includes 'mask' and 'capacity' for logging. 305 """ 306 x, y = batch 307 308 # zero the gradients before forward pass in manual optimisation mode 309 opt = self.optimizers() 310 opt.zero_grad() 311 312 # classification loss 313 num_batches = self.trainer.num_training_batches 314 logits, mask, hidden_features = self.forward( 315 x, 316 stage="train", 317 batch_idx=batch_idx, 318 num_batches=num_batches, 319 task_id=self.task_id, 320 ) 321 loss_cls = self.criterion(logits, y) 322 323 # regularisation loss. See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 324 loss_reg, network_sparsity = self.mark_sparsity_reg( 325 mask, self.cumulative_mask_for_previous_tasks 326 ) 327 328 # total loss 329 loss = loss_cls + loss_reg 330 331 # backward step (manually) 332 self.manual_backward(loss) # calculate the gradients 333 # HAT hard clip gradients by the cumulative masks. See equation (2) inchapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). Network capacity is calculated along with this process. Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 334 capacity = self.clip_grad_by_adjustment( 335 network_sparsity=network_sparsity, # pass a keyword argument network sparsity here to make it compatible with AdaHAT. AdaHAT inherits this `training_step()` method. 336 ) 337 # compensate the gradients of task embedding. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 338 self.compensate_task_embedding_gradients( 339 batch_idx=batch_idx, 340 num_batches=num_batches, 341 ) 342 # update parameters with the modified gradients 343 opt.step() 344 345 # accuracy of the batch 346 acc = (logits.argmax(dim=1) == y).float().mean() 347 348 return { 349 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 350 "loss_cls": loss_cls, 351 "loss_reg": loss_reg, 352 "acc": acc, 353 "hidden_features": hidden_features, 354 "mask": mask, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 355 "capacity": capacity, 356 } 357 358 def on_train_end(self) -> None: 359 r"""Store the mask and update cumulative mask after training the task.""" 360 361 # store the mask for the current task 362 mask_t = { 363 layer_name: self.backbone.gate_fn( 364 self.backbone.task_embedding_t[layer_name].weight * self.s_max 365 ) 366 .squeeze() 367 .detach() 368 for layer_name in self.backbone.weighted_layer_names 369 } 370 371 self.masks[f"{self.task_id}"] = mask_t 372 373 # update the cumulative and summative masks 374 self.cumulative_mask_for_previous_tasks = { 375 layer_name: torch.max( 376 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 377 ) 378 for layer_name in self.backbone.weighted_layer_names 379 } 380 381 def validation_step(self, batch: Any) -> dict[str, Tensor]: 382 r"""Validation step for current task `self.task_id`. 383 384 **Args:** 385 - **batch** (`Any`): a batch of validation data. 386 387 **Returns:** 388 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. 389 """ 390 x, y = batch 391 logits, mask, hidden_features = self.forward( 392 x, stage="validation", task_id=self.task_id 393 ) 394 loss_cls = self.criterion(logits, y) 395 acc = (logits.argmax(dim=1) == y).float().mean() 396 397 return { 398 "loss_cls": loss_cls, 399 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 400 } 401 402 def test_step( 403 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 404 ) -> dict[str, Tensor]: 405 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 406 407 **Args:** 408 - **batch** (`Any`): a batch of test data. 409 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 410 411 **Returns:** 412 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. 413 """ 414 test_task_id = dataloader_idx + 1 415 416 x, y = batch 417 logits, mask, hidden_features = self.forward( 418 x, 419 stage="test", 420 task_id=test_task_id, 421 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 422 loss_cls = self.criterion(logits, y) 423 acc = (logits.argmax(dim=1) == y).float().mean() 424 425 return { 426 "loss_cls": loss_cls, 427 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 428 }
HAT (Hard Attention to the Task) algorithm.
HAT (Hard Attention to the Task, 2018) is an architecture-based continual learning approach that uses learnable hard attention masks to select the task-specific parameters.
32 def __init__( 33 self, 34 backbone: HATMaskBackbone, 35 heads: HeadsTIL | HeadsCIL, 36 adjustment_mode: str, 37 s_max: float, 38 clamp_threshold: float, 39 mask_sparsity_reg_factor: float, 40 mask_sparsity_reg_mode: str = "original", 41 task_embedding_init_mode: str = "N01", 42 alpha: float | None = None, 43 ) -> None: 44 r"""Initialise the HAT algorithm with the network. 45 46 **Args:** 47 - **backbone** (`HATMaskBackbone`): must be a backbone network with HAT mask mechanism. 48 - **heads** (`HeadsTIL` | `HeadsCIL`): output heads. 49 - **adjustment_mode** (`str`): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following: 50 1. 'hat': set the gradients of parameters linking to masked units to zero. This is the way that HAT does, which fixes the part of network for previous tasks completely. See equation (2) in chapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 51 2. 'hat_random': set the gradients of parameters linking to masked units to random 0-1 values. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 3. 'hat_const_alpha': set the gradients of parameters linking to masked units to a constant value of `alpha`. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 4. 'hat_const_1': set the gradients of parameters linking to masked units to a constant value of 1, which means no gradient constraint on any parameter at all. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 56 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularisation factor for mask sparsity. 57 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularisation, should be one of the following: 58 1. 'original' (default): the original mask sparsity regularisation in HAT paper. 59 2. 'cross': the cross version mask sparsity regularisation. 60 - **task_embedding_init_mode** (`str`): the initialisation mode for task embeddings, should be one of the following: 61 1. 'N01' (default): standard normal distribution $N(0, 1)$. 62 2. 'U-11': uniform distribution $U(-1, 1)$. 63 3. 'U01': uniform distribution $U(0, 1)$. 64 4. 'U-10': uniform distribution $U(-1, 0)$. 65 5. 'last': inherit task embedding from last task. 66 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. See the "Baselines" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). It applies only when adjustment_mode is 'hat_const_alpha'. 67 """ 68 CLAlgorithm.__init__(self, backbone=backbone, heads=heads) 69 70 self.adjustment_mode = adjustment_mode 71 r"""Store the adjustment mode for gradient clipping.""" 72 self.s_max = s_max 73 r"""Store s_max. """ 74 self.clamp_threshold = clamp_threshold 75 r"""Store the clamp threshold for task embedding gradient compensation.""" 76 self.mask_sparsity_reg_factor = mask_sparsity_reg_factor 77 r"""Store the mask sparsity regularisation factor.""" 78 self.mask_sparsity_reg_mode = mask_sparsity_reg_mode 79 r"""Store the mask sparsity regularisation mode.""" 80 self.mark_sparsity_reg = HATMaskSparsityReg( 81 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 82 ) 83 r"""Initialise and store the mask sparsity regulariser.""" 84 self.task_embedding_init_mode = task_embedding_init_mode 85 r"""Store the task embedding initialisation mode.""" 86 self.alpha = alpha if adjustment_mode == "hat_const_alpha" else None 87 r"""Store the alpha for `hat_const_alpha`.""" 88 self.epsilon = None 89 r"""HAT doesn't use the epsilon for `hat_const_alpha`. We still set it here to be consistent with the `epsilon` in `clip_grad_by_adjustment()` method in `HATMaskBackbone`.""" 90 91 self.masks: dict[str, dict[str, Tensor]] = {} 92 r"""Store the binary attention mask of each previous task gated from the task embedding. Keys are task IDs (string type) and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 93 94 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 95 r"""Store the cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding. Keys are task IDs and values are the corresponding cumulative mask. Each cumulative mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units). """ 96 97 # set manual optimisation 98 self.automatic_optimization = False 99 100 HAT.sanity_check(self)
Initialise the HAT algorithm with the network.
Args:
- backbone (
HATMaskBackbone
): must be a backbone network with HAT mask mechanism. - heads (
HeadsTIL
|HeadsCIL
): output heads. - adjustment_mode (
str
): the strategy of adjustment i.e. the mode of gradient clipping, should be one of the following:- 'hat': set the gradients of parameters linking to masked units to zero. This is the way that HAT does, which fixes the part of network for previous tasks completely. See equation (2) in chapter 2.3 "Network Training" in HAT paper.
- 'hat_random': set the gradients of parameters linking to masked units to random 0-1 values. See the "Baselines" section in chapter 4.1 in AdaHAT paper.
- 'hat_const_alpha': set the gradients of parameters linking to masked units to a constant value of
alpha
. See the "Baselines" section in chapter 4.1 in AdaHAT paper. - 'hat_const_1': set the gradients of parameters linking to masked units to a constant value of 1, which means no gradient constraint on any parameter at all. See the "Baselines" section in chapter 4.1 in AdaHAT paper.
- s_max (
float
): hyperparameter, the maximum scaling factor in the gate function. See chapter 2.4 "Hard Attention Training" in HAT paper. - clamp_threshold (
float
): the threshold for task embedding gradient compensation. See chapter 2.5 "Embedding Gradient Compensation" in HAT paper. - mask_sparsity_reg_factor (
float
): hyperparameter, the regularisation factor for mask sparsity. - mask_sparsity_reg_mode (
str
): the mode of mask sparsity regularisation, should be one of the following:- 'original' (default): the original mask sparsity regularisation in HAT paper.
- 'cross': the cross version mask sparsity regularisation.
- task_embedding_init_mode (
str
): the initialisation mode for task embeddings, should be one of the following:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit task embedding from last task.
- alpha (
float
|None
): thealpha
in the 'HAT-const-alpha' mode. See the "Baselines" section in chapter 4.1 in AdaHAT paper. It applies only when adjustment_mode is 'hat_const_alpha'.
HAT doesn't use the epsilon for hat_const_alpha
. We still set it here to be consistent with the epsilon
in clip_grad_by_adjustment()
method in HATMaskBackbone
.
Store the binary attention mask of each previous task gated from the task embedding. Keys are task IDs (string type) and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units).
Store the cumulative binary attention mask $\mathrm{M}^{
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False
you are responsible for calling .backward()
, .step()
, .zero_grad()
.
102 def sanity_check(self) -> None: 103 r"""Check the sanity of the arguments. 104 105 **Raises:** 106 - **ValueError**: when backbone is not designed for HAT, or the `mask_sparsity_reg_mode` or `task_embedding_init_mode` is not one of the valid options. Also, if `alpha` is not given when `adjustment_mode` is 'hat_const_alpha'. 107 """ 108 if not isinstance(self.backbone, HATMaskBackbone): 109 raise ValueError("The backbone should be an instance of HATMaskBackbone.") 110 111 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 112 raise ValueError( 113 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 114 ) 115 if self.task_embedding_init_mode not in [ 116 "N01", 117 "U01", 118 "U-10", 119 "masked", 120 "unmasked", 121 ]: 122 raise ValueError( 123 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 124 ) 125 126 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 127 raise ValueError( 128 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 129 )
Check the sanity of the arguments.
Raises:
- ValueError: when backbone is not designed for HAT, or the
mask_sparsity_reg_mode
ortask_embedding_init_mode
is not one of the valid options. Also, ifalpha
is not given whenadjustment_mode
is 'hat_const_alpha'.
131 def on_train_start(self) -> None: 132 r"""Initialise the task embedding before training the next task and initialise the cumulative mask at the beginning of first task.""" 133 134 self.backbone.initialise_task_embedding(mode=self.task_embedding_init_mode) 135 136 # initialise the cumulative mask at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time. 137 if self.task_id == 1: 138 for layer_name in self.backbone.weighted_layer_names: 139 layer = self.backbone.get_layer_by_name( 140 layer_name 141 ) # get the layer by its name 142 num_units = layer.weight.shape[0] 143 144 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 145 num_units 146 ).to( 147 self.device 148 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialised as zeros mask ($t = 1$). See equation (2) in chapter 3 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or equation (5) in chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a).
Initialise the task embedding before training the next task and initialise the cumulative mask at the beginning of first task.
150 def clip_grad_by_adjustment( 151 self, 152 **kwargs, 153 ) -> Tensor: 154 r"""Clip the gradients by the adjustment rate. 155 156 Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes. 157 158 Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 159 160 161 **Returns:** 162 - **capacity** (`Tensor`): the calculated network capacity. 163 """ 164 165 # initialise network capacity metric 166 capacity = HATNetworkCapacity() 167 168 # Calculate the adjustment rate for gradients of the parameters, both weights and biases (if exists) 169 for layer_name in self.backbone.weighted_layer_names: 170 171 layer = self.backbone.get_layer_by_name( 172 layer_name 173 ) # get the layer by its name 174 175 # placeholder for the adjustment rate to avoid the error of using it before assignment 176 adjustment_rate_weight = 1 177 adjustment_rate_bias = 1 178 179 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 180 unit_wise_measure=self.cumulative_mask_for_previous_tasks, 181 layer_name=layer_name, 182 aggregation="min", 183 ) 184 185 if self.adjustment_mode == "hat": 186 adjustment_rate_weight = 1 - weight_mask 187 adjustment_rate_bias = 1 - bias_mask 188 189 elif self.adjustment_mode == "hat_random": 190 adjustment_rate_weight = torch.rand_like(weight_mask) * weight_mask + ( 191 1 - weight_mask 192 ) 193 adjustment_rate_bias = torch.rand_like(bias_mask) * bias_mask + ( 194 1 - bias_mask 195 ) 196 197 elif self.adjustment_mode == "hat_const_alpha": 198 adjustment_rate_weight = self.alpha * torch.ones_like( 199 weight_mask 200 ) * weight_mask + (1 - weight_mask) 201 adjustment_rate_bias = self.alpha * torch.ones_like( 202 bias_mask 203 ) * bias_mask + (1 - bias_mask) 204 205 elif self.adjustment_mode == "hat_const_1": 206 adjustment_rate_weight = torch.ones_like(weight_mask) * weight_mask + ( 207 1 - weight_mask 208 ) 209 adjustment_rate_bias = torch.ones_like(bias_mask) * bias_mask + ( 210 1 - bias_mask 211 ) 212 213 # apply the adjustment rate to the gradients 214 layer.weight.grad.data *= adjustment_rate_weight 215 if layer.bias is not None: 216 layer.bias.grad.data *= adjustment_rate_bias 217 218 # update network capacity metric 219 capacity.update(adjustment_rate_weight, adjustment_rate_bias) 220 221 return capacity.compute()
Clip the gradients by the adjustment rate.
Note that as the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only the parameters in between layers with task embedding, but also those before the first layer. We designed it seperately in the codes.
Network capacity is measured along with this method. Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in AdaHAT paper.
Returns:
- capacity (
Tensor
): the calculated network capacity.
223 def compensate_task_embedding_gradients( 224 self, 225 batch_idx: int, 226 num_batches: int, 227 ) -> None: 228 r"""Compensate the gradients of task embeddings during training. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 229 230 **Args:** 231 - **batch_idx** (`int`): the current training batch index. 232 - **num_batches** (`int`): the total number of training batches. 233 """ 234 235 for te in self.backbone.task_embedding_t.values(): 236 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 237 batch_idx - 1 238 ) / ( 239 num_batches - 1 240 ) # see equation (3) in chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a) 241 242 num = ( 243 torch.cosh( 244 torch.clamp( 245 anneal_scalar * te.weight.data, 246 -self.clamp_threshold, 247 self.clamp_threshold, 248 ) 249 ) 250 + 1 251 ) 252 253 den = torch.cosh(te.weight.data) + 1 254 255 compensation = self.s_max / anneal_scalar * num / den 256 257 te.weight.grad.data *= compensation
Compensate the gradients of task embeddings during training. See chapter 2.5 "Embedding Gradient Compensation" in HAT paper.
Args:
- batch_idx (
int
): the current training batch index. - num_batches (
int
): the total number of training batches.
259 def forward( 260 self, 261 input: torch.Tensor, 262 stage: str, 263 batch_idx: int | None = None, 264 num_batches: int | None = None, 265 task_id: int | None = None, 266 ) -> tuple[Tensor, dict[str, Tensor]]: 267 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 268 269 **Args:** 270 - **input** (`Tensor`): The input tensor from data. 271 - **stage** (`str`): the stage of the forward pass, should be one of the following: 272 1. 'train': training stage. 273 2. 'validation': validation stage. 274 3. 'test': testing stage. 275 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 276 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 277 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 278 279 **Returns:** 280 - **logits** (`Tensor`): the output logits tensor. 281 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units). 282 - **hidden_features** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 283 """ 284 feature, mask, hidden_features = self.backbone( 285 input, 286 stage=stage, 287 s_max=self.s_max if stage == "train" or stage == "validation" else None, 288 batch_idx=batch_idx if stage == "train" else None, 289 num_batches=num_batches if stage == "train" else None, 290 test_mask=self.masks[f"{task_id}"] if stage == "test" else None, 291 ) 292 logits = self.heads(feature, task_id) 293 294 return logits, mask, hidden_features
The forward pass for data from task task_id
. Note that it is nothing to do with forward()
method in nn.Module
.
Args:
- input (
Tensor
): The input tensor from data. - stage (
str
): the stage of the forward pass, should be one of the following:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- batch_idx (
int
|None
): the current batch index. Applies only to training stage. For other stages, it is defaultNone
. - num_batches (
int
|None
): the total number of batches. Applies only to training stage. For other stages, it is defaultNone
. - task_id (
int
|None
): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current taskself.task_id
. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL.
Returns:
- logits (
Tensor
): the output logits tensor. - mask (
dict[str, Tensor]
): the mask for the current task. Key (str
) is layer name, value (Tensor
) is the mask tensor. The mask tensor has size (number of units). - hidden_features (
dict[str, Tensor]
): the hidden features (after activation) in each weighted layer. Key (str
) is the weighted layer name, value (Tensor
) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited thisforward()
method ofHAT
class.
296 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 297 r"""Training step for current task `self.task_id`. 298 299 **Args:** 300 - **batch** (`Any`): a batch of training data. 301 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See chapter 2.4 "Hard Attention Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 302 303 **Returns:** 304 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this training step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For HAT, it includes 'mask' and 'capacity' for logging. 305 """ 306 x, y = batch 307 308 # zero the gradients before forward pass in manual optimisation mode 309 opt = self.optimizers() 310 opt.zero_grad() 311 312 # classification loss 313 num_batches = self.trainer.num_training_batches 314 logits, mask, hidden_features = self.forward( 315 x, 316 stage="train", 317 batch_idx=batch_idx, 318 num_batches=num_batches, 319 task_id=self.task_id, 320 ) 321 loss_cls = self.criterion(logits, y) 322 323 # regularisation loss. See chapter 2.6 "Promoting Low Capacity Usage" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 324 loss_reg, network_sparsity = self.mark_sparsity_reg( 325 mask, self.cumulative_mask_for_previous_tasks 326 ) 327 328 # total loss 329 loss = loss_cls + loss_reg 330 331 # backward step (manually) 332 self.manual_backward(loss) # calculate the gradients 333 # HAT hard clip gradients by the cumulative masks. See equation (2) inchapter 2.3 "Network Training" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). Network capacity is calculated along with this process. Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 334 capacity = self.clip_grad_by_adjustment( 335 network_sparsity=network_sparsity, # pass a keyword argument network sparsity here to make it compatible with AdaHAT. AdaHAT inherits this `training_step()` method. 336 ) 337 # compensate the gradients of task embedding. See chapter 2.5 "Embedding Gradient Compensation" in [HAT paper](http://proceedings.mlr.press/v80/serra18a). 338 self.compensate_task_embedding_gradients( 339 batch_idx=batch_idx, 340 num_batches=num_batches, 341 ) 342 # update parameters with the modified gradients 343 opt.step() 344 345 # accuracy of the batch 346 acc = (logits.argmax(dim=1) == y).float().mean() 347 348 return { 349 "loss": loss, # Return loss is essential for training step, or backpropagation will fail 350 "loss_cls": loss_cls, 351 "loss_reg": loss_reg, 352 "acc": acc, 353 "hidden_features": hidden_features, 354 "mask": mask, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 355 "capacity": capacity, 356 }
Training step for current task self.task_id
.
Args:
- batch (
Any
): a batch of training data. - batch_idx (
int
): the index of the batch. Used for calculating annealed scalar in HAT. See chapter 2.4 "Hard Attention Training" in HAT paper.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this training step. Key (str
) is the metrics name, value (Tensor
) is the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. For HAT, it includes 'mask' and 'capacity' for logging.
358 def on_train_end(self) -> None: 359 r"""Store the mask and update cumulative mask after training the task.""" 360 361 # store the mask for the current task 362 mask_t = { 363 layer_name: self.backbone.gate_fn( 364 self.backbone.task_embedding_t[layer_name].weight * self.s_max 365 ) 366 .squeeze() 367 .detach() 368 for layer_name in self.backbone.weighted_layer_names 369 } 370 371 self.masks[f"{self.task_id}"] = mask_t 372 373 # update the cumulative and summative masks 374 self.cumulative_mask_for_previous_tasks = { 375 layer_name: torch.max( 376 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 377 ) 378 for layer_name in self.backbone.weighted_layer_names 379 }
Store the mask and update cumulative mask after training the task.
381 def validation_step(self, batch: Any) -> dict[str, Tensor]: 382 r"""Validation step for current task `self.task_id`. 383 384 **Args:** 385 - **batch** (`Any`): a batch of validation data. 386 387 **Returns:** 388 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. 389 """ 390 x, y = batch 391 logits, mask, hidden_features = self.forward( 392 x, stage="validation", task_id=self.task_id 393 ) 394 loss_cls = self.criterion(logits, y) 395 acc = (logits.argmax(dim=1) == y).float().mean() 396 397 return { 398 "loss_cls": loss_cls, 399 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 400 }
Validation step for current task self.task_id
.
Args:
- batch (
Any
): a batch of validation data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this validation step. Key (str
) is the metrics name, value (Tensor
) is the metrics.
402 def test_step( 403 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 404 ) -> dict[str, Tensor]: 405 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 406 407 **Args:** 408 - **batch** (`Any`): a batch of test data. 409 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 410 411 **Returns:** 412 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Key (`str`) is the metrics name, value (`Tensor`) is the metrics. 413 """ 414 test_task_id = dataloader_idx + 1 415 416 x, y = batch 417 logits, mask, hidden_features = self.forward( 418 x, 419 stage="test", 420 task_id=test_task_id, 421 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 422 loss_cls = self.criterion(logits, y) 423 acc = (logits.argmax(dim=1) == y).float().mean() 424 425 return { 426 "loss_cls": loss_cls, 427 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 428 }
Test step for current task self.task_id
, which tests for all seen tasks indexed by dataloader_idx
.
Args:
- batch (
Any
): a batch of test data. - dataloader_idx (
int
): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and other metrics from this test step. Key (str
) is the metrics name, value (Tensor
) is the metrics.