clarena.cl_algorithms.hat
The submodule in cl_algorithms for HAT (Hard Attention to the Task) algorithm.
1r""" 2The submodule in `cl_algorithms` for [HAT (Hard Attention to the Task) algorithm](http://proceedings.mlr.press/v80/serra18a). 3""" 4 5__all__ = ["HAT"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12from torch.utils.data import DataLoader 13 14from clarena.backbones import HATMaskBackbone 15from clarena.cl_algorithms import CLAlgorithm 16from clarena.cl_algorithms.regularizers import HATMaskSparsityReg 17from clarena.heads import HeadsTIL 18from clarena.utils.metrics import HATNetworkCapacityMetric 19 20# always get logger for built-in logging in each module 21pylogger = logging.getLogger(__name__) 22 23 24class HAT(CLAlgorithm): 25 r"""[HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm. 26 27 An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters. 28 """ 29 30 def __init__( 31 self, 32 backbone: HATMaskBackbone, 33 heads: HeadsTIL, 34 adjustment_mode: str, 35 s_max: float, 36 clamp_threshold: float, 37 mask_sparsity_reg_factor: float, 38 mask_sparsity_reg_mode: str = "original", 39 task_embedding_init_mode: str = "N01", 40 alpha: float | None = None, 41 non_algorithmic_hparams: dict[str, Any] = {}, 42 ) -> None: 43 r"""Initialize the HAT algorithm with the network. 44 45 **Args:** 46 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 47 - **heads** (`HeadsTIL`): output heads. HAT only supports TIL (Task-Incremental Learning). 48 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 49 1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 50 2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 51 3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 54 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 56 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 57 1. 'original' (default): the original mask sparsity regularization in the HAT paper. 58 2. 'cross': the cross version of mask sparsity regularization. 59 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 60 1. 'N01' (default): standard normal distribution $N(0, 1)$. 61 2. 'U-11': uniform distribution $U(-1, 1)$. 62 3. 'U01': uniform distribution $U(0, 1)$. 63 4. 'U-10': uniform distribution $U(-1, 0)$. 64 5. 'last': inherit the task embedding from the last task. 65 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'. 66 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 67 68 """ 69 super().__init__( 70 backbone=backbone, 71 heads=heads, 72 non_algorithmic_hparams=non_algorithmic_hparams, 73 ) 74 75 # save additional algorithmic hyperparameters 76 self.save_hyperparameters( 77 "adjustment_mode", 78 "s_max", 79 "clamp_threshold", 80 "mask_sparsity_reg_factor", 81 "mask_sparsity_reg_mode", 82 "task_embedding_init_mode", 83 "alpha", 84 ) 85 86 self.adjustment_mode: str = adjustment_mode 87 r"""The adjustment mode for gradient clipping.""" 88 self.s_max: float = s_max 89 r"""The hyperparameter s_max.""" 90 self.clamp_threshold: float = clamp_threshold 91 r"""The clamp threshold for task embedding gradient compensation.""" 92 self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor 93 r"""The mask sparsity regularization factor.""" 94 self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode 95 r"""The mask sparsity regularization mode.""" 96 self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg( 97 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 98 ) 99 r"""The mask sparsity regularizer.""" 100 self.task_embedding_init_mode: str = task_embedding_init_mode 101 r"""The task embedding initialization mode.""" 102 self.alpha: float | None = alpha 103 r"""The hyperparameter alpha for `hat_const_alpha`.""" 104 # self.epsilon: float | None = None 105 # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`.""" 106 107 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 108 r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """ 109 110 # set manual optimization 111 self.automatic_optimization = False 112 113 HAT.sanity_check(self) 114 115 def sanity_check(self) -> None: 116 r"""Sanity check.""" 117 118 # check the backbone and heads 119 if not isinstance(self.backbone, HATMaskBackbone): 120 raise ValueError("The backbone should be an instance of `HATMaskBackbone`.") 121 if not isinstance(self.heads, HeadsTIL): 122 raise ValueError("The heads should be an instance of `HeadsTIL`.") 123 124 # check marker sparsity regularization mode 125 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 126 raise ValueError( 127 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 128 ) 129 130 # check task embedding initialization mode 131 if self.task_embedding_init_mode not in [ 132 "N01", 133 "U01", 134 "U-10", 135 "masked", 136 "unmasked", 137 ]: 138 raise ValueError( 139 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 140 ) 141 142 # check adjustment mode `hat_const_alpha` 143 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 144 raise ValueError( 145 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 146 ) 147 148 def on_train_start(self) -> None: 149 r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.""" 150 151 self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode) 152 153 self.backbone.initialize_independent_bn() 154 155 # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 156 if self.task_id == 1: 157 for layer_name in self.backbone.weighted_layer_names: 158 layer = self.backbone.get_layer_by_name( 159 layer_name 160 ) # get the layer by its name 161 num_units = layer.weight.shape[0] 162 163 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 164 num_units 165 ).to( 166 self.device 167 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 168 169 # self.neuron_first_task[layer_name] = [None] * num_units 170 171 def clip_grad_by_adjustment( 172 self, 173 **kwargs, 174 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 175 r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 176 177 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 178 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 179 180 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. 181 See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 182 183 **Returns:** 184 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 185 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors. 186 - **capacity** (`Tensor`): the calculated network capacity. 187 """ 188 189 # initialize network capacity metric 190 capacity = HATNetworkCapacityMetric().to(self.device) 191 adjustment_rate_weight = {} 192 adjustment_rate_bias = {} 193 194 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist) 195 for layer_name in self.backbone.weighted_layer_names: 196 197 layer = self.backbone.get_layer_by_name( 198 layer_name 199 ) # get the layer by its name 200 201 # placeholder for the adjustment rate to avoid the error of using it before assignment 202 adjustment_rate_weight_layer = 1 203 adjustment_rate_bias_layer = 1 204 205 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 206 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 207 layer_name=layer_name, 208 aggregation_mode="min", 209 ) 210 211 if self.adjustment_mode == "hat": 212 adjustment_rate_weight_layer = 1 - weight_mask 213 adjustment_rate_bias_layer = 1 - bias_mask 214 215 elif self.adjustment_mode == "hat_random": 216 adjustment_rate_weight_layer = torch.rand_like( 217 weight_mask 218 ) * weight_mask + (1 - weight_mask) 219 adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + ( 220 1 - bias_mask 221 ) 222 223 elif self.adjustment_mode == "hat_const_alpha": 224 adjustment_rate_weight_layer = self.alpha * torch.ones_like( 225 weight_mask 226 ) * weight_mask + (1 - weight_mask) 227 adjustment_rate_bias_layer = self.alpha * torch.ones_like( 228 bias_mask 229 ) * bias_mask + (1 - bias_mask) 230 231 elif self.adjustment_mode == "hat_const_1": 232 adjustment_rate_weight_layer = torch.ones_like( 233 weight_mask 234 ) * weight_mask + (1 - weight_mask) 235 adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + ( 236 1 - bias_mask 237 ) 238 239 # apply the adjustment rate to the gradients 240 layer.weight.grad.data *= adjustment_rate_weight_layer 241 if layer.bias is not None: 242 layer.bias.grad.data *= adjustment_rate_bias_layer 243 244 # store the adjustment rate for logging 245 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 246 if layer.bias is not None: 247 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 248 249 # update network capacity metric 250 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 251 252 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 253 254 def compensate_task_embedding_gradients( 255 self, 256 batch_idx: int, 257 num_batches: int, 258 ) -> None: 259 r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 260 261 **Args:** 262 - **batch_idx** (`int`): the current training batch index. 263 - **num_batches** (`int`): the total number of training batches. 264 """ 265 266 for te in self.backbone.task_embedding_t.values(): 267 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 268 batch_idx - 1 269 ) / ( 270 num_batches - 1 271 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 272 273 num = ( 274 torch.cosh( 275 torch.clamp( 276 anneal_scalar * te.weight.data, 277 -self.clamp_threshold, 278 self.clamp_threshold, 279 ) 280 ) 281 + 1 282 ) 283 284 den = torch.cosh(te.weight.data) + 1 285 286 compensation = self.s_max / anneal_scalar * num / den 287 288 te.weight.grad.data *= compensation 289 290 def forward( 291 self, 292 input: torch.Tensor, 293 stage: str, 294 task_id: int | None = None, 295 batch_idx: int | None = None, 296 num_batches: int | None = None, 297 ) -> tuple[Tensor, dict[str, Tensor]]: 298 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 299 300 **Args:** 301 - **input** (`Tensor`): The input tensor from data. 302 - **stage** (`str`): the stage of the forward pass; one of: 303 1. 'train': training stage. 304 2. 'validation': validation stage. 305 3. 'test': testing stage. 306 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 307 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 308 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 309 310 **Returns:** 311 - **logits** (`Tensor`): the output logits tensor. 312 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ). 313 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 314 """ 315 feature, mask, activations = self.backbone( 316 input, 317 stage=stage, 318 s_max=self.s_max if stage == "train" or stage == "validation" else None, 319 batch_idx=batch_idx if stage == "train" else None, 320 num_batches=num_batches if stage == "train" else None, 321 test_task_id=task_id if stage == "test" else None, 322 ) 323 logits = self.heads(feature, task_id) 324 325 return ( 326 logits 327 if self.if_forward_func_return_logits_only 328 else (logits, mask, activations) 329 ) 330 331 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 332 r"""Training step for current task `self.task_id`. 333 334 **Args:** 335 - **batch** (`Any`): a batch of training data. 336 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 337 338 **Returns:** 339 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging. 340 """ 341 x, y = batch 342 343 # zero the gradients before forward pass in manual optimization mode 344 opt = self.optimizers() 345 opt.zero_grad() 346 347 # classification loss 348 num_batches = self.trainer.num_training_batches 349 logits, mask, activations = self.forward( 350 x, 351 stage="train", 352 batch_idx=batch_idx, 353 num_batches=num_batches, 354 task_id=self.task_id, 355 ) 356 loss_cls = self.criterion(logits, y) 357 358 # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 359 loss_reg, network_sparsity = self.mark_sparsity_reg( 360 mask, self.cumulative_mask_for_previous_tasks 361 ) 362 363 # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 364 loss = loss_cls + loss_reg 365 366 # backward step (manually) 367 self.manual_backward(loss) # calculate the gradients 368 # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper. 369 # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 370 371 adjustment_rate_weight, adjustment_rate_bias, capacity = ( 372 self.clip_grad_by_adjustment( 373 network_sparsity=network_sparsity, # passed for compatibility with AdaHAT, which inherits this method 374 ) 375 ) 376 # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 377 self.compensate_task_embedding_gradients( 378 batch_idx=batch_idx, 379 num_batches=num_batches, 380 ) 381 # update parameters with the modified gradients 382 opt.step() 383 384 # accuracy of the batch 385 acc = (logits.argmax(dim=1) == y).float().mean() 386 387 return { 388 "loss": loss, # return loss is essential for training step, or backpropagation will fail 389 "loss_cls": loss_cls, 390 "loss_reg": loss_reg, 391 "acc": acc, 392 "activations": activations, 393 "logits": logits, 394 "mask": mask, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 395 "input": x, # return the input batch for Captum to use 396 "target": y, # return the target batch for Captum to use 397 "adjustment_rate_weight": adjustment_rate_weight, # return the adjustment rate for weights and biases for logging 398 "adjustment_rate_bias": adjustment_rate_bias, 399 "capacity": capacity, # return the network capacity for logging 400 } 401 402 def on_train_end(self) -> None: 403 r"""The mask and update the cumulative mask after training the task.""" 404 405 # store the mask for the current task 406 mask_t = self.backbone.store_mask() 407 408 # store the batch normalization if necessary 409 self.backbone.store_bn() 410 411 # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 412 self.cumulative_mask_for_previous_tasks = { 413 layer_name: torch.max( 414 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 415 ) 416 for layer_name in self.backbone.weighted_layer_names 417 } 418 419 def validation_step(self, batch: Any) -> dict[str, Tensor]: 420 r"""Validation step for current task `self.task_id`. 421 422 **Args:** 423 - **batch** (`Any`): a batch of validation data. 424 425 **Returns:** 426 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 427 """ 428 x, y = batch 429 logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 430 loss_cls = self.criterion(logits, y) 431 acc = (logits.argmax(dim=1) == y).float().mean() 432 433 return { 434 "loss_cls": loss_cls, 435 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 436 } 437 438 def test_step( 439 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 440 ) -> dict[str, Tensor]: 441 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 442 443 **Args:** 444 - **batch** (`Any`): a batch of test data. 445 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 446 447 **Returns:** 448 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 449 """ 450 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 451 452 x, y = batch 453 logits, _, _ = self.forward( 454 x, 455 stage="test", 456 task_id=test_task_id, 457 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 458 loss_cls = self.criterion(logits, y) 459 acc = (logits.argmax(dim=1) == y).float().mean() 460 461 return { 462 "loss_cls": loss_cls, 463 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 464 }
25class HAT(CLAlgorithm): 26 r"""[HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm. 27 28 An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters. 29 """ 30 31 def __init__( 32 self, 33 backbone: HATMaskBackbone, 34 heads: HeadsTIL, 35 adjustment_mode: str, 36 s_max: float, 37 clamp_threshold: float, 38 mask_sparsity_reg_factor: float, 39 mask_sparsity_reg_mode: str = "original", 40 task_embedding_init_mode: str = "N01", 41 alpha: float | None = None, 42 non_algorithmic_hparams: dict[str, Any] = {}, 43 ) -> None: 44 r"""Initialize the HAT algorithm with the network. 45 46 **Args:** 47 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 48 - **heads** (`HeadsTIL`): output heads. HAT only supports TIL (Task-Incremental Learning). 49 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 50 1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 51 2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 56 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 57 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 58 1. 'original' (default): the original mask sparsity regularization in the HAT paper. 59 2. 'cross': the cross version of mask sparsity regularization. 60 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 61 1. 'N01' (default): standard normal distribution $N(0, 1)$. 62 2. 'U-11': uniform distribution $U(-1, 1)$. 63 3. 'U01': uniform distribution $U(0, 1)$. 64 4. 'U-10': uniform distribution $U(-1, 0)$. 65 5. 'last': inherit the task embedding from the last task. 66 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'. 67 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 68 69 """ 70 super().__init__( 71 backbone=backbone, 72 heads=heads, 73 non_algorithmic_hparams=non_algorithmic_hparams, 74 ) 75 76 # save additional algorithmic hyperparameters 77 self.save_hyperparameters( 78 "adjustment_mode", 79 "s_max", 80 "clamp_threshold", 81 "mask_sparsity_reg_factor", 82 "mask_sparsity_reg_mode", 83 "task_embedding_init_mode", 84 "alpha", 85 ) 86 87 self.adjustment_mode: str = adjustment_mode 88 r"""The adjustment mode for gradient clipping.""" 89 self.s_max: float = s_max 90 r"""The hyperparameter s_max.""" 91 self.clamp_threshold: float = clamp_threshold 92 r"""The clamp threshold for task embedding gradient compensation.""" 93 self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor 94 r"""The mask sparsity regularization factor.""" 95 self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode 96 r"""The mask sparsity regularization mode.""" 97 self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg( 98 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 99 ) 100 r"""The mask sparsity regularizer.""" 101 self.task_embedding_init_mode: str = task_embedding_init_mode 102 r"""The task embedding initialization mode.""" 103 self.alpha: float | None = alpha 104 r"""The hyperparameter alpha for `hat_const_alpha`.""" 105 # self.epsilon: float | None = None 106 # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`.""" 107 108 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 109 r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """ 110 111 # set manual optimization 112 self.automatic_optimization = False 113 114 HAT.sanity_check(self) 115 116 def sanity_check(self) -> None: 117 r"""Sanity check.""" 118 119 # check the backbone and heads 120 if not isinstance(self.backbone, HATMaskBackbone): 121 raise ValueError("The backbone should be an instance of `HATMaskBackbone`.") 122 if not isinstance(self.heads, HeadsTIL): 123 raise ValueError("The heads should be an instance of `HeadsTIL`.") 124 125 # check marker sparsity regularization mode 126 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 127 raise ValueError( 128 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 129 ) 130 131 # check task embedding initialization mode 132 if self.task_embedding_init_mode not in [ 133 "N01", 134 "U01", 135 "U-10", 136 "masked", 137 "unmasked", 138 ]: 139 raise ValueError( 140 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 141 ) 142 143 # check adjustment mode `hat_const_alpha` 144 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 145 raise ValueError( 146 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 147 ) 148 149 def on_train_start(self) -> None: 150 r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.""" 151 152 self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode) 153 154 self.backbone.initialize_independent_bn() 155 156 # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 157 if self.task_id == 1: 158 for layer_name in self.backbone.weighted_layer_names: 159 layer = self.backbone.get_layer_by_name( 160 layer_name 161 ) # get the layer by its name 162 num_units = layer.weight.shape[0] 163 164 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 165 num_units 166 ).to( 167 self.device 168 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 169 170 # self.neuron_first_task[layer_name] = [None] * num_units 171 172 def clip_grad_by_adjustment( 173 self, 174 **kwargs, 175 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 176 r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 177 178 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 179 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 180 181 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. 182 See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 183 184 **Returns:** 185 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 186 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors. 187 - **capacity** (`Tensor`): the calculated network capacity. 188 """ 189 190 # initialize network capacity metric 191 capacity = HATNetworkCapacityMetric().to(self.device) 192 adjustment_rate_weight = {} 193 adjustment_rate_bias = {} 194 195 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist) 196 for layer_name in self.backbone.weighted_layer_names: 197 198 layer = self.backbone.get_layer_by_name( 199 layer_name 200 ) # get the layer by its name 201 202 # placeholder for the adjustment rate to avoid the error of using it before assignment 203 adjustment_rate_weight_layer = 1 204 adjustment_rate_bias_layer = 1 205 206 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 207 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 208 layer_name=layer_name, 209 aggregation_mode="min", 210 ) 211 212 if self.adjustment_mode == "hat": 213 adjustment_rate_weight_layer = 1 - weight_mask 214 adjustment_rate_bias_layer = 1 - bias_mask 215 216 elif self.adjustment_mode == "hat_random": 217 adjustment_rate_weight_layer = torch.rand_like( 218 weight_mask 219 ) * weight_mask + (1 - weight_mask) 220 adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + ( 221 1 - bias_mask 222 ) 223 224 elif self.adjustment_mode == "hat_const_alpha": 225 adjustment_rate_weight_layer = self.alpha * torch.ones_like( 226 weight_mask 227 ) * weight_mask + (1 - weight_mask) 228 adjustment_rate_bias_layer = self.alpha * torch.ones_like( 229 bias_mask 230 ) * bias_mask + (1 - bias_mask) 231 232 elif self.adjustment_mode == "hat_const_1": 233 adjustment_rate_weight_layer = torch.ones_like( 234 weight_mask 235 ) * weight_mask + (1 - weight_mask) 236 adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + ( 237 1 - bias_mask 238 ) 239 240 # apply the adjustment rate to the gradients 241 layer.weight.grad.data *= adjustment_rate_weight_layer 242 if layer.bias is not None: 243 layer.bias.grad.data *= adjustment_rate_bias_layer 244 245 # store the adjustment rate for logging 246 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 247 if layer.bias is not None: 248 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 249 250 # update network capacity metric 251 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 252 253 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 254 255 def compensate_task_embedding_gradients( 256 self, 257 batch_idx: int, 258 num_batches: int, 259 ) -> None: 260 r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 261 262 **Args:** 263 - **batch_idx** (`int`): the current training batch index. 264 - **num_batches** (`int`): the total number of training batches. 265 """ 266 267 for te in self.backbone.task_embedding_t.values(): 268 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 269 batch_idx - 1 270 ) / ( 271 num_batches - 1 272 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 273 274 num = ( 275 torch.cosh( 276 torch.clamp( 277 anneal_scalar * te.weight.data, 278 -self.clamp_threshold, 279 self.clamp_threshold, 280 ) 281 ) 282 + 1 283 ) 284 285 den = torch.cosh(te.weight.data) + 1 286 287 compensation = self.s_max / anneal_scalar * num / den 288 289 te.weight.grad.data *= compensation 290 291 def forward( 292 self, 293 input: torch.Tensor, 294 stage: str, 295 task_id: int | None = None, 296 batch_idx: int | None = None, 297 num_batches: int | None = None, 298 ) -> tuple[Tensor, dict[str, Tensor]]: 299 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 300 301 **Args:** 302 - **input** (`Tensor`): The input tensor from data. 303 - **stage** (`str`): the stage of the forward pass; one of: 304 1. 'train': training stage. 305 2. 'validation': validation stage. 306 3. 'test': testing stage. 307 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 308 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 309 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 310 311 **Returns:** 312 - **logits** (`Tensor`): the output logits tensor. 313 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ). 314 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 315 """ 316 feature, mask, activations = self.backbone( 317 input, 318 stage=stage, 319 s_max=self.s_max if stage == "train" or stage == "validation" else None, 320 batch_idx=batch_idx if stage == "train" else None, 321 num_batches=num_batches if stage == "train" else None, 322 test_task_id=task_id if stage == "test" else None, 323 ) 324 logits = self.heads(feature, task_id) 325 326 return ( 327 logits 328 if self.if_forward_func_return_logits_only 329 else (logits, mask, activations) 330 ) 331 332 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 333 r"""Training step for current task `self.task_id`. 334 335 **Args:** 336 - **batch** (`Any`): a batch of training data. 337 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 338 339 **Returns:** 340 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging. 341 """ 342 x, y = batch 343 344 # zero the gradients before forward pass in manual optimization mode 345 opt = self.optimizers() 346 opt.zero_grad() 347 348 # classification loss 349 num_batches = self.trainer.num_training_batches 350 logits, mask, activations = self.forward( 351 x, 352 stage="train", 353 batch_idx=batch_idx, 354 num_batches=num_batches, 355 task_id=self.task_id, 356 ) 357 loss_cls = self.criterion(logits, y) 358 359 # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 360 loss_reg, network_sparsity = self.mark_sparsity_reg( 361 mask, self.cumulative_mask_for_previous_tasks 362 ) 363 364 # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 365 loss = loss_cls + loss_reg 366 367 # backward step (manually) 368 self.manual_backward(loss) # calculate the gradients 369 # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper. 370 # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 371 372 adjustment_rate_weight, adjustment_rate_bias, capacity = ( 373 self.clip_grad_by_adjustment( 374 network_sparsity=network_sparsity, # passed for compatibility with AdaHAT, which inherits this method 375 ) 376 ) 377 # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 378 self.compensate_task_embedding_gradients( 379 batch_idx=batch_idx, 380 num_batches=num_batches, 381 ) 382 # update parameters with the modified gradients 383 opt.step() 384 385 # accuracy of the batch 386 acc = (logits.argmax(dim=1) == y).float().mean() 387 388 return { 389 "loss": loss, # return loss is essential for training step, or backpropagation will fail 390 "loss_cls": loss_cls, 391 "loss_reg": loss_reg, 392 "acc": acc, 393 "activations": activations, 394 "logits": logits, 395 "mask": mask, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 396 "input": x, # return the input batch for Captum to use 397 "target": y, # return the target batch for Captum to use 398 "adjustment_rate_weight": adjustment_rate_weight, # return the adjustment rate for weights and biases for logging 399 "adjustment_rate_bias": adjustment_rate_bias, 400 "capacity": capacity, # return the network capacity for logging 401 } 402 403 def on_train_end(self) -> None: 404 r"""The mask and update the cumulative mask after training the task.""" 405 406 # store the mask for the current task 407 mask_t = self.backbone.store_mask() 408 409 # store the batch normalization if necessary 410 self.backbone.store_bn() 411 412 # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 413 self.cumulative_mask_for_previous_tasks = { 414 layer_name: torch.max( 415 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 416 ) 417 for layer_name in self.backbone.weighted_layer_names 418 } 419 420 def validation_step(self, batch: Any) -> dict[str, Tensor]: 421 r"""Validation step for current task `self.task_id`. 422 423 **Args:** 424 - **batch** (`Any`): a batch of validation data. 425 426 **Returns:** 427 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 428 """ 429 x, y = batch 430 logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 431 loss_cls = self.criterion(logits, y) 432 acc = (logits.argmax(dim=1) == y).float().mean() 433 434 return { 435 "loss_cls": loss_cls, 436 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 437 } 438 439 def test_step( 440 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 441 ) -> dict[str, Tensor]: 442 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 443 444 **Args:** 445 - **batch** (`Any`): a batch of test data. 446 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 447 448 **Returns:** 449 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 450 """ 451 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 452 453 x, y = batch 454 logits, _, _ = self.forward( 455 x, 456 stage="test", 457 task_id=test_task_id, 458 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 459 loss_cls = self.criterion(logits, y) 460 acc = (logits.argmax(dim=1) == y).float().mean() 461 462 return { 463 "loss_cls": loss_cls, 464 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 465 }
HAT (Hard Attention to the Task) algorithm.
An architecture-based continual learning approach that uses learnable hard attention masks to select task-specific parameters.
31 def __init__( 32 self, 33 backbone: HATMaskBackbone, 34 heads: HeadsTIL, 35 adjustment_mode: str, 36 s_max: float, 37 clamp_threshold: float, 38 mask_sparsity_reg_factor: float, 39 mask_sparsity_reg_mode: str = "original", 40 task_embedding_init_mode: str = "N01", 41 alpha: float | None = None, 42 non_algorithmic_hparams: dict[str, Any] = {}, 43 ) -> None: 44 r"""Initialize the HAT algorithm with the network. 45 46 **Args:** 47 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 48 - **heads** (`HeadsTIL`): output heads. HAT only supports TIL (Task-Incremental Learning). 49 - **adjustment_mode** (`str`): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of: 50 1. 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 51 2. 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 52 3. 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value `alpha`. See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 53 4. 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 54 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 55 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 56 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 57 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 58 1. 'original' (default): the original mask sparsity regularization in the HAT paper. 59 2. 'cross': the cross version of mask sparsity regularization. 60 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 61 1. 'N01' (default): standard normal distribution $N(0, 1)$. 62 2. 'U-11': uniform distribution $U(-1, 1)$. 63 3. 'U01': uniform distribution $U(0, 1)$. 64 4. 'U-10': uniform distribution $U(-1, 0)$. 65 5. 'last': inherit the task embedding from the last task. 66 - **alpha** (`float` | `None`): the `alpha` in the 'HAT-const-alpha' mode. Applies only when `adjustment_mode` is 'hat_const_alpha'. 67 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 68 69 """ 70 super().__init__( 71 backbone=backbone, 72 heads=heads, 73 non_algorithmic_hparams=non_algorithmic_hparams, 74 ) 75 76 # save additional algorithmic hyperparameters 77 self.save_hyperparameters( 78 "adjustment_mode", 79 "s_max", 80 "clamp_threshold", 81 "mask_sparsity_reg_factor", 82 "mask_sparsity_reg_mode", 83 "task_embedding_init_mode", 84 "alpha", 85 ) 86 87 self.adjustment_mode: str = adjustment_mode 88 r"""The adjustment mode for gradient clipping.""" 89 self.s_max: float = s_max 90 r"""The hyperparameter s_max.""" 91 self.clamp_threshold: float = clamp_threshold 92 r"""The clamp threshold for task embedding gradient compensation.""" 93 self.mask_sparsity_reg_factor: float = mask_sparsity_reg_factor 94 r"""The mask sparsity regularization factor.""" 95 self.mask_sparsity_reg_mode: str = mask_sparsity_reg_mode 96 r"""The mask sparsity regularization mode.""" 97 self.mark_sparsity_reg: HATMaskSparsityReg = HATMaskSparsityReg( 98 factor=mask_sparsity_reg_factor, mode=mask_sparsity_reg_mode 99 ) 100 r"""The mask sparsity regularizer.""" 101 self.task_embedding_init_mode: str = task_embedding_init_mode 102 r"""The task embedding initialization mode.""" 103 self.alpha: float | None = alpha 104 r"""The hyperparameter alpha for `hat_const_alpha`.""" 105 # self.epsilon: float | None = None 106 # r"""HAT doesn't use epsilon for `hat_const_alpha`. It is kept for consistency with `epsilon` in `clip_grad_by_adjustment()` in `HATMaskBackbone`.""" 107 108 self.cumulative_mask_for_previous_tasks: dict[str, Tensor] = {} 109 r"""The cumulative binary attention mask $\mathrm{M}^{<t}$ of previous tasks $1,\cdots, t-1$, gated from the task embedding ($t$ is `self.task_id`). It is a dict where keys are layer names and values are the binary mask tensors for the layers. The mask tensor has size (number of units, ). """ 110 111 # set manual optimization 112 self.automatic_optimization = False 113 114 HAT.sanity_check(self)
Initialize the HAT algorithm with the network.
Args:
- backbone (
HATMaskBackbone): must be a backbone network with the HAT mask mechanism. - heads (
HeadsTIL): output heads. HAT only supports TIL (Task-Incremental Learning). - adjustment_mode (
str): the strategy of adjustment (i.e., the mode of gradient clipping), must be one of:- 'hat': set gradients of parameters linking to masked units to zero. This is how HAT fixes the part of the network for previous tasks completely. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
- 'hat_random': set gradients of parameters linking to masked units to random 0–1 values. See "Baselines" in Sec. 4.1 in the AdaHAT paper.
- 'hat_const_alpha': set gradients of parameters linking to masked units to a constant value
alpha. See "Baselines" in Sec. 4.1 in the AdaHAT paper. - 'hat_const_1': set gradients of parameters linking to masked units to a constant value of 1 (i.e., no gradient constraint). See "Baselines" in Sec. 4.1 in the AdaHAT paper.
- s_max (
float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper. - clamp_threshold (
float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper. - mask_sparsity_reg_factor (
float): hyperparameter, the regularization factor for mask sparsity. - mask_sparsity_reg_mode (
str): the mode of mask sparsity regularization, must be one of:- 'original' (default): the original mask sparsity regularization in the HAT paper.
- 'cross': the cross version of mask sparsity regularization.
- task_embedding_init_mode (
str): the initialization mode for task embeddings, must be one of:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit the task embedding from the last task.
- alpha (
float|None): thealphain the 'HAT-const-alpha' mode. Applies only whenadjustment_modeis 'hat_const_alpha'. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility.
The mask sparsity regularizer.
The cumulative binary attention mask $\mathrm{M}^{
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False you are responsible for calling .backward(), .step(), .zero_grad().
116 def sanity_check(self) -> None: 117 r"""Sanity check.""" 118 119 # check the backbone and heads 120 if not isinstance(self.backbone, HATMaskBackbone): 121 raise ValueError("The backbone should be an instance of `HATMaskBackbone`.") 122 if not isinstance(self.heads, HeadsTIL): 123 raise ValueError("The heads should be an instance of `HeadsTIL`.") 124 125 # check marker sparsity regularization mode 126 if self.mask_sparsity_reg_mode not in ["original", "cross"]: 127 raise ValueError( 128 "The mask_sparsity_reg_mode should be one of 'original', 'cross'." 129 ) 130 131 # check task embedding initialization mode 132 if self.task_embedding_init_mode not in [ 133 "N01", 134 "U01", 135 "U-10", 136 "masked", 137 "unmasked", 138 ]: 139 raise ValueError( 140 "The task_embedding_init_mode should be one of 'N01', 'U01', 'U-10', 'masked', 'unmasked'." 141 ) 142 143 # check adjustment mode `hat_const_alpha` 144 if self.adjustment_mode == "hat_const_alpha" and self.alpha is None: 145 raise ValueError( 146 "Alpha should be given when the adjustment_mode is 'hat_const_alpha'." 147 )
Sanity check.
149 def on_train_start(self) -> None: 150 r"""Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.""" 151 152 self.backbone.initialize_task_embedding(mode=self.task_embedding_init_mode) 153 154 self.backbone.initialize_independent_bn() 155 156 # initialize the cumulative mask for the first task at the beginning of the first task. This should not be called in `__init__()` because `self.device` is not available at that time. 157 if self.task_id == 1: 158 for layer_name in self.backbone.weighted_layer_names: 159 layer = self.backbone.get_layer_by_name( 160 layer_name 161 ) # get the layer by its name 162 num_units = layer.weight.shape[0] 163 164 self.cumulative_mask_for_previous_tasks[layer_name] = torch.zeros( 165 num_units 166 ).to( 167 self.device 168 ) # the cumulative mask $\mathrm{M}^{<t}$ is initialized as a zeros mask ($t = 1$). See Eq. (2) in Sec. 3 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9), or Eq. (5) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 169 170 # self.neuron_first_task[layer_name] = [None] * num_units
Initialize the task embedding before training the next task and initialize the cumulative mask at the beginning of the first task.
172 def clip_grad_by_adjustment( 173 self, 174 **kwargs, 175 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 176 r"""Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 177 178 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. 179 This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 180 181 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. 182 See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 183 184 **Returns:** 185 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 186 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer name and values (`Tensor`) are the adjustment rate tensors. 187 - **capacity** (`Tensor`): the calculated network capacity. 188 """ 189 190 # initialize network capacity metric 191 capacity = HATNetworkCapacityMetric().to(self.device) 192 adjustment_rate_weight = {} 193 adjustment_rate_bias = {} 194 195 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist) 196 for layer_name in self.backbone.weighted_layer_names: 197 198 layer = self.backbone.get_layer_by_name( 199 layer_name 200 ) # get the layer by its name 201 202 # placeholder for the adjustment rate to avoid the error of using it before assignment 203 adjustment_rate_weight_layer = 1 204 adjustment_rate_bias_layer = 1 205 206 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 207 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 208 layer_name=layer_name, 209 aggregation_mode="min", 210 ) 211 212 if self.adjustment_mode == "hat": 213 adjustment_rate_weight_layer = 1 - weight_mask 214 adjustment_rate_bias_layer = 1 - bias_mask 215 216 elif self.adjustment_mode == "hat_random": 217 adjustment_rate_weight_layer = torch.rand_like( 218 weight_mask 219 ) * weight_mask + (1 - weight_mask) 220 adjustment_rate_bias_layer = torch.rand_like(bias_mask) * bias_mask + ( 221 1 - bias_mask 222 ) 223 224 elif self.adjustment_mode == "hat_const_alpha": 225 adjustment_rate_weight_layer = self.alpha * torch.ones_like( 226 weight_mask 227 ) * weight_mask + (1 - weight_mask) 228 adjustment_rate_bias_layer = self.alpha * torch.ones_like( 229 bias_mask 230 ) * bias_mask + (1 - bias_mask) 231 232 elif self.adjustment_mode == "hat_const_1": 233 adjustment_rate_weight_layer = torch.ones_like( 234 weight_mask 235 ) * weight_mask + (1 - weight_mask) 236 adjustment_rate_bias_layer = torch.ones_like(bias_mask) * bias_mask + ( 237 1 - bias_mask 238 ) 239 240 # apply the adjustment rate to the gradients 241 layer.weight.grad.data *= adjustment_rate_weight_layer 242 if layer.bias is not None: 243 layer.bias.grad.data *= adjustment_rate_bias_layer 244 245 # store the adjustment rate for logging 246 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 247 if layer.bias is not None: 248 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 249 250 # update network capacity metric 251 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 252 253 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
Clip the gradients by the adjustment rate. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper.
Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.
Returns:
- adjustment_rate_weight (
dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. - adjustment_rate_bias (
dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer name and values (Tensor) are the adjustment rate tensors. - capacity (
Tensor): the calculated network capacity.
255 def compensate_task_embedding_gradients( 256 self, 257 batch_idx: int, 258 num_batches: int, 259 ) -> None: 260 r"""Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 261 262 **Args:** 263 - **batch_idx** (`int`): the current training batch index. 264 - **num_batches** (`int`): the total number of training batches. 265 """ 266 267 for te in self.backbone.task_embedding_t.values(): 268 anneal_scalar = 1 / self.s_max + (self.s_max - 1 / self.s_max) * ( 269 batch_idx - 1 270 ) / ( 271 num_batches - 1 272 ) # see Eq. (3) in Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 273 274 num = ( 275 torch.cosh( 276 torch.clamp( 277 anneal_scalar * te.weight.data, 278 -self.clamp_threshold, 279 self.clamp_threshold, 280 ) 281 ) 282 + 1 283 ) 284 285 den = torch.cosh(te.weight.data) + 1 286 287 compensation = self.s_max / anneal_scalar * num / den 288 289 te.weight.grad.data *= compensation
Compensate the gradients of task embeddings during training. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.
Args:
- batch_idx (
int): the current training batch index. - num_batches (
int): the total number of training batches.
291 def forward( 292 self, 293 input: torch.Tensor, 294 stage: str, 295 task_id: int | None = None, 296 batch_idx: int | None = None, 297 num_batches: int | None = None, 298 ) -> tuple[Tensor, dict[str, Tensor]]: 299 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. 300 301 **Args:** 302 - **input** (`Tensor`): The input tensor from data. 303 - **stage** (`str`): the stage of the forward pass; one of: 304 1. 'train': training stage. 305 2. 'validation': validation stage. 306 3. 'test': testing stage. 307 - **task_id** (`int`| `None`): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current task `self.task_id`. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. 308 - **batch_idx** (`int` | `None`): the current batch index. Applies only to training stage. For other stages, it is default `None`. 309 - **num_batches** (`int` | `None`): the total number of batches. Applies only to training stage. For other stages, it is default `None`. 310 311 **Returns:** 312 - **logits** (`Tensor`): the output logits tensor. 313 - **mask** (`dict[str, Tensor]`): the mask for the current task. Key (`str`) is layer name, value (`Tensor`) is the mask tensor. The mask tensor has size (number of units, ). 314 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited this `forward()` method of `HAT` class. 315 """ 316 feature, mask, activations = self.backbone( 317 input, 318 stage=stage, 319 s_max=self.s_max if stage == "train" or stage == "validation" else None, 320 batch_idx=batch_idx if stage == "train" else None, 321 num_batches=num_batches if stage == "train" else None, 322 test_task_id=task_id if stage == "test" else None, 323 ) 324 logits = self.heads(feature, task_id) 325 326 return ( 327 logits 328 if self.if_forward_func_return_logits_only 329 else (logits, mask, activations) 330 )
The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): the stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
- task_id (
int|None): the task ID where the data are from. If the stage is 'train' or 'validation', it should be the current taskself.task_id. If stage is 'test', it could be from any seen task. In TIL, the task IDs of test data are provided thus this argument can be used. HAT algorithm works only for TIL. - batch_idx (
int|None): the current batch index. Applies only to training stage. For other stages, it is defaultNone. - num_batches (
int|None): the total number of batches. Applies only to training stage. For other stages, it is defaultNone.
Returns:
- logits (
Tensor): the output logits tensor. - mask (
dict[str, Tensor]): the mask for the current task. Key (str) is layer name, value (Tensor) is the mask tensor. The mask tensor has size (number of units, ). - activations (
dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. Although HAT algorithm does not need this, it is still provided for API consistence for other HAT-based algorithms inherited thisforward()method ofHATclass.
332 def training_step(self, batch: Any, batch_idx: int) -> dict[str, Tensor]: 333 r"""Training step for current task `self.task_id`. 334 335 **Args:** 336 - **batch** (`Any`): a batch of training data. 337 - **batch_idx** (`int`): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 338 339 **Returns:** 340 - **outputs** (`dict[str, Tensor]`): a dictionary containing loss and other metrics from this training step. Keys (`str`) are metric names, and values (`Tensor`) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging. 341 """ 342 x, y = batch 343 344 # zero the gradients before forward pass in manual optimization mode 345 opt = self.optimizers() 346 opt.zero_grad() 347 348 # classification loss 349 num_batches = self.trainer.num_training_batches 350 logits, mask, activations = self.forward( 351 x, 352 stage="train", 353 batch_idx=batch_idx, 354 num_batches=num_batches, 355 task_id=self.task_id, 356 ) 357 loss_cls = self.criterion(logits, y) 358 359 # regularization loss. See Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 360 loss_reg, network_sparsity = self.mark_sparsity_reg( 361 mask, self.cumulative_mask_for_previous_tasks 362 ) 363 364 # total loss. See Eq. (4) in Sec. 2.6 "Promoting Low Capacity Usage" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 365 loss = loss_cls + loss_reg 366 367 # backward step (manually) 368 self.manual_backward(loss) # calculate the gradients 369 # HAT hard-clips gradients using the cumulative masks. See Eq. (2) in Sec. 2.3 "Network Training" in the HAT paper. 370 # Network capacity is computed along with this process (defined as the average adjustment rate over all parameters; see Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9)). 371 372 adjustment_rate_weight, adjustment_rate_bias, capacity = ( 373 self.clip_grad_by_adjustment( 374 network_sparsity=network_sparsity, # passed for compatibility with AdaHAT, which inherits this method 375 ) 376 ) 377 # compensate the gradients of task embedding. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 378 self.compensate_task_embedding_gradients( 379 batch_idx=batch_idx, 380 num_batches=num_batches, 381 ) 382 # update parameters with the modified gradients 383 opt.step() 384 385 # accuracy of the batch 386 acc = (logits.argmax(dim=1) == y).float().mean() 387 388 return { 389 "loss": loss, # return loss is essential for training step, or backpropagation will fail 390 "loss_cls": loss_cls, 391 "loss_reg": loss_reg, 392 "acc": acc, 393 "activations": activations, 394 "logits": logits, 395 "mask": mask, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 396 "input": x, # return the input batch for Captum to use 397 "target": y, # return the target batch for Captum to use 398 "adjustment_rate_weight": adjustment_rate_weight, # return the adjustment rate for weights and biases for logging 399 "adjustment_rate_bias": adjustment_rate_bias, 400 "capacity": capacity, # return the network capacity for logging 401 }
Training step for current task self.task_id.
Args:
- batch (
Any): a batch of training data. - batch_idx (
int): the index of the batch. Used for calculating annealed scalar in HAT. See Sec. 2.4 "Hard Attention Training" in the HAT paper.
Returns:
- outputs (
dict[str, Tensor]): a dictionary containing loss and other metrics from this training step. Keys (str) are metric names, and values (Tensor) are the metrics. Must include the key 'loss' (total loss) in the case of automatic optimization, according to PyTorch Lightning. For HAT, it includes 'mask' and 'capacity' for logging.
403 def on_train_end(self) -> None: 404 r"""The mask and update the cumulative mask after training the task.""" 405 406 # store the mask for the current task 407 mask_t = self.backbone.store_mask() 408 409 # store the batch normalization if necessary 410 self.backbone.store_bn() 411 412 # update the cumulative mask. See the first Eq. in Sec 2.3 "Network Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a) 413 self.cumulative_mask_for_previous_tasks = { 414 layer_name: torch.max( 415 self.cumulative_mask_for_previous_tasks[layer_name], mask_t[layer_name] 416 ) 417 for layer_name in self.backbone.weighted_layer_names 418 }
The mask and update the cumulative mask after training the task.
420 def validation_step(self, batch: Any) -> dict[str, Tensor]: 421 r"""Validation step for current task `self.task_id`. 422 423 **Args:** 424 - **batch** (`Any`): a batch of validation data. 425 426 **Returns:** 427 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 428 """ 429 x, y = batch 430 logits, _, _ = self.forward(x, stage="validation", task_id=self.task_id) 431 loss_cls = self.criterion(logits, y) 432 acc = (logits.argmax(dim=1) == y).float().mean() 433 434 return { 435 "loss_cls": loss_cls, 436 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 437 }
Validation step for current task self.task_id.
Args:
- batch (
Any): a batch of validation data.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
439 def test_step( 440 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 441 ) -> dict[str, Tensor]: 442 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 443 444 **Args:** 445 - **batch** (`Any`): a batch of test data. 446 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 447 448 **Returns:** 449 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 450 """ 451 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 452 453 x, y = batch 454 logits, _, _ = self.forward( 455 x, 456 stage="test", 457 task_id=test_task_id, 458 ) # use the corresponding head and mask to test (instead of the current task `self.task_id`) 459 loss_cls = self.criterion(logits, y) 460 acc = (logits.argmax(dim=1) == y).float().mean() 461 462 return { 463 "loss_cls": loss_cls, 464 "acc": acc, # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 465 }
Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.
Args:
- batch (
Any): a batch of test data. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.