clarena.cl_algorithms.hat
The submodule in cl_algorithms for HAT (Hard Attention to the Task) algorithm.
1r""" 2The submodule in `cl_algorithms` for [HAT (Hard Attention to the Task) algorithm](http://proceedings.mlr.press/v80/serra18a). 3""" 4 5__all__ = ["HAT"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12from torch.utils.data import DataLoader 13 14from clarena.backbones import HATMaskBackbone 15from clarena.cl_algorithms import CLAlgorithm 16from clarena.cl_algorithms.regularizers import HATMaskSparsityReg 17from clarena.heads import HeadDIL, HeadsTIL 18from clarena.utils.metrics import HATNetworkCapacityMetric 19 20# always get logger for built-in logging in each module 21pylogger = logging.getLogger(__name__) 22 23 24class HAT(CLAlgorithm): 25 r"""[HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm. 26 27 An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters. 28 """ 29 30 def __init__( 31 self, 32 backbone: HATMaskBackbone, 33 heads: HeadsTIL | HeadDIL, 34 adjustment_mode: str, 35 s_max: float, 36 clamp_threshold: float, 37 mask_sparsity_reg_factor: float, 38 mask_sparsity_reg_mode: str = "original", 39 task_embedding_init_mode: str = "N01", 40 alpha: float | None = None, 41 non_algorithmic_hparams: dict[str, Any] = {}, 42 **kwargs, 43 ) -> None: 44 r"""Initialize the HAT algorithm with the network. 45 46 **Args:** 47 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 48 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. HAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 49 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 50 1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 51 2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 56 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 57 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 58 1. 'original' (default): the original mask sparsity regularization in the HAT paper. 59 2. 'cross': the cross version of mask sparsity regularization. 60 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 61 1. 'N01' (default): standard normal distribution $N(0, 1)$. 62 2. 'U-11': uniform distribution $U(-1, 1)$. 63 3. 'U01': uniform distribution $U(0, 1)$. 64 4. 'U-10': uniform distribution $U(-1, 0)$. 65 5. 'last': inherit the task embedding from the last task. 66 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'. 67 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 68 - **kwargs**: Reserved for multiple inheritance. 69 70 """ 71 super().__init__( 72 backbone=backbone, 73 heads=heads, 74 non_algorithmic_hparams=non_algorithmic_hparams, 75 **kwargs, 76 ) 77 78 # save additional algorithmic hyperparameters 79 self.save_hyperparameters( 80 "adjustment_mode", 81 "s_max", 82 "clamp_threshold", 83 "mask_sparsity_reg_factor", 84 "mask_sparsity_reg_mode", 85 "task_embedding_init_mode", 86 "alpha", 87 ) 88 89 self.adjustment_mode: str = adjustment_mode 90 r"""The adjustment mode for gradient clipping.""" 91 self.s_max: float = s_max 92 r"""The hyperparameter s_max.""" 93 self.clamp_threshold: float = clamp_threshold 94 r"""The clamp threshold for task embedding gradient compensation.""" 95 self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor 96 r"""The mask sparsity regularization factor.""" 97 self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode 98 r"""The mask sparsity regularization mode.""" 99 self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg( 100 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 101 ) 102 r"""The mask sparsity regularizer.""" 103 self.task_embedding_init_mode: str = task_embedding_init_mode 104 r"""The task embedding initialization mode.""" 105 self.alpha: float | None = alpha 106 r"""The hyperparameter alpha for `hat_const_alpha`.""" 107 # self.epsilon: float | None = None 108 # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`.""" 109 110 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 111 r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """ 112 113 # set manual optimization 114 self.automatic_optimization = False 115 116 HAT.sanity_check(self) 117 118 def sanity_check(self) -> None: 119 r"""Sanity check.""" 120 121 # check the backbone and heads 122 if not isinstance(self.backbone, HATMaskBackbone): 123 raise ValueError("The backbone should be an instance of `HATMaskBackbone`.") 124 if not isinstance(self.heads, HeadsTIL) and not isinstance(self.heads, HeadDIL): 125 raise ValueError( 126 "The heads should be an instance of `HeadsTIL` or `HeadDIL`." 127 ) 128 129 # check marker sparsity regularization mode 130 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 131 raise ValueError( 132 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 133 ) 134 135 # check task embedding initialization mode 136 if self.task_embedding_init_mode not in [ 137 "N01", 138 "U01", 139 "U-10", 140 "masked", 141 "unmasked", 142 ]: 143 raise ValueError( 144 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 145 ) 146 147 # check adjustment mode `hat_const_alpha` 148 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 149 raise ValueError( 150 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 151 ) 152 153 def on_train_start(self) -> None: 154 r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.""" 155 156 self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode) 157 158 if self.backbone.batch_normalization is not None: 159 self.backbone.initialize_independent_bn() 160 161 # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 162 if self.cumulative_mask_for_previous_tasks == {}: 163 for layer_name in self.backbone.weighted_layer_names: 164 layer = self.backbone.get_layer_by_name( 165 layer_name 166 ) # get the layer by its name 167 num_units = layer.weight.shape[0] 168 169 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 170 num_units 171 ).to( 172 self.device 173 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 174 175 def clip_grad_by_mask( 176 self, 177 mask: dict[int, Tensor], 178 aggregation_mode: str, 179 ) -> None: 180 r"""Clip the gradients by neuron mask. 181 182 **Args:** 183 - **mask** (`dict[int, Tensor]`): the neuron-wise mask which the gradients outside are clipped. Keys are layer names and values are the corresponding mask tensor for the layer. 184 - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of: 185 - 'min': takes the minimum of the two connected unit measures. 186 - 'max': takes the maximum of the two connected unit measures. 187 - 'mean': takes the mean of the two connected unit measures. 188 189 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 190 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 191 """ 192 193 for layer_name in self.backbone.weighted_layer_names: 194 195 layer = self.backbone.get_layer_by_name( 196 layer_name 197 ) # get the layer by its name 198 199 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 200 neuron_wise_measure=mask, 201 layer_name=layer_name, 202 aggregation_mode=aggregation_mode, 203 ) 204 205 # apply the adjustment rate to the gradients 206 layer.weight.grad.data *= weight_mask 207 if layer.bias is not None: 208 layer.bias.grad.data *= bias_mask 209 210 def clip_grad_by_adjustment( 211 self, 212 **kwargs, 213 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 214 r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 215 216 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 217 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 218 219 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. 220 See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 221 222 **Returns:** 223 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 224 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors. 225 - **capacity** (`Tensor`): the calculated network capacity. 226 """ 227 228 # initialize network capacity metric 229 capacity = HATNetworkCapacityMetric().to(self.device) 230 adjustment_rate_weight = {} 231 adjustment_rate_bias = {} 232 233 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist) 234 for layer_name in self.backbone.weighted_layer_names: 235 236 layer = self.backbone.get_layer_by_name( 237 layer_name 238 ) # get the layer by its name 239 240 # placeholder for the adjustment rate to avoid the error of using it before assignment 241 adjustment_rate_weight_layer = 1 242 adjustment_rate_bias_layer = 1 243 244 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 245 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 246 layer_name=layer_name, 247 aggregation_mode="min", 248 ) 249 250 if self.adjustment_mode == "hat": 251 adjustment_rate_weight_layer = 1 - weight_mask 252 adjustment_rate_bias_layer = 1 - bias_mask 253 254 elif self.adjustment_mode == "hat_random": 255 adjustment_rate_weight_layer = torch.rand_like( 256 weight_mask 257 ) * weight_mask + (1 - weight_mask) 258 adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + ( 259 1 - bias_mask 260 ) 261 262 elif self.adjustment_mode == "hat_const_alpha": 263 adjustment_rate_weight_layer = self.alpha * torch.ones_like( 264 weight_mask 265 ) * weight_mask + (1 - weight_mask) 266 adjustment_rate_bias_layer = self.alpha * torch.ones_like( 267 bias_mask 268 ) * bias_mask + (1 - bias_mask) 269 270 elif self.adjustment_mode == "hat_const_1": 271 adjustment_rate_weight_layer = torch.ones_like( 272 weight_mask 273 ) * weight_mask + (1 - weight_mask) 274 adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + ( 275 1 - bias_mask 276 ) 277 278 # apply the adjustment rate to the gradients 279 layer.weight.grad.data *= adjustment_rate_weight_layer 280 if layer.bias is not None: 281 layer.bias.grad.data *= adjustment_rate_bias_layer 282 283 # store the adjustment rate for logging 284 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 285 if layer.bias is not None: 286 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 287 288 # update network capacity metric 289 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 290 291 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 292 293 def compensate_task_embedding_gradients( 294 self, 295 batch_idx: int, 296 num_batches: int, 297 ) -> None: 298 r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 299 300 **Args:** 301 - **batch_idx** (`int`): the current training batch index. 302 - **num_batches** (`int`): the total number of training batches. 303 """ 304 305 for te in self.backbone.task_embedding_t.values(): 306 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 307 batch_idx - 1 308 ) / ( 309 num_batches - 1 310 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 311 312 num = ( 313 torch.cosh( 314 torch.clamp( 315 anneal_scalar * te.weight.data, 316 -self.clamp_threshold, 317 self.clamp_threshold, 318 ) 319 ) 320 + 1 321 ) 322 323 den = torch.cosh(te.weight.data) + 1 324 325 compensation = self.s_max / anneal_scalar * num / den 326 327 te.weight.grad.data *= compensation 328 329 def forward( 330 self, 331 input: torch.Tensor, 332 stage: str, 333 task_id: int | None = None, 334 batch_idx: int | None = None, 335 num_batches: int | None = None, 336 ) -> tuple[Tensor, dict[str, Tensor]]: 337 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 338 339 **Args:** 340 - **input** (`Tensor`): The input tensor from data. 341 - **stage** (`str`): the stage of the forward pass; one of: 342 1. 'train': training stage. 343 2. 'validation': validation stage. 344 3. 'test': testing stage. 345 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 346 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 347 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 348 349 **Returns:** 350 - **logits** (`Tensor`): the output logits tensor. 351 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ). 352 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 353 """ 354 feature, mask, activations = self.backbone( 355 input, 356 stage=stage, 357 s_max=self.s_max if stage == "train" or stage == "validation" else None, 358 batch_idx=batch_idx if stage == "train" else None, 359 num_batches=num_batches if stage == "train" else None, 360 test_task_id=task_id if stage == "test" else None, 361 ) 362 logits = self.heads(feature, task_id) 363 364 return ( 365 logits 366 if self.if_forward_func_return_logits_only 367 else (logits, mask, activations) 368 ) 369 370 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 371 r"""Training step for current task `self.task_id`. 372 373 **Args:** 374 - **batch** (`Any`): a batch of training data. 375 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 376 377 **Returns:** 378 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging. 379 """ 380 x, y = batch 381 382 # zero the gradients before forward pass in manual optimization mode 383 opt = self.optimizers() 384 opt.zero_grad() 385 386 # classification loss 387 num_batches = self.trainer.num_training_batches 388 logits, mask, activations = self.forward( 389 x, 390 stage="train", 391 batch_idx=batch_idx, 392 num_batches=num_batches, 393 task_id=self.task_id, 394 ) 395 loss_cls = self.criterion(logits, y) 396 397 # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 398 hat_mask_sparsity_reg, network_sparsity = self.mark_sparsity_reg( 399 mask, self.cumulative_mask_for_previous_tasks 400 ) 401 402 # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 403 loss = loss_cls + hat_mask_sparsity_reg 404 405 # backward step (manually) 406 self.manual_backward(loss) # calculate the gradients 407 # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper. 408 # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 409 410 adjustment_rate_weight, adjustment_rate_bias, capacity = ( 411 self.clip_grad_by_adjustment( 412 network_sparsity=network_sparsity, # passed for compatibility with AdaHAT, which inherits this method 413 ) 414 ) 415 # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 416 self.compensate_task_embedding_gradients( 417 batch_idx=batch_idx, 418 num_batches=num_batches, 419 ) 420 # update parameters with the modified gradients 421 opt.step() 422 423 # predicted labels 424 preds = logits.argmax(dim=1) 425 426 # accuracy of the batch 427 acc = (preds == y).float().mean() 428 429 return { 430 "preds": preds, 431 "loss": loss, # return loss is essential for training step, or backpropagation will fail 432 "loss_cls": loss_cls, 433 "hat_mask_sparsity_reg": hat_mask_sparsity_reg, 434 "acc": acc, 435 "activations": activations, 436 "logits": logits, 437 "mask": mask, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 438 "input": x, # return the input batch for Captum to use 439 "target": y, # return the target batch for Captum to use 440 "adjustment_rate_weight": adjustment_rate_weight, # return the adjustment rate for weights and biases for logging 441 "adjustment_rate_bias": adjustment_rate_bias, 442 "capacity": capacity, # return the network capacity for logging 443 } 444 445 def on_train_end(self) -> None: 446 r"""The mask and update the cumulative mask after training the task.""" 447 448 # store the mask for the current task 449 mask_t = self.backbone.store_mask() 450 451 # store the batch normalization if necessary 452 if self.backbone.batch_normalization is not None: 453 self.backbone.store_bn() 454 455 # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 456 self.cumulative_mask_for_previous_tasks = { 457 layer_name: torch.max( 458 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 459 ) 460 for layer_name in self.backbone.weighted_layer_names 461 } 462 463 def validation_step(self, batch: Any) -> dict[str, Tensor]: 464 r"""Validation step for current task `self.task_id`. 465 466 **Args:** 467 - **batch** (`Any`): a batch of validation data. 468 469 **Returns:** 470 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 471 """ 472 x, y = batch 473 logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 474 loss_cls = self.criterion(logits, y) 475 preds = logits.argmax(dim=1) 476 acc = (preds == y).float().mean() 477 478 return { 479 "preds": preds, 480 "loss_cls": loss_cls, 481 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 482 } 483 484 def test_step( 485 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 486 ) -> dict[str, Tensor]: 487 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 488 489 **Args:** 490 - **batch** (`Any`): a batch of test data. 491 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 492 493 **Returns:** 494 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 495 """ 496 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 497 498 x, y = batch 499 logits, _, _ = self.forward( 500 x, 501 stage="test", 502 task_id=test_task_id, 503 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 504 loss_cls = self.criterion(logits, y) 505 preds = logits.argmax(dim=1) 506 acc = (preds == y).float().mean() 507 508 return { 509 "preds": preds, 510 "loss_cls": loss_cls, 511 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 512 }
25class HAT(CLAlgorithm): 26 r"""[HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm. 27 28 An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters. 29 """ 30 31 def __init__( 32 self, 33 backbone: HATMaskBackbone, 34 heads: HeadsTIL | HeadDIL, 35 adjustment_mode: str, 36 s_max: float, 37 clamp_threshold: float, 38 mask_sparsity_reg_factor: float, 39 mask_sparsity_reg_mode: str = "original", 40 task_embedding_init_mode: str = "N01", 41 alpha: float | None = None, 42 non_algorithmic_hparams: dict[str, Any] = {}, 43 **kwargs, 44 ) -> None: 45 r"""Initialize the HAT algorithm with the network. 46 47 **Args:** 48 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 49 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. HAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 50 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 51 1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 52 2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 55 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 56 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 57 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 58 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 59 1. 'original' (default): the original mask sparsity regularization in the HAT paper. 60 2. 'cross': the cross version of mask sparsity regularization. 61 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 62 1. 'N01' (default): standard normal distribution $N(0, 1)$. 63 2. 'U-11': uniform distribution $U(-1, 1)$. 64 3. 'U01': uniform distribution $U(0, 1)$. 65 4. 'U-10': uniform distribution $U(-1, 0)$. 66 5. 'last': inherit the task embedding from the last task. 67 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'. 68 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 69 - **kwargs**: Reserved for multiple inheritance. 70 71 """ 72 super().__init__( 73 backbone=backbone, 74 heads=heads, 75 non_algorithmic_hparams=non_algorithmic_hparams, 76 **kwargs, 77 ) 78 79 # save additional algorithmic hyperparameters 80 self.save_hyperparameters( 81 "adjustment_mode", 82 "s_max", 83 "clamp_threshold", 84 "mask_sparsity_reg_factor", 85 "mask_sparsity_reg_mode", 86 "task_embedding_init_mode", 87 "alpha", 88 ) 89 90 self.adjustment_mode: str = adjustment_mode 91 r"""The adjustment mode for gradient clipping.""" 92 self.s_max: float = s_max 93 r"""The hyperparameter s_max.""" 94 self.clamp_threshold: float = clamp_threshold 95 r"""The clamp threshold for task embedding gradient compensation.""" 96 self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor 97 r"""The mask sparsity regularization factor.""" 98 self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode 99 r"""The mask sparsity regularization mode.""" 100 self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg( 101 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 102 ) 103 r"""The mask sparsity regularizer.""" 104 self.task_embedding_init_mode: str = task_embedding_init_mode 105 r"""The task embedding initialization mode.""" 106 self.alpha: float | None = alpha 107 r"""The hyperparameter alpha for `hat_const_alpha`.""" 108 # self.epsilon: float | None = None 109 # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`.""" 110 111 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 112 r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """ 113 114 # set manual optimization 115 self.automatic_optimization = False 116 117 HAT.sanity_check(self) 118 119 def sanity_check(self) -> None: 120 r"""Sanity check.""" 121 122 # check the backbone and heads 123 if not isinstance(self.backbone, HATMaskBackbone): 124 raise ValueError("The backbone should be an instance of `HATMaskBackbone`.") 125 if not isinstance(self.heads, HeadsTIL) and not isinstance(self.heads, HeadDIL): 126 raise ValueError( 127 "The heads should be an instance of `HeadsTIL` or `HeadDIL`." 128 ) 129 130 # check marker sparsity regularization mode 131 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 132 raise ValueError( 133 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 134 ) 135 136 # check task embedding initialization mode 137 if self.task_embedding_init_mode not in [ 138 "N01", 139 "U01", 140 "U-10", 141 "masked", 142 "unmasked", 143 ]: 144 raise ValueError( 145 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 146 ) 147 148 # check adjustment mode `hat_const_alpha` 149 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 150 raise ValueError( 151 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 152 ) 153 154 def on_train_start(self) -> None: 155 r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.""" 156 157 self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode) 158 159 if self.backbone.batch_normalization is not None: 160 self.backbone.initialize_independent_bn() 161 162 # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 163 if self.cumulative_mask_for_previous_tasks == {}: 164 for layer_name in self.backbone.weighted_layer_names: 165 layer = self.backbone.get_layer_by_name( 166 layer_name 167 ) # get the layer by its name 168 num_units = layer.weight.shape[0] 169 170 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 171 num_units 172 ).to( 173 self.device 174 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 175 176 def clip_grad_by_mask( 177 self, 178 mask: dict[int, Tensor], 179 aggregation_mode: str, 180 ) -> None: 181 r"""Clip the gradients by neuron mask. 182 183 **Args:** 184 - **mask** (`dict[int, Tensor]`): the neuron-wise mask which the gradients outside are clipped. Keys are layer names and values are the corresponding mask tensor for the layer. 185 - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of: 186 - 'min': takes the minimum of the two connected unit measures. 187 - 'max': takes the maximum of the two connected unit measures. 188 - 'mean': takes the mean of the two connected unit measures. 189 190 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 191 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 192 """ 193 194 for layer_name in self.backbone.weighted_layer_names: 195 196 layer = self.backbone.get_layer_by_name( 197 layer_name 198 ) # get the layer by its name 199 200 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 201 neuron_wise_measure=mask, 202 layer_name=layer_name, 203 aggregation_mode=aggregation_mode, 204 ) 205 206 # apply the adjustment rate to the gradients 207 layer.weight.grad.data *= weight_mask 208 if layer.bias is not None: 209 layer.bias.grad.data *= bias_mask 210 211 def clip_grad_by_adjustment( 212 self, 213 **kwargs, 214 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 215 r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 216 217 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 218 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 219 220 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. 221 See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 222 223 **Returns:** 224 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 225 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors. 226 - **capacity** (`Tensor`): the calculated network capacity. 227 """ 228 229 # initialize network capacity metric 230 capacity = HATNetworkCapacityMetric().to(self.device) 231 adjustment_rate_weight = {} 232 adjustment_rate_bias = {} 233 234 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist) 235 for layer_name in self.backbone.weighted_layer_names: 236 237 layer = self.backbone.get_layer_by_name( 238 layer_name 239 ) # get the layer by its name 240 241 # placeholder for the adjustment rate to avoid the error of using it before assignment 242 adjustment_rate_weight_layer = 1 243 adjustment_rate_bias_layer = 1 244 245 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 246 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 247 layer_name=layer_name, 248 aggregation_mode="min", 249 ) 250 251 if self.adjustment_mode == "hat": 252 adjustment_rate_weight_layer = 1 - weight_mask 253 adjustment_rate_bias_layer = 1 - bias_mask 254 255 elif self.adjustment_mode == "hat_random": 256 adjustment_rate_weight_layer = torch.rand_like( 257 weight_mask 258 ) * weight_mask + (1 - weight_mask) 259 adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + ( 260 1 - bias_mask 261 ) 262 263 elif self.adjustment_mode == "hat_const_alpha": 264 adjustment_rate_weight_layer = self.alpha * torch.ones_like( 265 weight_mask 266 ) * weight_mask + (1 - weight_mask) 267 adjustment_rate_bias_layer = self.alpha * torch.ones_like( 268 bias_mask 269 ) * bias_mask + (1 - bias_mask) 270 271 elif self.adjustment_mode == "hat_const_1": 272 adjustment_rate_weight_layer = torch.ones_like( 273 weight_mask 274 ) * weight_mask + (1 - weight_mask) 275 adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + ( 276 1 - bias_mask 277 ) 278 279 # apply the adjustment rate to the gradients 280 layer.weight.grad.data *= adjustment_rate_weight_layer 281 if layer.bias is not None: 282 layer.bias.grad.data *= adjustment_rate_bias_layer 283 284 # store the adjustment rate for logging 285 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 286 if layer.bias is not None: 287 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 288 289 # update network capacity metric 290 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 291 292 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 293 294 def compensate_task_embedding_gradients( 295 self, 296 batch_idx: int, 297 num_batches: int, 298 ) -> None: 299 r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 300 301 **Args:** 302 - **batch_idx** (`int`): the current training batch index. 303 - **num_batches** (`int`): the total number of training batches. 304 """ 305 306 for te in self.backbone.task_embedding_t.values(): 307 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 308 batch_idx - 1 309 ) / ( 310 num_batches - 1 311 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 312 313 num = ( 314 torch.cosh( 315 torch.clamp( 316 anneal_scalar * te.weight.data, 317 -self.clamp_threshold, 318 self.clamp_threshold, 319 ) 320 ) 321 + 1 322 ) 323 324 den = torch.cosh(te.weight.data) + 1 325 326 compensation = self.s_max / anneal_scalar * num / den 327 328 te.weight.grad.data *= compensation 329 330 def forward( 331 self, 332 input: torch.Tensor, 333 stage: str, 334 task_id: int | None = None, 335 batch_idx: int | None = None, 336 num_batches: int | None = None, 337 ) -> tuple[Tensor, dict[str, Tensor]]: 338 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 339 340 **Args:** 341 - **input** (`Tensor`): The input tensor from data. 342 - **stage** (`str`): the stage of the forward pass; one of: 343 1. 'train': training stage. 344 2. 'validation': validation stage. 345 3. 'test': testing stage. 346 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 347 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 348 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 349 350 **Returns:** 351 - **logits** (`Tensor`): the output logits tensor. 352 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ). 353 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 354 """ 355 feature, mask, activations = self.backbone( 356 input, 357 stage=stage, 358 s_max=self.s_max if stage == "train" or stage == "validation" else None, 359 batch_idx=batch_idx if stage == "train" else None, 360 num_batches=num_batches if stage == "train" else None, 361 test_task_id=task_id if stage == "test" else None, 362 ) 363 logits = self.heads(feature, task_id) 364 365 return ( 366 logits 367 if self.if_forward_func_return_logits_only 368 else (logits, mask, activations) 369 ) 370 371 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 372 r"""Training step for current task `self.task_id`. 373 374 **Args:** 375 - **batch** (`Any`): a batch of training data. 376 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 377 378 **Returns:** 379 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging. 380 """ 381 x, y = batch 382 383 # zero the gradients before forward pass in manual optimization mode 384 opt = self.optimizers() 385 opt.zero_grad() 386 387 # classification loss 388 num_batches = self.trainer.num_training_batches 389 logits, mask, activations = self.forward( 390 x, 391 stage="train", 392 batch_idx=batch_idx, 393 num_batches=num_batches, 394 task_id=self.task_id, 395 ) 396 loss_cls = self.criterion(logits, y) 397 398 # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 399 hat_mask_sparsity_reg, network_sparsity = self.mark_sparsity_reg( 400 mask, self.cumulative_mask_for_previous_tasks 401 ) 402 403 # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 404 loss = loss_cls + hat_mask_sparsity_reg 405 406 # backward step (manually) 407 self.manual_backward(loss) # calculate the gradients 408 # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper. 409 # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 410 411 adjustment_rate_weight, adjustment_rate_bias, capacity = ( 412 self.clip_grad_by_adjustment( 413 network_sparsity=network_sparsity, # passed for compatibility with AdaHAT, which inherits this method 414 ) 415 ) 416 # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 417 self.compensate_task_embedding_gradients( 418 batch_idx=batch_idx, 419 num_batches=num_batches, 420 ) 421 # update parameters with the modified gradients 422 opt.step() 423 424 # predicted labels 425 preds = logits.argmax(dim=1) 426 427 # accuracy of the batch 428 acc = (preds == y).float().mean() 429 430 return { 431 "preds": preds, 432 "loss": loss, # return loss is essential for training step, or backpropagation will fail 433 "loss_cls": loss_cls, 434 "hat_mask_sparsity_reg": hat_mask_sparsity_reg, 435 "acc": acc, 436 "activations": activations, 437 "logits": logits, 438 "mask": mask, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 439 "input": x, # return the input batch for Captum to use 440 "target": y, # return the target batch for Captum to use 441 "adjustment_rate_weight": adjustment_rate_weight, # return the adjustment rate for weights and biases for logging 442 "adjustment_rate_bias": adjustment_rate_bias, 443 "capacity": capacity, # return the network capacity for logging 444 } 445 446 def on_train_end(self) -> None: 447 r"""The mask and update the cumulative mask after training the task.""" 448 449 # store the mask for the current task 450 mask_t = self.backbone.store_mask() 451 452 # store the batch normalization if necessary 453 if self.backbone.batch_normalization is not None: 454 self.backbone.store_bn() 455 456 # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 457 self.cumulative_mask_for_previous_tasks = { 458 layer_name: torch.max( 459 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 460 ) 461 for layer_name in self.backbone.weighted_layer_names 462 } 463 464 def validation_step(self, batch: Any) -> dict[str, Tensor]: 465 r"""Validation step for current task `self.task_id`. 466 467 **Args:** 468 - **batch** (`Any`): a batch of validation data. 469 470 **Returns:** 471 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 472 """ 473 x, y = batch 474 logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 475 loss_cls = self.criterion(logits, y) 476 preds = logits.argmax(dim=1) 477 acc = (preds == y).float().mean() 478 479 return { 480 "preds": preds, 481 "loss_cls": loss_cls, 482 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 483 } 484 485 def test_step( 486 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 487 ) -> dict[str, Tensor]: 488 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 489 490 **Args:** 491 - **batch** (`Any`): a batch of test data. 492 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 493 494 **Returns:** 495 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 496 """ 497 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 498 499 x, y = batch 500 logits, _, _ = self.forward( 501 x, 502 stage="test", 503 task_id=test_task_id, 504 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 505 loss_cls = self.criterion(logits, y) 506 preds = logits.argmax(dim=1) 507 acc = (preds == y).float().mean() 508 509 return { 510 "preds": preds, 511 "loss_cls": loss_cls, 512 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 513 }
HAT (Hard Attention to the Task) algorithm.
An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
31 def __init__( 32 self, 33 backbone: HATMaskBackbone, 34 heads: HeadsTIL | HeadDIL, 35 adjustment_mode: str, 36 s_max: float, 37 clamp_threshold: float, 38 mask_sparsity_reg_factor: float, 39 mask_sparsity_reg_mode: str = "original", 40 task_embedding_init_mode: str = "N01", 41 alpha: float | None = None, 42 non_algorithmic_hparams: dict[str, Any] = {}, 43 **kwargs, 44 ) -> None: 45 r"""Initialize the HAT algorithm with the network. 46 47 **Args:** 48 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 49 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. HAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 50 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 51 1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 52 2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 55 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 56 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 57 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 58 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 59 1. 'original' (default): the original mask sparsity regularization in the HAT paper. 60 2. 'cross': the cross version of mask sparsity regularization. 61 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 62 1. 'N01' (default): standard normal distribution $N(0, 1)$. 63 2. 'U-11': uniform distribution $U(-1, 1)$. 64 3. 'U01': uniform distribution $U(0, 1)$. 65 4. 'U-10': uniform distribution $U(-1, 0)$. 66 5. 'last': inherit the task embedding from the last task. 67 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'. 68 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 69 - **kwargs**: Reserved for multiple inheritance. 70 71 """ 72 super().__init__( 73 backbone=backbone, 74 heads=heads, 75 non_algorithmic_hparams=non_algorithmic_hparams, 76 **kwargs, 77 ) 78 79 # save additional algorithmic hyperparameters 80 self.save_hyperparameters( 81 "adjustment_mode", 82 "s_max", 83 "clamp_threshold", 84 "mask_sparsity_reg_factor", 85 "mask_sparsity_reg_mode", 86 "task_embedding_init_mode", 87 "alpha", 88 ) 89 90 self.adjustment_mode: str = adjustment_mode 91 r"""The adjustment mode for gradient clipping.""" 92 self.s_max: float = s_max 93 r"""The hyperparameter s_max.""" 94 self.clamp_threshold: float = clamp_threshold 95 r"""The clamp threshold for task embedding gradient compensation.""" 96 self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor 97 r"""The mask sparsity regularization factor.""" 98 self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode 99 r"""The mask sparsity regularization mode.""" 100 self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg( 101 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 102 ) 103 r"""The mask sparsity regularizer.""" 104 self.task_embedding_init_mode: str = task_embedding_init_mode 105 r"""The task embedding initialization mode.""" 106 self.alpha: float | None = alpha 107 r"""The hyperparameter alpha for `hat_const_alpha`.""" 108 # self.epsilon: float | None = None 109 # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`.""" 110 111 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 112 r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """ 113 114 # set manual optimization 115 self.automatic_optimization = False 116 117 HAT.sanity_check(self)
Initialize the HAT algorithm with the network.
Args:
- backbone (
HATMaskBackbone): must be a backbone network with the HAT mask mechanism. - heads (
HeadsTIL|HeadDIL): output heads. HAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). - adjustment_mode (
str): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:- 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
- 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the AdaHAT paper.
- 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value
alpha. See "Baselines" in Sec. 4.1 in the AdaHAT paper. - 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the AdaHAT paper.
- s_max (
float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper. - clamp_threshold (
float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper. - mask_sparsity_reg_factor (
float): hyperparameter, the regularization factor for mask sparsity. - mask_sparsity_reg_mode (
str): the mode of mask sparsity regularization, must be one of:- 'original' (default): the original mask sparsity regularization in the HAT paper.
- 'cross': the cross version of mask sparsity regularization.
- task_embedding_init_mode (
str): the initialization mode for task embeddings, must be one of:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit the task embedding from the last task.
- alpha (
float|None): thealphain the 'HAT-const-alpha' mode. Applies only whenadjustment_modeis 'hat_const_alpha'. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
The mask sparsity regularizer.
The cumulative binary attention mask $\mathrm{M}^{
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False you are responsible for calling .backward(), .step(), .zero_grad().
119 def sanity_check(self) -> None: 120 r"""Sanity check.""" 121 122 # check the backbone and heads 123 if not isinstance(self.backbone, HATMaskBackbone): 124 raise ValueError("The backbone should be an instance of `HATMaskBackbone`.") 125 if not isinstance(self.heads, HeadsTIL) and not isinstance(self.heads, HeadDIL): 126 raise ValueError( 127 "The heads should be an instance of `HeadsTIL` or `HeadDIL`." 128 ) 129 130 # check marker sparsity regularization mode 131 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 132 raise ValueError( 133 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 134 ) 135 136 # check task embedding initialization mode 137 if self.task_embedding_init_mode not in [ 138 "N01", 139 "U01", 140 "U-10", 141 "masked", 142 "unmasked", 143 ]: 144 raise ValueError( 145 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 146 ) 147 148 # check adjustment mode `hat_const_alpha` 149 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 150 raise ValueError( 151 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 152 )
Sanity check.
154 def on_train_start(self) -> None: 155 r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.""" 156 157 self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode) 158 159 if self.backbone.batch_normalization is not None: 160 self.backbone.initialize_independent_bn() 161 162 # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 163 if self.cumulative_mask_for_previous_tasks == {}: 164 for layer_name in self.backbone.weighted_layer_names: 165 layer = self.backbone.get_layer_by_name( 166 layer_name 167 ) # get the layer by its name 168 num_units = layer.weight.shape[0] 169 170 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 171 num_units 172 ).to( 173 self.device 174 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a)
Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.
176 def clip_grad_by_mask( 177 self, 178 mask: dict[int, Tensor], 179 aggregation_mode: str, 180 ) -> None: 181 r"""Clip the gradients by neuron mask. 182 183 **Args:** 184 - **mask** (`dict[int, Tensor]`): the neuron-wise mask which the gradients outside are clipped. Keys are layer names and values are the corresponding mask tensor for the layer. 185 - **aggregation_mode** (`str`): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of: 186 - 'min': takes the minimum of the two connected unit measures. 187 - 'max': takes the maximum of the two connected unit measures. 188 - 'mean': takes the mean of the two connected unit measures. 189 190 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 191 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 192 """ 193 194 for layer_name in self.backbone.weighted_layer_names: 195 196 layer = self.backbone.get_layer_by_name( 197 layer_name 198 ) # get the layer by its name 199 200 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 201 neuron_wise_measure=mask, 202 layer_name=layer_name, 203 aggregation_mode=aggregation_mode, 204 ) 205 206 # apply the adjustment rate to the gradients 207 layer.weight.grad.data *= weight_mask 208 if layer.bias is not None: 209 layer.bias.grad.data *= bias_mask
Clip the gradients by neuron mask.
Args:
- mask (
dict[int, Tensor]): the neuron-wise mask which the gradients outside are clipped. Keys are layer names and values are the corresponding mask tensor for the layer. - aggregation_mode (
str): The aggregation mode mapping two feature-wise measures into a weight-wise matrix; one of:- 'min': takes the minimum of the two connected unit measures.
- 'max': takes the maximum of the two connected unit measures.
- 'mean': takes the mean of the two connected unit measures.
Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
211 def clip_grad_by_adjustment( 212 self, 213 **kwargs, 214 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 215 r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 216 217 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 218 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 219 220 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. 221 See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 222 223 **Returns:** 224 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 225 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors. 226 - **capacity** (`Tensor`): the calculated network capacity. 227 """ 228 229 # initialize network capacity metric 230 capacity = HATNetworkCapacityMetric().to(self.device) 231 adjustment_rate_weight = {} 232 adjustment_rate_bias = {} 233 234 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist) 235 for layer_name in self.backbone.weighted_layer_names: 236 237 layer = self.backbone.get_layer_by_name( 238 layer_name 239 ) # get the layer by its name 240 241 # placeholder for the adjustment rate to avoid the error of using it before assignment 242 adjustment_rate_weight_layer = 1 243 adjustment_rate_bias_layer = 1 244 245 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 246 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 247 layer_name=layer_name, 248 aggregation_mode="min", 249 ) 250 251 if self.adjustment_mode == "hat": 252 adjustment_rate_weight_layer = 1 - weight_mask 253 adjustment_rate_bias_layer = 1 - bias_mask 254 255 elif self.adjustment_mode == "hat_random": 256 adjustment_rate_weight_layer = torch.rand_like( 257 weight_mask 258 ) * weight_mask + (1 - weight_mask) 259 adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + ( 260 1 - bias_mask 261 ) 262 263 elif self.adjustment_mode == "hat_const_alpha": 264 adjustment_rate_weight_layer = self.alpha * torch.ones_like( 265 weight_mask 266 ) * weight_mask + (1 - weight_mask) 267 adjustment_rate_bias_layer = self.alpha * torch.ones_like( 268 bias_mask 269 ) * bias_mask + (1 - bias_mask) 270 271 elif self.adjustment_mode == "hat_const_1": 272 adjustment_rate_weight_layer = torch.ones_like( 273 weight_mask 274 ) * weight_mask + (1 - weight_mask) 275 adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + ( 276 1 - bias_mask 277 ) 278 279 # apply the adjustment rate to the gradients 280 layer.weight.grad.data *= adjustment_rate_weight_layer 281 if layer.bias is not None: 282 layer.bias.grad.data *= adjustment_rate_bias_layer 283 284 # store the adjustment rate for logging 285 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 286 if layer.bias is not None: 287 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 288 289 # update network capacity metric 290 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 291 292 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.
Returns:
- adjustment_rate_weight (
dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. - adjustment_rate_bias (
dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer name and values (Tensor) are the adjustment rate tensors. - capacity (
Tensor): the calculated network capacity.
294 def compensate_task_embedding_gradients( 295 self, 296 batch_idx: int, 297 num_batches: int, 298 ) -> None: 299 r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 300 301 **Args:** 302 - **batch_idx** (`int`): the current training batch index. 303 - **num_batches** (`int`): the total number of training batches. 304 """ 305 306 for te in self.backbone.task_embedding_t.values(): 307 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 308 batch_idx - 1 309 ) / ( 310 num_batches - 1 311 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 312 313 num = ( 314 torch.cosh( 315 torch.clamp( 316 anneal_scalar * te.weight.data, 317 -self.clamp_threshold, 318 self.clamp_threshold, 319 ) 320 ) 321 + 1 322 ) 323 324 den = torch.cosh(te.weight.data) + 1 325 326 compensation = self.s_max / anneal_scalar * num / den 327 328 te.weight.grad.data *= compensation
Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.
Args:
- batch_idx (
int): the current training batch index. - num_batches (
int): the total number of training batches.
330 def forward( 331 self, 332 input: torch.Tensor, 333 stage: str, 334 task_id: int | None = None, 335 batch_idx: int | None = None, 336 num_batches: int | None = None, 337 ) -> tuple[Tensor, dict[str, Tensor]]: 338 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 339 340 **Args:** 341 - **input** (`Tensor`): The input tensor from data. 342 - **stage** (`str`): the stage of the forward pass; one of: 343 1. 'train': training stage. 344 2. 'validation': validation stage. 345 3. 'test': testing stage. 346 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 347 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 348 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 349 350 **Returns:** 351 - **logits** (`Tensor`): the output logits tensor. 352 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ). 353 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 354 """ 355 feature, mask, activations = self.backbone( 356 input, 357 stage=stage, 358 s_max=self.s_max if stage == "train" or stage == "validation" else None, 359 batch_idx=batch_idx if stage == "train" else None, 360 num_batches=num_batches if stage == "train" else None, 361 test_task_id=task_id if stage == "test" else None, 362 ) 363 logits = self.heads(feature, task_id) 364 365 return ( 366 logits 367 if self.if_forward_func_return_logits_only 368 else (logits, mask, activations) 369 )
The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): the stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int|None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current taskself.task_id. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. - batch_idx (
int|None): the current batch index. Applies only to training stage. For other stages, it is defaultNone. - num_batches (
int|None): the total number of batches. Applies only to training stage. For other stages, it is defaultNone.
Returns:
- logits (
Tensor): the output logits tensor. - mask (
dict[str, Tensor]): the mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units, ). - activations (
dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited thisforward()method ofHATclass.
371 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 372 r"""Training step for current task `self.task_id`. 373 374 **Args:** 375 - **batch** (`Any`): a batch of training data. 376 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 377 378 **Returns:** 379 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging. 380 """ 381 x, y = batch 382 383 # zero the gradients before forward pass in manual optimization mode 384 opt = self.optimizers() 385 opt.zero_grad() 386 387 # classification loss 388 num_batches = self.trainer.num_training_batches 389 logits, mask, activations = self.forward( 390 x, 391 stage="train", 392 batch_idx=batch_idx, 393 num_batches=num_batches, 394 task_id=self.task_id, 395 ) 396 loss_cls = self.criterion(logits, y) 397 398 # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 399 hat_mask_sparsity_reg, network_sparsity = self.mark_sparsity_reg( 400 mask, self.cumulative_mask_for_previous_tasks 401 ) 402 403 # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 404 loss = loss_cls + hat_mask_sparsity_reg 405 406 # backward step (manually) 407 self.manual_backward(loss) # calculate the gradients 408 # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper. 409 # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 410 411 adjustment_rate_weight, adjustment_rate_bias, capacity = ( 412 self.clip_grad_by_adjustment( 413 network_sparsity=network_sparsity, # passed for compatibility with AdaHAT, which inherits this method 414 ) 415 ) 416 # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 417 self.compensate_task_embedding_gradients( 418 batch_idx=batch_idx, 419 num_batches=num_batches, 420 ) 421 # update parameters with the modified gradients 422 opt.step() 423 424 # predicted labels 425 preds = logits.argmax(dim=1) 426 427 # accuracy of the batch 428 acc = (preds == y).float().mean() 429 430 return { 431 "preds": preds, 432 "loss": loss, # return loss is essential for training step, or backpropagation will fail 433 "loss_cls": loss_cls, 434 "hat_mask_sparsity_reg": hat_mask_sparsity_reg, 435 "acc": acc, 436 "activations": activations, 437 "logits": logits, 438 "mask": mask, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 439 "input": x, # return the input batch for Captum to use 440 "target": y, # return the target batch for Captum to use 441 "adjustment_rate_weight": adjustment_rate_weight, # return the adjustment rate for weights and biases for logging 442 "adjustment_rate_bias": adjustment_rate_bias, 443 "capacity": capacity, # return the network capacity for logging 444 }
Training step for current task self.task_id.
Args:
- batch (
Any): a batch of training data. - batch_idx (
int): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the HAT paper.
Returns:
- outputs (
dict[str, Tensor]): a dictionary containing loss and other metrics from this training step. Keys (str) are metric names, and values (Tensor) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
446 def on_train_end(self) -> None: 447 r"""The mask and update the cumulative mask after training the task.""" 448 449 # store the mask for the current task 450 mask_t = self.backbone.store_mask() 451 452 # store the batch normalization if necessary 453 if self.backbone.batch_normalization is not None: 454 self.backbone.store_bn() 455 456 # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 457 self.cumulative_mask_for_previous_tasks = { 458 layer_name: torch.max( 459 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 460 ) 461 for layer_name in self.backbone.weighted_layer_names 462 }
The mask and update the cumulative mask after training the task.
464 def validation_step(self, batch: Any) -> dict[str, Tensor]: 465 r"""Validation step for current task `self.task_id`. 466 467 **Args:** 468 - **batch** (`Any`): a batch of validation data. 469 470 **Returns:** 471 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 472 """ 473 x, y = batch 474 logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 475 loss_cls = self.criterion(logits, y) 476 preds = logits.argmax(dim=1) 477 acc = (preds == y).float().mean() 478 479 return { 480 "preds": preds, 481 "loss_cls": loss_cls, 482 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 483 }
Validation step for current task self.task_id.
Args:
- batch (
Any): a batch of validation data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
485 def test_step( 486 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 487 ) -> dict[str, Tensor]: 488 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 489 490 **Args:** 491 - **batch** (`Any`): a batch of test data. 492 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 493 494 **Returns:** 495 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 496 """ 497 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 498 499 x, y = batch 500 logits, _, _ = self.forward( 501 x, 502 stage="test", 503 task_id=test_task_id, 504 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 505 loss_cls = self.criterion(logits, y) 506 preds = logits.argmax(dim=1) 507 acc = (preds == y).float().mean() 508 509 return { 510 "preds": preds, 511 "loss_cls": loss_cls, 512 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 513 }
Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.
Args:
- batch (
Any): a batch of test data. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.