clarena.cl_algorithms.fgadahat
The submodule in cl_algorithms for FG-AdaHAT algorithm.
1r""" 2The submodule in `cl_algorithms` for FG-AdaHAT algorithm. 3""" 4 5__all__ = ["FGAdaHAT"] 6 7import logging 8import math 9from typing import Any 10 11import torch 12from captum.attr import ( 13 InternalInfluence, 14 LayerConductance, 15 LayerDeepLift, 16 LayerDeepLiftShap, 17 LayerFeatureAblation, 18 LayerGradCam, 19 LayerGradientShap, 20 LayerGradientXActivation, 21 LayerIntegratedGradients, 22 LayerLRP, 23) 24from torch import Tensor 25 26from clarena.backbones import HATMaskBackbone 27from clarena.cl_algorithms import AdaHAT 28from clarena.heads import HeadDIL, HeadsTIL 29from clarena.utils.metrics import HATNetworkCapacityMetric 30from clarena.utils.transforms import min_max_normalize 31 32# always get logger for built-in logging in each module 33pylogger = logging.getLogger(__name__) 34 35 36class FGAdaHAT(AdaHAT): 37 r"""FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm. 38 39 An architecture-based continual learning approach that improves [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT. 40 41 We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components. 42 """ 43 44 def __init__( 45 self, 46 backbone: HATMaskBackbone, 47 heads: HeadsTIL | HeadDIL, 48 adjustment_intensity: float, 49 importance_type: str, 50 importance_summing_strategy: str, 51 importance_scheduler_type: str, 52 neuron_to_weight_importance_aggregation_mode: str, 53 s_max: float, 54 clamp_threshold: float, 55 mask_sparsity_reg_factor: float, 56 mask_sparsity_reg_mode: str = "original", 57 base_importance: float = 0.01, 58 base_mask_sparsity_reg: float = 0.1, 59 base_linear: float = 10, 60 filter_by_cumulative_mask: bool = False, 61 filter_unmasked_importance: bool = True, 62 step_multiply_training_mask: bool = True, 63 task_embedding_init_mode: str = "N01", 64 importance_summing_strategy_linear_step: float | None = None, 65 importance_summing_strategy_exponential_rate: float | None = None, 66 importance_summing_strategy_log_base: float | None = None, 67 non_algorithmic_hparams: dict[str, Any] = {}, 68 **kwargs, 69 ) -> None: 70 r"""Initialize the FG-AdaHAT algorithm with the network. 71 72 **Args:** 73 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 74 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. FG-AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 75 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper). 76 - **importance_type** (`str`): the type of neuron-wise importance, must be one of: 77 1. 'input_weight_abs_sum': sum of absolute input weights; 78 2. 'output_weight_abs_sum': sum of absolute output weights; 79 3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper); 80 4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper); 81 5. 'activation_abs': absolute activation; 82 6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper); 83 7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper); 84 8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation; 85 9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights; 86 10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights; 87 11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper); 88 12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation; 89 13. 'conductance_abs': absolute layer conductance; 90 14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper); 91 15. 'gradcam_abs': absolute Grad-CAM; 92 16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper); 93 17. 'deepliftshap_abs': absolute DeepLIFT-SHAP; 94 18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper); 95 19. 'integrated_gradients_abs': absolute Integrated Gradients; 96 20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper); 97 21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP); 98 22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 99 23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 100 - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of: 101 1. 'add_latest': add the latest neuron-wise importance to the summative importance; 102 2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance; 103 3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance; 104 4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID; 105 5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID; 106 6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID; 107 7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID; 108 8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID; 109 9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID; 110 - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of: 111 1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`); 112 2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$; 113 3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$. 114 - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of: 115 1. 'min': take the minimum of neuron-wise importance for each weight; 116 2. 'max': take the maximum of neuron-wise importance for each weight; 117 3. 'mean': take the mean of neuron-wise importance for each weight. 118 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 119 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 120 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 121 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 122 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 123 2. 'cross': the cross version of mask sparsity regularization. 124 - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01. 125 - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1. 126 - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10. 127 - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False. 128 - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False. 129 - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True. 130 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 131 1. 'N01' (default): standard normal distribution $N(0, 1)$. 132 2. 'U-11': uniform distribution $U(-1, 1)$. 133 3. 'U01': uniform distribution $U(0, 1)$. 134 4. 'U-10': uniform distribution $U(-1, 0)$. 135 5. 'last': inherit the task embedding from the last task. 136 - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0. 137 - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1. 138 - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1. 139 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 140 - **kwargs**: Reserved for multiple inheritance. 141 142 """ 143 super().__init__( 144 backbone=backbone, 145 heads=heads, 146 adjustment_mode=None, # use the own adjustment mechanism of FG-AdaHAT 147 adjustment_intensity=adjustment_intensity, 148 s_max=s_max, 149 clamp_threshold=clamp_threshold, 150 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 151 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 152 task_embedding_init_mode=task_embedding_init_mode, 153 epsilon=base_mask_sparsity_reg, # the epsilon is now the base mask sparsity regularization factor 154 non_algorithmic_hparams=non_algorithmic_hparams, 155 **kwargs, 156 ) 157 158 # save additional algorithmic hyperparameters 159 self.save_hyperparameters( 160 "adjustment_intensity", 161 "importance_type", 162 "importance_summing_strategy", 163 "importance_scheduler_type", 164 "neuron_to_weight_importance_aggregation_mode", 165 "s_max", 166 "clamp_threshold", 167 "mask_sparsity_reg_factor", 168 "mask_sparsity_reg_mode", 169 "base_importance", 170 "base_mask_sparsity_reg", 171 "base_linear", 172 "filter_by_cumulative_mask", 173 "filter_unmasked_importance", 174 "step_multiply_training_mask", 175 ) 176 177 self.importance_type: str | None = importance_type 178 r"""The type of the neuron-wise importance added to AdaHAT importance.""" 179 180 self.importance_scheduler_type: str = importance_scheduler_type 181 r"""The type of the importance scheduler.""" 182 self.neuron_to_weight_importance_aggregation_mode: str = ( 183 neuron_to_weight_importance_aggregation_mode 184 ) 185 r"""The mode of aggregation from neuron-wise to weight-wise importance. """ 186 self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask 187 r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate.""" 188 self.filter_unmasked_importance: bool = filter_unmasked_importance 189 r"""The flag to filter unmasked importance values (set them to 0) at the end of task training.""" 190 self.step_multiply_training_mask: bool = step_multiply_training_mask 191 r"""The flag to multiply the training mask to the importance at each training step.""" 192 193 # importance summing strategy 194 self.importance_summing_strategy: str = importance_summing_strategy 195 r"""The strategy to sum the neuron-wise importance for previous tasks.""" 196 if importance_summing_strategy_linear_step is not None: 197 self.importance_summing_strategy_linear_step: float = ( 198 importance_summing_strategy_linear_step 199 ) 200 r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease').""" 201 if importance_summing_strategy_exponential_rate is not None: 202 self.importance_summing_strategy_exponential_rate: float = ( 203 importance_summing_strategy_exponential_rate 204 ) 205 r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """ 206 if importance_summing_strategy_log_base is not None: 207 self.importance_summing_strategy_log_base: float = ( 208 importance_summing_strategy_log_base 209 ) 210 r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """ 211 212 # base values 213 self.base_importance: float = base_importance 214 r"""The base value added to the importance to avoid zero. """ 215 self.base_mask_sparsity_reg: float = base_mask_sparsity_reg 216 r"""The base value added to the mask sparsity regularization to avoid zero. """ 217 self.base_linear: float = base_linear 218 r"""The base value added to the linear layer to avoid zero. """ 219 220 self.importances: dict[int, dict[str, Tensor]] = {} 221 r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 222 self.summative_importance_for_previous_tasks: dict[str, Tensor] = {} 223 r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """ 224 225 self.num_steps_t: int 226 r"""The number of training steps for the current task `self.task_id`.""" 227 # set manual optimization 228 self.automatic_optimization = False 229 230 FGAdaHAT.sanity_check(self) 231 232 def sanity_check(self) -> None: 233 r"""Sanity check.""" 234 235 # check importance type 236 if self.importance_type not in [ 237 "input_weight_abs_sum", 238 "output_weight_abs_sum", 239 "input_weight_gradient_abs_sum", 240 "output_weight_gradient_abs_sum", 241 "activation_abs", 242 "input_weight_abs_sum_x_activation_abs", 243 "output_weight_abs_sum_x_activation_abs", 244 "gradient_x_activation_abs", 245 "input_weight_gradient_square_sum", 246 "output_weight_gradient_square_sum", 247 "input_weight_gradient_square_sum_x_activation_abs", 248 "output_weight_gradient_square_sum_x_activation_abs", 249 "conductance_abs", 250 "internal_influence_abs", 251 "gradcam_abs", 252 "deeplift_abs", 253 "deepliftshap_abs", 254 "gradientshap_abs", 255 "integrated_gradients_abs", 256 "feature_ablation_abs", 257 "lrp_abs", 258 "cbp_adaptation", 259 "cbp_adaptive_contribution", 260 ]: 261 raise ValueError( 262 f"importance_type must be one of the predefined types, but got {self.importance_type}" 263 ) 264 265 # check importance summing strategy 266 if self.importance_summing_strategy not in [ 267 "add_latest", 268 "add_all", 269 "add_average", 270 "linear_decrease", 271 "quadratic_decrease", 272 "cubic_decrease", 273 "exponential_decrease", 274 "log_decrease", 275 "factorial_decrease", 276 ]: 277 raise ValueError( 278 f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}" 279 ) 280 281 # check importance scheduler type 282 if self.importance_scheduler_type not in [ 283 "linear_sparsity_reg", 284 "sparsity_reg", 285 "summative_mask_sparsity_reg", 286 ]: 287 raise ValueError( 288 f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}" 289 ) 290 291 # check neuron to weight importance aggregation mode 292 if self.neuron_to_weight_importance_aggregation_mode not in [ 293 "min", 294 "max", 295 "mean", 296 ]: 297 raise ValueError( 298 f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}" 299 ) 300 301 # check base values 302 if self.base_importance < 0: 303 raise ValueError( 304 f"base_importance must be >= 0, but got {self.base_importance}" 305 ) 306 if self.base_mask_sparsity_reg <= 0: 307 raise ValueError( 308 f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}" 309 ) 310 if self.base_linear <= 0: 311 raise ValueError(f"base_linear must be > 0, but got {self.base_linear}") 312 313 def on_train_start(self) -> None: 314 r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.""" 315 super().on_train_start() 316 317 self.importances[self.task_id] = ( 318 {} 319 ) # initialize the importance for the current task 320 321 # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time. 322 for layer_name in self.backbone.weighted_layer_names: 323 layer = self.backbone.get_layer_by_name( 324 layer_name 325 ) # get the layer by its name 326 num_units = layer.weight.shape[0] 327 328 # initialize the accumulated importance at the beginning of each task 329 self.importances[self.task_id][layer_name] = torch.zeros(num_units).to( 330 self.device 331 ) 332 333 # reset the number of steps counter for the current task 334 self.num_steps_t = 0 335 336 # initialize the summative neuron-wise importance at the beginning of the first task 337 if self.task_id == 1: 338 self.summative_importance_for_previous_tasks[layer_name] = torch.zeros( 339 num_units 340 ).to( 341 self.device 342 ) # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$ 343 344 def clip_grad_by_adjustment( 345 self, 346 network_sparsity: dict[str, Tensor], 347 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 348 r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper. 349 350 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 351 352 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 353 354 **Args:** 355 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler. 356 357 **Returns:** 358 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 359 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 360 - **capacity** (`Tensor`): the calculated network capacity. 361 """ 362 363 # initialize network capacity metric 364 capacity = HATNetworkCapacityMetric().to(self.device) 365 adjustment_rate_weight = {} 366 adjustment_rate_bias = {} 367 368 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper 369 for layer_name in self.backbone.weighted_layer_names: 370 371 layer = self.backbone.get_layer_by_name( 372 layer_name 373 ) # get the layer by its name 374 375 # placeholder for the adjustment rate to avoid the error of using it before assignment 376 adjustment_rate_weight_layer = 1 377 adjustment_rate_bias_layer = 1 378 379 # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask 380 weight_importance, bias_importance = ( 381 self.backbone.get_layer_measure_parameter_wise( 382 neuron_wise_measure=self.summative_importance_for_previous_tasks, 383 layer_name=layer_name, 384 aggregation_mode=self.neuron_to_weight_importance_aggregation_mode, 385 ) 386 ) 387 388 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 389 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 390 layer_name=layer_name, 391 aggregation_mode="min", 392 ) 393 394 # filter the weight importance by the cumulative mask 395 if self.filter_by_cumulative_mask: 396 weight_importance = weight_importance * weight_mask 397 bias_importance = bias_importance * bias_mask 398 399 network_sparsity_layer = network_sparsity[layer_name] 400 401 # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper 402 factor = network_sparsity_layer + self.base_mask_sparsity_reg 403 if self.importance_scheduler_type == "linear_sparsity_reg": 404 factor = factor * (self.task_id + self.base_linear) 405 elif self.importance_scheduler_type == "sparsity_reg": 406 pass 407 elif self.importance_scheduler_type == "summative_mask_sparsity_reg": 408 factor = factor * ( 409 self.summative_mask_for_previous_tasks + self.base_linear 410 ) 411 412 # calculate the adjustment rate 413 adjustment_rate_weight_layer = torch.div( 414 self.adjustment_intensity, 415 (factor * weight_importance + self.adjustment_intensity), 416 ) 417 418 adjustment_rate_bias_layer = torch.div( 419 self.adjustment_intensity, 420 (factor * bias_importance + self.adjustment_intensity), 421 ) 422 423 # apply the adjustment rate to the gradients 424 layer.weight.grad.data *= adjustment_rate_weight_layer 425 if layer.bias is not None: 426 layer.bias.grad.data *= adjustment_rate_bias_layer 427 428 # store the adjustment rate for logging 429 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 430 if layer.bias is not None: 431 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 432 433 # update network capacity metric 434 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 435 436 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 437 438 def on_train_batch_end( 439 self, outputs: dict[str, Any], batch: Any, batch_idx: int 440 ) -> None: 441 r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step. 442 443 **Args:** 444 - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`). 445 - **batch** (`Any`): training data batch. 446 - **batch_idx** (`int`): index of the current batch (for mask figure file name). 447 """ 448 449 # get potential useful information from training batch 450 activations = outputs["activations"] 451 input = outputs["input"] 452 target = outputs["target"] 453 mask = outputs["mask"] 454 num_batches = self.trainer.num_training_batches 455 456 for layer_name in self.backbone.weighted_layer_names: 457 # layer-wise operation 458 459 activation = activations[layer_name] 460 461 # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. 462 if self.importance_type == "input_weight_abs_sum": 463 importance_step = self.get_importance_step_layer_weight_abs_sum( 464 layer_name=layer_name, 465 if_output_weight=False, 466 reciprocal=False, 467 ) 468 elif self.importance_type == "output_weight_abs_sum": 469 importance_step = self.get_importance_step_layer_weight_abs_sum( 470 layer_name=layer_name, 471 if_output_weight=True, 472 reciprocal=False, 473 ) 474 elif self.importance_type == "input_weight_gradient_abs_sum": 475 importance_step = ( 476 self.get_importance_step_layer_weight_gradient_abs_sum( 477 layer_name=layer_name, if_output_weight=False 478 ) 479 ) 480 elif self.importance_type == "output_weight_gradient_abs_sum": 481 importance_step = ( 482 self.get_importance_step_layer_weight_gradient_abs_sum( 483 layer_name=layer_name, if_output_weight=True 484 ) 485 ) 486 elif self.importance_type == "activation_abs": 487 importance_step = self.get_importance_step_layer_activation_abs( 488 activation=activation 489 ) 490 elif self.importance_type == "input_weight_abs_sum_x_activation_abs": 491 importance_step = ( 492 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 493 layer_name=layer_name, 494 activation=activation, 495 if_output_weight=False, 496 ) 497 ) 498 elif self.importance_type == "output_weight_abs_sum_x_activation_abs": 499 importance_step = ( 500 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 501 layer_name=layer_name, 502 activation=activation, 503 if_output_weight=True, 504 ) 505 ) 506 elif self.importance_type == "gradient_x_activation_abs": 507 importance_step = ( 508 self.get_importance_step_layer_gradient_x_activation_abs( 509 layer_name=layer_name, 510 input=input, 511 target=target, 512 batch_idx=batch_idx, 513 num_batches=num_batches, 514 ) 515 ) 516 elif self.importance_type == "input_weight_gradient_square_sum": 517 importance_step = ( 518 self.get_importance_step_layer_weight_gradient_square_sum( 519 layer_name=layer_name, 520 activation=activation, 521 if_output_weight=False, 522 ) 523 ) 524 elif self.importance_type == "output_weight_gradient_square_sum": 525 importance_step = ( 526 self.get_importance_step_layer_weight_gradient_square_sum( 527 layer_name=layer_name, 528 activation=activation, 529 if_output_weight=True, 530 ) 531 ) 532 elif ( 533 self.importance_type 534 == "input_weight_gradient_square_sum_x_activation_abs" 535 ): 536 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 537 layer_name=layer_name, 538 activation=activation, 539 if_output_weight=False, 540 ) 541 elif ( 542 self.importance_type 543 == "output_weight_gradient_square_sum_x_activation_abs" 544 ): 545 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 546 layer_name=layer_name, 547 activation=activation, 548 if_output_weight=True, 549 ) 550 elif self.importance_type == "conductance_abs": 551 importance_step = self.get_importance_step_layer_conductance_abs( 552 layer_name=layer_name, 553 input=input, 554 baselines=None, 555 target=target, 556 batch_idx=batch_idx, 557 num_batches=num_batches, 558 ) 559 elif self.importance_type == "internal_influence_abs": 560 importance_step = self.get_importance_step_layer_internal_influence_abs( 561 layer_name=layer_name, 562 input=input, 563 baselines=None, 564 target=target, 565 batch_idx=batch_idx, 566 num_batches=num_batches, 567 ) 568 elif self.importance_type == "gradcam_abs": 569 importance_step = self.get_importance_step_layer_gradcam_abs( 570 layer_name=layer_name, 571 input=input, 572 target=target, 573 batch_idx=batch_idx, 574 num_batches=num_batches, 575 ) 576 elif self.importance_type == "deeplift_abs": 577 importance_step = self.get_importance_step_layer_deeplift_abs( 578 layer_name=layer_name, 579 input=input, 580 baselines=None, 581 target=target, 582 batch_idx=batch_idx, 583 num_batches=num_batches, 584 ) 585 elif self.importance_type == "deepliftshap_abs": 586 importance_step = self.get_importance_step_layer_deepliftshap_abs( 587 layer_name=layer_name, 588 input=input, 589 baselines=None, 590 target=target, 591 batch_idx=batch_idx, 592 num_batches=num_batches, 593 ) 594 elif self.importance_type == "gradientshap_abs": 595 importance_step = self.get_importance_step_layer_gradientshap_abs( 596 layer_name=layer_name, 597 input=input, 598 baselines=None, 599 target=target, 600 batch_idx=batch_idx, 601 num_batches=num_batches, 602 ) 603 elif self.importance_type == "integrated_gradients_abs": 604 importance_step = ( 605 self.get_importance_step_layer_integrated_gradients_abs( 606 layer_name=layer_name, 607 input=input, 608 baselines=None, 609 target=target, 610 batch_idx=batch_idx, 611 num_batches=num_batches, 612 ) 613 ) 614 elif self.importance_type == "feature_ablation_abs": 615 importance_step = self.get_importance_step_layer_feature_ablation_abs( 616 layer_name=layer_name, 617 input=input, 618 layer_baselines=None, 619 target=target, 620 batch_idx=batch_idx, 621 num_batches=num_batches, 622 ) 623 elif self.importance_type == "lrp_abs": 624 importance_step = self.get_importance_step_layer_lrp_abs( 625 layer_name=layer_name, 626 input=input, 627 target=target, 628 batch_idx=batch_idx, 629 num_batches=num_batches, 630 ) 631 elif self.importance_type == "cbp_adaptation": 632 importance_step = self.get_importance_step_layer_weight_abs_sum( 633 layer_name=layer_name, 634 if_output_weight=False, 635 reciprocal=True, 636 ) 637 elif self.importance_type == "cbp_adaptive_contribution": 638 importance_step = ( 639 self.get_importance_step_layer_cbp_adaptive_contribution( 640 layer_name=layer_name, 641 activation=activation, 642 ) 643 ) 644 645 importance_step = min_max_normalize( 646 importance_step 647 ) # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper 648 649 # multiply the importance by the training mask. See Eq. (6) in the paper 650 if self.step_multiply_training_mask: 651 importance_step = importance_step * mask[layer_name] 652 653 # update accumulated importance 654 self.importances[self.task_id][layer_name] = ( 655 self.importances[self.task_id][layer_name] + importance_step 656 ) 657 658 # update number of steps counter 659 self.num_steps_t += 1 660 661 def on_train_end(self) -> None: 662 r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task.""" 663 super().on_train_end() # store the mask and update cumulative and summative masks 664 665 for layer_name in self.backbone.weighted_layer_names: 666 667 # average the neuron-wise step importance. See Eq. (4) in the paper 668 self.importances[self.task_id][layer_name] = ( 669 self.importances[self.task_id][layer_name] 670 ) / self.num_steps_t 671 672 # add the base importance. See Eq. (6) in the paper 673 self.importances[self.task_id][layer_name] = ( 674 self.importances[self.task_id][layer_name] + self.base_importance 675 ) 676 677 # filter unmasked importance 678 if self.filter_unmasked_importance: 679 self.importances[self.task_id][layer_name] = ( 680 self.importances[self.task_id][layer_name] 681 * self.backbone.masks[f"{self.task_id}"][layer_name] 682 ) 683 684 # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper 685 if self.importance_summing_strategy == "add_latest": 686 self.summative_importance_for_previous_tasks[ 687 layer_name 688 ] += self.importances[self.task_id][layer_name] 689 690 elif self.importance_summing_strategy == "add_all": 691 for t in range(1, self.task_id + 1): 692 self.summative_importance_for_previous_tasks[ 693 layer_name 694 ] += self.importances[t][layer_name] 695 696 elif self.importance_summing_strategy == "add_average": 697 for t in range(1, self.task_id + 1): 698 self.summative_importance_for_previous_tasks[layer_name] += ( 699 self.importances[t][layer_name] / self.task_id 700 ) 701 else: 702 self.summative_importance_for_previous_tasks[ 703 layer_name 704 ] = torch.zeros_like( 705 self.summative_importance_for_previous_tasks[layer_name] 706 ).to( 707 self.device 708 ) # starting adding from 0 709 710 if self.importance_summing_strategy == "linear_decrease": 711 s = self.importance_summing_strategy_linear_step 712 for t in range(1, self.task_id + 1): 713 w_t = s * (self.task_id - t) + 1 714 715 elif self.importance_summing_strategy == "quadratic_decrease": 716 for t in range(1, self.task_id + 1): 717 w_t = (self.task_id - t + 1) ** 2 718 elif self.importance_summing_strategy == "cubic_decrease": 719 for t in range(1, self.task_id + 1): 720 w_t = (self.task_id - t + 1) ** 3 721 elif self.importance_summing_strategy == "exponential_decrease": 722 for t in range(1, self.task_id + 1): 723 r = self.importance_summing_strategy_exponential_rate 724 725 w_t = r ** (self.task_id - t + 1) 726 elif self.importance_summing_strategy == "log_decrease": 727 a = self.importance_summing_strategy_log_base 728 for t in range(1, self.task_id + 1): 729 w_t = math.log(self.task_id - t, a) + 1 730 elif self.importance_summing_strategy == "factorial_decrease": 731 for t in range(1, self.task_id + 1): 732 w_t = math.factorial(self.task_id - t + 1) 733 else: 734 raise ValueError 735 self.summative_importance_for_previous_tasks[layer_name] += ( 736 self.importances[t][layer_name] * w_t 737 ) 738 739 def get_importance_step_layer_weight_abs_sum( 740 self: str, 741 layer_name: str, 742 if_output_weight: bool, 743 reciprocal: bool, 744 ) -> Tensor: 745 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights. 746 747 **Args:** 748 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 749 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 750 - **reciprocal** (`bool`): whether to take reciprocal. 751 752 **Returns:** 753 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 754 """ 755 layer = self.backbone.get_layer_by_name(layer_name) 756 757 if not if_output_weight: 758 weight_abs = torch.abs(layer.weight.data) 759 weight_abs_sum = torch.sum( 760 weight_abs, 761 dim=[ 762 i for i in range(weight_abs.dim()) if i != 0 763 ], # sum over the input dimension 764 ) 765 else: 766 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 767 weight_abs_sum = torch.sum( 768 weight_abs, 769 dim=[ 770 i for i in range(weight_abs.dim()) if i != 1 771 ], # sum over the output dimension 772 ) 773 774 if reciprocal: 775 weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum) 776 importance_step_layer = weight_abs_sum_reciprocal 777 else: 778 importance_step_layer = weight_abs_sum 779 importance_step_layer = importance_step_layer.detach() 780 781 return importance_step_layer 782 783 def get_importance_step_layer_weight_gradient_abs_sum( 784 self: str, 785 layer_name: str, 786 if_output_weight: bool, 787 ) -> Tensor: 788 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights. 789 790 **Args:** 791 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 792 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 793 794 **Returns:** 795 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 796 """ 797 layer = self.backbone.get_layer_by_name(layer_name) 798 799 if not if_output_weight: 800 gradient_abs = torch.abs(layer.weight.grad.data) 801 gradient_abs_sum = torch.sum( 802 gradient_abs, 803 dim=[ 804 i for i in range(gradient_abs.dim()) if i != 0 805 ], # sum over the input dimension 806 ) 807 else: 808 gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data) 809 gradient_abs_sum = torch.sum( 810 gradient_abs, 811 dim=[ 812 i for i in range(gradient_abs.dim()) if i != 1 813 ], # sum over the output dimension 814 ) 815 816 importance_step_layer = gradient_abs_sum 817 importance_step_layer = importance_step_layer.detach() 818 819 return importance_step_layer 820 821 def get_importance_step_layer_activation_abs( 822 self: str, 823 activation: Tensor, 824 ) -> Tensor: 825 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum. 826 827 **Args:** 828 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 829 830 **Returns:** 831 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 832 """ 833 activation_abs_batch_mean = torch.mean( 834 torch.abs(activation), 835 dim=[ 836 i for i in range(activation.dim()) if i != 1 837 ], # average the features over batch samples 838 ) 839 importance_step_layer = activation_abs_batch_mean 840 importance_step_layer = importance_step_layer.detach() 841 842 return importance_step_layer 843 844 def get_importance_step_layer_weight_abs_sum_x_activation_abs( 845 self: str, 846 layer_name: str, 847 activation: Tensor, 848 if_output_weight: bool, 849 ) -> Tensor: 850 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 851 852 **Args:** 853 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 854 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 855 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 856 857 **Returns:** 858 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 859 """ 860 layer = self.backbone.get_layer_by_name(layer_name) 861 862 if not if_output_weight: 863 weight_abs = torch.abs(layer.weight.data) 864 weight_abs_sum = torch.sum( 865 weight_abs, 866 dim=[ 867 i for i in range(weight_abs.dim()) if i != 0 868 ], # sum over the input dimension 869 ) 870 else: 871 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 872 weight_abs_sum = torch.sum( 873 weight_abs, 874 dim=[ 875 i for i in range(weight_abs.dim()) if i != 1 876 ], # sum over the output dimension 877 ) 878 879 activation_abs_batch_mean = torch.mean( 880 torch.abs(activation), 881 dim=[ 882 i for i in range(activation.dim()) if i != 1 883 ], # average the features over batch samples 884 ) 885 886 importance_step_layer = weight_abs_sum * activation_abs_batch_mean 887 importance_step_layer = importance_step_layer.detach() 888 889 return importance_step_layer 890 891 def get_importance_step_layer_gradient_x_activation_abs( 892 self: str, 893 layer_name: str, 894 input: Tensor | tuple[Tensor, ...], 895 target: Tensor | None, 896 batch_idx: int, 897 num_batches: int, 898 ) -> Tensor: 899 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum. 900 901 **Args:** 902 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 903 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 904 - **target** (`Tensor` | `None`): the target batch of the training step. 905 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 906 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 907 908 **Returns:** 909 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 910 """ 911 layer = self.backbone.get_layer_by_name(layer_name) 912 913 input = input.requires_grad_() 914 915 # initialize the Layer Gradient X Activation object 916 layer_gradient_x_activation = LayerGradientXActivation( 917 forward_func=self.forward, layer=layer 918 ) 919 920 self.set_forward_func_return_logits_only(True) 921 # calculate layer attribution of the step 922 attribution = layer_gradient_x_activation.attribute( 923 inputs=input, 924 target=target, 925 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 926 ) 927 self.set_forward_func_return_logits_only(False) 928 929 attribution_abs_batch_mean = torch.mean( 930 torch.abs(attribution), 931 dim=[ 932 i for i in range(attribution.dim()) if i != 1 933 ], # average the features over batch samples 934 ) 935 936 importance_step_layer = attribution_abs_batch_mean 937 importance_step_layer = importance_step_layer.detach() 938 939 return importance_step_layer 940 941 def get_importance_step_layer_weight_gradient_square_sum( 942 self: str, 943 layer_name: str, 944 activation: Tensor, 945 if_output_weight: bool, 946 ) -> Tensor: 947 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 948 949 **Args:** 950 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 951 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 952 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 953 954 **Returns:** 955 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 956 """ 957 layer = self.backbone.get_layer_by_name(layer_name) 958 959 if not if_output_weight: 960 gradient_square = layer.weight.grad.data**2 961 gradient_square_sum = torch.sum( 962 gradient_square, 963 dim=[ 964 i for i in range(gradient_square.dim()) if i != 0 965 ], # sum over the input dimension 966 ) 967 else: 968 gradient_square = self.next_layer(layer_name).weight.grad.data**2 969 gradient_square_sum = torch.sum( 970 gradient_square, 971 dim=[ 972 i for i in range(gradient_square.dim()) if i != 1 973 ], # sum over the output dimension 974 ) 975 976 importance_step_layer = gradient_square_sum 977 importance_step_layer = importance_step_layer.detach() 978 979 return importance_step_layer 980 981 def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 982 self: str, 983 layer_name: str, 984 activation: Tensor, 985 if_output_weight: bool, 986 ) -> Tensor: 987 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 988 989 **Args:** 990 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 991 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 992 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 993 994 **Returns:** 995 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 996 """ 997 layer = self.backbone.get_layer_by_name(layer_name) 998 999 if not if_output_weight: 1000 gradient_square = layer.weight.grad.data**2 1001 gradient_square_sum = torch.sum( 1002 gradient_square, 1003 dim=[ 1004 i for i in range(gradient_square.dim()) if i != 0 1005 ], # sum over the input dimension 1006 ) 1007 else: 1008 gradient_square = self.next_layer(layer_name).weight.grad.data**2 1009 gradient_square_sum = torch.sum( 1010 gradient_square, 1011 dim=[ 1012 i for i in range(gradient_square.dim()) if i != 1 1013 ], # sum over the output dimension 1014 ) 1015 1016 activation_abs_batch_mean = torch.mean( 1017 torch.abs(activation), 1018 dim=[ 1019 i for i in range(activation.dim()) if i != 1 1020 ], # average the features over batch samples 1021 ) 1022 1023 importance_step_layer = gradient_square_sum * activation_abs_batch_mean 1024 importance_step_layer = importance_step_layer.detach() 1025 1026 return importance_step_layer 1027 1028 def get_importance_step_layer_conductance_abs( 1029 self: str, 1030 layer_name: str, 1031 input: Tensor | tuple[Tensor, ...], 1032 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1033 target: Tensor | None, 1034 batch_idx: int, 1035 num_batches: int, 1036 ) -> Tensor: 1037 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum. 1038 1039 **Args:** 1040 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1041 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1042 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details. 1043 - **target** (`Tensor` | `None`): the target batch of the training step. 1044 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1045 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ). 1046 1047 **Returns:** 1048 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1049 """ 1050 layer = self.backbone.get_layer_by_name(layer_name) 1051 1052 # initialize the Layer Conductance object 1053 layer_conductance = LayerConductance(forward_func=self.forward, layer=layer) 1054 1055 self.set_forward_func_return_logits_only(True) 1056 # calculate layer attribution of the step 1057 attribution = layer_conductance.attribute( 1058 inputs=input, 1059 baselines=baselines, 1060 target=target, 1061 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1062 ) 1063 self.set_forward_func_return_logits_only(False) 1064 1065 attribution_abs_batch_mean = torch.mean( 1066 torch.abs(attribution), 1067 dim=[ 1068 i for i in range(attribution.dim()) if i != 1 1069 ], # average the features over batch samples 1070 ) 1071 1072 importance_step_layer = attribution_abs_batch_mean 1073 importance_step_layer = importance_step_layer.detach() 1074 1075 return importance_step_layer 1076 1077 def get_importance_step_layer_internal_influence_abs( 1078 self: str, 1079 layer_name: str, 1080 input: Tensor | tuple[Tensor, ...], 1081 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1082 target: Tensor | None, 1083 batch_idx: int, 1084 num_batches: int, 1085 ) -> Tensor: 1086 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum. 1087 1088 **Args:** 1089 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1090 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1091 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details. 1092 - **target** (`Tensor` | `None`): the target batch of the training step. 1093 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1094 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1095 1096 **Returns:** 1097 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1098 """ 1099 layer = self.backbone.get_layer_by_name(layer_name) 1100 1101 # initialize the Internal Influence object 1102 internal_influence = InternalInfluence(forward_func=self.forward, layer=layer) 1103 1104 # convert the target to long type to avoid error 1105 target = target.long() if target is not None else None 1106 1107 self.set_forward_func_return_logits_only(True) 1108 # calculate layer attribution of the step 1109 attribution = internal_influence.attribute( 1110 inputs=input, 1111 baselines=baselines, 1112 target=target, 1113 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1114 n_steps=5, # set 10 instead of default 50 to accelerate the computation 1115 ) 1116 self.set_forward_func_return_logits_only(False) 1117 1118 attribution_abs_batch_mean = torch.mean( 1119 torch.abs(attribution), 1120 dim=[ 1121 i for i in range(attribution.dim()) if i != 1 1122 ], # average the features over batch samples 1123 ) 1124 1125 importance_step_layer = attribution_abs_batch_mean 1126 importance_step_layer = importance_step_layer.detach() 1127 1128 return importance_step_layer 1129 1130 def get_importance_step_layer_gradcam_abs( 1131 self: str, 1132 layer_name: str, 1133 input: Tensor | tuple[Tensor, ...], 1134 target: Tensor | None, 1135 batch_idx: int, 1136 num_batches: int, 1137 ) -> Tensor: 1138 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum. 1139 1140 **Args:** 1141 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1142 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1143 - **target** (`Tensor` | `None`): the target batch of the training step. 1144 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1145 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1146 1147 **Returns:** 1148 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1149 """ 1150 layer = self.backbone.get_layer_by_name(layer_name) 1151 1152 # initialize the GradCAM object 1153 gradcam = LayerGradCam(forward_func=self.forward, layer=layer) 1154 1155 self.set_forward_func_return_logits_only(True) 1156 # calculate layer attribution of the step 1157 attribution = gradcam.attribute( 1158 inputs=input, 1159 target=target, 1160 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1161 ) 1162 self.set_forward_func_return_logits_only(False) 1163 1164 attribution_abs_batch_mean = torch.mean( 1165 torch.abs(attribution), 1166 dim=[ 1167 i for i in range(attribution.dim()) if i != 1 1168 ], # average the features over batch samples 1169 ) 1170 1171 importance_step_layer = attribution_abs_batch_mean 1172 importance_step_layer = importance_step_layer.detach() 1173 1174 return importance_step_layer 1175 1176 def get_importance_step_layer_deeplift_abs( 1177 self: str, 1178 layer_name: str, 1179 input: Tensor | tuple[Tensor, ...], 1180 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1181 target: Tensor | None, 1182 batch_idx: int, 1183 num_batches: int, 1184 ) -> Tensor: 1185 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum. 1186 1187 **Args:** 1188 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1189 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1190 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details. 1191 - **target** (`Tensor` | `None`): the target batch of the training step. 1192 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1193 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1194 1195 **Returns:** 1196 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1197 """ 1198 layer = self.backbone.get_layer_by_name(layer_name) 1199 1200 # initialize the Layer DeepLift object 1201 layer_deeplift = LayerDeepLift(model=self, layer=layer) 1202 1203 # convert the target to long type to avoid error 1204 target = target.long() if target is not None else None 1205 1206 self.set_forward_func_return_logits_only(True) 1207 # calculate layer attribution of the step 1208 attribution = layer_deeplift.attribute( 1209 inputs=input, 1210 baselines=baselines, 1211 target=target, 1212 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1213 ) 1214 self.set_forward_func_return_logits_only(False) 1215 1216 attribution_abs_batch_mean = torch.mean( 1217 torch.abs(attribution), 1218 dim=[ 1219 i for i in range(attribution.dim()) if i != 1 1220 ], # average the features over batch samples 1221 ) 1222 1223 importance_step_layer = attribution_abs_batch_mean 1224 importance_step_layer = importance_step_layer.detach() 1225 1226 return importance_step_layer 1227 1228 def get_importance_step_layer_deepliftshap_abs( 1229 self: str, 1230 layer_name: str, 1231 input: Tensor | tuple[Tensor, ...], 1232 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1233 target: Tensor | None, 1234 batch_idx: int, 1235 num_batches: int, 1236 ) -> Tensor: 1237 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum. 1238 1239 **Args:** 1240 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1241 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1242 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details. 1243 - **target** (`Tensor` | `None`): the target batch of the training step. 1244 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1245 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1246 1247 **Returns:** 1248 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1249 """ 1250 layer = self.backbone.get_layer_by_name(layer_name) 1251 1252 # initialize the Layer DeepLiftShap object 1253 layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer) 1254 1255 # convert the target to long type to avoid error 1256 target = target.long() if target is not None else None 1257 1258 self.set_forward_func_return_logits_only(True) 1259 # calculate layer attribution of the step 1260 attribution = layer_deepliftshap.attribute( 1261 inputs=input, 1262 baselines=baselines, 1263 target=target, 1264 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1265 ) 1266 self.set_forward_func_return_logits_only(False) 1267 1268 attribution_abs_batch_mean = torch.mean( 1269 torch.abs(attribution), 1270 dim=[ 1271 i for i in range(attribution.dim()) if i != 1 1272 ], # average the features over batch samples 1273 ) 1274 1275 importance_step_layer = attribution_abs_batch_mean 1276 importance_step_layer = importance_step_layer.detach() 1277 1278 return importance_step_layer 1279 1280 def get_importance_step_layer_gradientshap_abs( 1281 self: str, 1282 layer_name: str, 1283 input: Tensor | tuple[Tensor, ...], 1284 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1285 target: Tensor | None, 1286 batch_idx: int, 1287 num_batches: int, 1288 ) -> Tensor: 1289 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum. 1290 1291 **Args:** 1292 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1293 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1294 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero. 1295 - **target** (`Tensor` | `None`): the target batch of the training step. 1296 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1297 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1298 1299 **Returns:** 1300 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1301 """ 1302 layer = self.backbone.get_layer_by_name(layer_name) 1303 1304 if baselines is None: 1305 baselines = torch.zeros_like( 1306 input 1307 ) # baselines are mandatory for GradientShap API. We explicitly set them to zero 1308 1309 # initialize the Layer GradientShap object 1310 layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer) 1311 1312 # convert the target to long type to avoid error 1313 target = target.long() if target is not None else None 1314 1315 self.set_forward_func_return_logits_only(True) 1316 # calculate layer attribution of the step 1317 attribution = layer_gradientshap.attribute( 1318 inputs=input, 1319 baselines=baselines, 1320 target=target, 1321 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1322 ) 1323 self.set_forward_func_return_logits_only(False) 1324 1325 attribution_abs_batch_mean = torch.mean( 1326 torch.abs(attribution), 1327 dim=[ 1328 i for i in range(attribution.dim()) if i != 1 1329 ], # average the features over batch samples 1330 ) 1331 1332 importance_step_layer = attribution_abs_batch_mean 1333 importance_step_layer = importance_step_layer.detach() 1334 1335 return importance_step_layer 1336 1337 def get_importance_step_layer_integrated_gradients_abs( 1338 self: str, 1339 layer_name: str, 1340 input: Tensor | tuple[Tensor, ...], 1341 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1342 target: Tensor | None, 1343 batch_idx: int, 1344 num_batches: int, 1345 ) -> Tensor: 1346 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum. 1347 1348 **Args:** 1349 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1350 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1351 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details. 1352 - **target** (`Tensor` | `None`): the target batch of the training step. 1353 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1354 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1355 1356 **Returns:** 1357 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1358 """ 1359 layer = self.backbone.get_layer_by_name(layer_name) 1360 1361 # initialize the Layer Integrated Gradients object 1362 layer_integrated_gradients = LayerIntegratedGradients( 1363 forward_func=self.forward, layer=layer 1364 ) 1365 1366 self.set_forward_func_return_logits_only(True) 1367 # calculate layer attribution of the step 1368 attribution = layer_integrated_gradients.attribute( 1369 inputs=input, 1370 baselines=baselines, 1371 target=target, 1372 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1373 ) 1374 self.set_forward_func_return_logits_only(False) 1375 1376 attribution_abs_batch_mean = torch.mean( 1377 torch.abs(attribution), 1378 dim=[ 1379 i for i in range(attribution.dim()) if i != 1 1380 ], # average the features over batch samples 1381 ) 1382 1383 importance_step_layer = attribution_abs_batch_mean 1384 importance_step_layer = importance_step_layer.detach() 1385 1386 return importance_step_layer 1387 1388 def get_importance_step_layer_feature_ablation_abs( 1389 self: str, 1390 layer_name: str, 1391 input: Tensor | tuple[Tensor, ...], 1392 layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1393 target: Tensor | None, 1394 batch_idx: int, 1395 num_batches: int, 1396 if_captum: bool = False, 1397 ) -> Tensor: 1398 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum. 1399 1400 **Args:** 1401 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1402 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1403 - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details. 1404 - **target** (`Tensor` | `None`): the target batch of the training step. 1405 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1406 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1407 - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster. 1408 1409 **Returns:** 1410 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1411 """ 1412 layer = self.backbone.get_layer_by_name(layer_name) 1413 1414 if not if_captum: 1415 # 1. Baseline logits (take first element of forward output) 1416 baseline_out, _, _ = self.forward( 1417 input, "train", batch_idx, num_batches, self.task_id 1418 ) 1419 if target is not None: 1420 baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1) 1421 else: 1422 baseline_scores = baseline_out.sum(dim=1) 1423 1424 # 2. Capture layer’s output shape 1425 activs = {} 1426 handle = layer.register_forward_hook( 1427 lambda module, inp, out: activs.setdefault("output", out.detach()) 1428 ) 1429 _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id) 1430 handle.remove() 1431 layer_output = activs["output"] # shape (B, F, ...) 1432 1433 # 3. Build baseline tensor matching that shape 1434 if layer_baselines is None: 1435 baseline_tensor = torch.zeros_like(layer_output) 1436 elif isinstance(layer_baselines, (int, float)): 1437 baseline_tensor = torch.full_like(layer_output, layer_baselines) 1438 elif isinstance(layer_baselines, Tensor): 1439 if layer_baselines.shape == layer_output.shape: 1440 baseline_tensor = layer_baselines 1441 elif layer_baselines.shape == layer_output.shape[1:]: 1442 baseline_tensor = layer_baselines.unsqueeze(0).repeat( 1443 layer_output.size(0), *([1] * layer_baselines.ndim) 1444 ) 1445 else: 1446 raise ValueError(...) 1447 else: 1448 raise ValueError(...) 1449 1450 B, F = layer_output.size(0), layer_output.size(1) 1451 1452 # 4. Create a “mega-batch” replicating the input F times 1453 if isinstance(input, tuple): 1454 mega_inputs = tuple( 1455 t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:]) 1456 for t in input 1457 ) 1458 else: 1459 mega_inputs = ( 1460 input.unsqueeze(0) 1461 .repeat(F, *([1] * input.ndim)) 1462 .view(-1, *input.shape[1:]) 1463 ) 1464 1465 # 5. Equally replicate the baseline tensor 1466 mega_baseline = ( 1467 baseline_tensor.unsqueeze(0) 1468 .repeat(F, *([1] * baseline_tensor.ndim)) 1469 .view(-1, *baseline_tensor.shape[1:]) 1470 ) 1471 1472 # 6. Precompute vectorized indices 1473 device = layer_output.device 1474 positions = torch.arange(F * B, device=device) # [0,1,...,F*B-1] 1475 feat_idx = torch.arange(F, device=device).repeat_interleave( 1476 B 1477 ) # [0,0,...,1,1,...,F-1] 1478 1479 # 7. One hook to zero out each channel slice across the mega-batch 1480 def mega_ablate_hook(module, inp, out): 1481 out_mod = out.clone() 1482 # for each sample in mega-batch, zero its corresponding channel 1483 out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx] 1484 return out_mod 1485 1486 h = layer.register_forward_hook(mega_ablate_hook) 1487 out_all, _, _ = self.forward( 1488 mega_inputs, "train", batch_idx, num_batches, self.task_id 1489 ) 1490 h.remove() 1491 1492 # 8. Recover scores, reshape [F*B] → [F, B], diff & mean 1493 if target is not None: 1494 tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1) 1495 scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1) 1496 else: 1497 scores_all = out_all.sum(dim=1) 1498 1499 scores_all = scores_all.view(F, B) 1500 diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all) 1501 importance_step_layer = diffs.mean(dim=1).detach() # [F] 1502 1503 return importance_step_layer 1504 1505 else: 1506 # initialize the Layer Feature Ablation object 1507 layer_feature_ablation = LayerFeatureAblation( 1508 forward_func=self.forward, layer=layer 1509 ) 1510 1511 # calculate layer attribution of the step 1512 self.set_forward_func_return_logits_only(True) 1513 attribution = layer_feature_ablation.attribute( 1514 inputs=input, 1515 layer_baselines=layer_baselines, 1516 # target=target, # disable target to enable perturbations_per_eval 1517 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1518 perturbations_per_eval=128, # to accelerate the computation 1519 ) 1520 self.set_forward_func_return_logits_only(False) 1521 1522 attribution_abs_batch_mean = torch.mean( 1523 torch.abs(attribution), 1524 dim=[ 1525 i for i in range(attribution.dim()) if i != 1 1526 ], # average the features over batch samples 1527 ) 1528 1529 importance_step_layer = attribution_abs_batch_mean 1530 importance_step_layer = importance_step_layer.detach() 1531 1532 return importance_step_layer 1533 1534 def get_importance_step_layer_lrp_abs( 1535 self: str, 1536 layer_name: str, 1537 input: Tensor | tuple[Tensor, ...], 1538 target: Tensor | None, 1539 batch_idx: int, 1540 num_batches: int, 1541 ) -> Tensor: 1542 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum. 1543 1544 **Args:** 1545 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1546 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1547 - **target** (`Tensor` | `None`): the target batch of the training step. 1548 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1549 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1550 1551 **Returns:** 1552 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1553 """ 1554 layer = self.backbone.get_layer_by_name(layer_name) 1555 1556 # initialize the Layer LRP object 1557 layer_lrp = LayerLRP(model=self, layer=layer) 1558 1559 # set model to evaluation mode to prevent updating the model parameters 1560 self.eval() 1561 1562 self.set_forward_func_return_logits_only(True) 1563 # calculate layer attribution of the step 1564 attribution = layer_lrp.attribute( 1565 inputs=input, 1566 target=target, 1567 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1568 ) 1569 self.set_forward_func_return_logits_only(False) 1570 1571 attribution_abs_batch_mean = torch.mean( 1572 torch.abs(attribution), 1573 dim=[ 1574 i for i in range(attribution.dim()) if i != 1 1575 ], # average the features over batch samples 1576 ) 1577 1578 importance_step_layer = attribution_abs_batch_mean 1579 importance_step_layer = importance_step_layer.detach() 1580 1581 return importance_step_layer 1582 1583 def get_importance_step_layer_cbp_adaptive_contribution( 1584 self: str, 1585 layer_name: str, 1586 activation: Tensor, 1587 ) -> Tensor: 1588 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 1589 1590 **Args:** 1591 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1592 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 1593 1594 **Returns:** 1595 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1596 """ 1597 layer = self.backbone.get_layer_by_name(layer_name) 1598 1599 input_weight_abs = torch.abs(layer.weight.data) 1600 input_weight_abs_sum = torch.sum( 1601 input_weight_abs, 1602 dim=[ 1603 i for i in range(input_weight_abs.dim()) if i != 0 1604 ], # sum over the input dimension 1605 ) 1606 input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum) 1607 1608 output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 1609 output_weight_abs_sum = torch.sum( 1610 output_weight_abs, 1611 dim=[ 1612 i for i in range(output_weight_abs.dim()) if i != 1 1613 ], # sum over the output dimension 1614 ) 1615 1616 activation_abs_batch_mean = torch.mean( 1617 torch.abs(activation), 1618 dim=[ 1619 i for i in range(activation.dim()) if i != 1 1620 ], # average the features over batch samples 1621 ) 1622 1623 importance_step_layer = ( 1624 output_weight_abs_sum 1625 * activation_abs_batch_mean 1626 * input_weight_abs_sum_reciprocal 1627 ) 1628 importance_step_layer = importance_step_layer.detach() 1629 1630 return importance_step_layer
37class FGAdaHAT(AdaHAT): 38 r"""FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm. 39 40 An architecture-based continual learning approach that improves [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT. 41 42 We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components. 43 """ 44 45 def __init__( 46 self, 47 backbone: HATMaskBackbone, 48 heads: HeadsTIL | HeadDIL, 49 adjustment_intensity: float, 50 importance_type: str, 51 importance_summing_strategy: str, 52 importance_scheduler_type: str, 53 neuron_to_weight_importance_aggregation_mode: str, 54 s_max: float, 55 clamp_threshold: float, 56 mask_sparsity_reg_factor: float, 57 mask_sparsity_reg_mode: str = "original", 58 base_importance: float = 0.01, 59 base_mask_sparsity_reg: float = 0.1, 60 base_linear: float = 10, 61 filter_by_cumulative_mask: bool = False, 62 filter_unmasked_importance: bool = True, 63 step_multiply_training_mask: bool = True, 64 task_embedding_init_mode: str = "N01", 65 importance_summing_strategy_linear_step: float | None = None, 66 importance_summing_strategy_exponential_rate: float | None = None, 67 importance_summing_strategy_log_base: float | None = None, 68 non_algorithmic_hparams: dict[str, Any] = {}, 69 **kwargs, 70 ) -> None: 71 r"""Initialize the FG-AdaHAT algorithm with the network. 72 73 **Args:** 74 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 75 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. FG-AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 76 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper). 77 - **importance_type** (`str`): the type of neuron-wise importance, must be one of: 78 1. 'input_weight_abs_sum': sum of absolute input weights; 79 2. 'output_weight_abs_sum': sum of absolute output weights; 80 3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper); 81 4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper); 82 5. 'activation_abs': absolute activation; 83 6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper); 84 7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper); 85 8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation; 86 9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights; 87 10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights; 88 11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper); 89 12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation; 90 13. 'conductance_abs': absolute layer conductance; 91 14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper); 92 15. 'gradcam_abs': absolute Grad-CAM; 93 16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper); 94 17. 'deepliftshap_abs': absolute DeepLIFT-SHAP; 95 18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper); 96 19. 'integrated_gradients_abs': absolute Integrated Gradients; 97 20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper); 98 21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP); 99 22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 100 23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 101 - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of: 102 1. 'add_latest': add the latest neuron-wise importance to the summative importance; 103 2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance; 104 3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance; 105 4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID; 106 5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID; 107 6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID; 108 7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID; 109 8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID; 110 9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID; 111 - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of: 112 1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`); 113 2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$; 114 3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$. 115 - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of: 116 1. 'min': take the minimum of neuron-wise importance for each weight; 117 2. 'max': take the maximum of neuron-wise importance for each weight; 118 3. 'mean': take the mean of neuron-wise importance for each weight. 119 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 120 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 121 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 122 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 123 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 124 2. 'cross': the cross version of mask sparsity regularization. 125 - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01. 126 - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1. 127 - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10. 128 - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False. 129 - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False. 130 - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True. 131 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 132 1. 'N01' (default): standard normal distribution $N(0, 1)$. 133 2. 'U-11': uniform distribution $U(-1, 1)$. 134 3. 'U01': uniform distribution $U(0, 1)$. 135 4. 'U-10': uniform distribution $U(-1, 0)$. 136 5. 'last': inherit the task embedding from the last task. 137 - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0. 138 - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1. 139 - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1. 140 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 141 - **kwargs**: Reserved for multiple inheritance. 142 143 """ 144 super().__init__( 145 backbone=backbone, 146 heads=heads, 147 adjustment_mode=None, # use the own adjustment mechanism of FG-AdaHAT 148 adjustment_intensity=adjustment_intensity, 149 s_max=s_max, 150 clamp_threshold=clamp_threshold, 151 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 152 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 153 task_embedding_init_mode=task_embedding_init_mode, 154 epsilon=base_mask_sparsity_reg, # the epsilon is now the base mask sparsity regularization factor 155 non_algorithmic_hparams=non_algorithmic_hparams, 156 **kwargs, 157 ) 158 159 # save additional algorithmic hyperparameters 160 self.save_hyperparameters( 161 "adjustment_intensity", 162 "importance_type", 163 "importance_summing_strategy", 164 "importance_scheduler_type", 165 "neuron_to_weight_importance_aggregation_mode", 166 "s_max", 167 "clamp_threshold", 168 "mask_sparsity_reg_factor", 169 "mask_sparsity_reg_mode", 170 "base_importance", 171 "base_mask_sparsity_reg", 172 "base_linear", 173 "filter_by_cumulative_mask", 174 "filter_unmasked_importance", 175 "step_multiply_training_mask", 176 ) 177 178 self.importance_type: str | None = importance_type 179 r"""The type of the neuron-wise importance added to AdaHAT importance.""" 180 181 self.importance_scheduler_type: str = importance_scheduler_type 182 r"""The type of the importance scheduler.""" 183 self.neuron_to_weight_importance_aggregation_mode: str = ( 184 neuron_to_weight_importance_aggregation_mode 185 ) 186 r"""The mode of aggregation from neuron-wise to weight-wise importance. """ 187 self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask 188 r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate.""" 189 self.filter_unmasked_importance: bool = filter_unmasked_importance 190 r"""The flag to filter unmasked importance values (set them to 0) at the end of task training.""" 191 self.step_multiply_training_mask: bool = step_multiply_training_mask 192 r"""The flag to multiply the training mask to the importance at each training step.""" 193 194 # importance summing strategy 195 self.importance_summing_strategy: str = importance_summing_strategy 196 r"""The strategy to sum the neuron-wise importance for previous tasks.""" 197 if importance_summing_strategy_linear_step is not None: 198 self.importance_summing_strategy_linear_step: float = ( 199 importance_summing_strategy_linear_step 200 ) 201 r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease').""" 202 if importance_summing_strategy_exponential_rate is not None: 203 self.importance_summing_strategy_exponential_rate: float = ( 204 importance_summing_strategy_exponential_rate 205 ) 206 r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """ 207 if importance_summing_strategy_log_base is not None: 208 self.importance_summing_strategy_log_base: float = ( 209 importance_summing_strategy_log_base 210 ) 211 r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """ 212 213 # base values 214 self.base_importance: float = base_importance 215 r"""The base value added to the importance to avoid zero. """ 216 self.base_mask_sparsity_reg: float = base_mask_sparsity_reg 217 r"""The base value added to the mask sparsity regularization to avoid zero. """ 218 self.base_linear: float = base_linear 219 r"""The base value added to the linear layer to avoid zero. """ 220 221 self.importances: dict[int, dict[str, Tensor]] = {} 222 r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 223 self.summative_importance_for_previous_tasks: dict[str, Tensor] = {} 224 r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """ 225 226 self.num_steps_t: int 227 r"""The number of training steps for the current task `self.task_id`.""" 228 # set manual optimization 229 self.automatic_optimization = False 230 231 FGAdaHAT.sanity_check(self) 232 233 def sanity_check(self) -> None: 234 r"""Sanity check.""" 235 236 # check importance type 237 if self.importance_type not in [ 238 "input_weight_abs_sum", 239 "output_weight_abs_sum", 240 "input_weight_gradient_abs_sum", 241 "output_weight_gradient_abs_sum", 242 "activation_abs", 243 "input_weight_abs_sum_x_activation_abs", 244 "output_weight_abs_sum_x_activation_abs", 245 "gradient_x_activation_abs", 246 "input_weight_gradient_square_sum", 247 "output_weight_gradient_square_sum", 248 "input_weight_gradient_square_sum_x_activation_abs", 249 "output_weight_gradient_square_sum_x_activation_abs", 250 "conductance_abs", 251 "internal_influence_abs", 252 "gradcam_abs", 253 "deeplift_abs", 254 "deepliftshap_abs", 255 "gradientshap_abs", 256 "integrated_gradients_abs", 257 "feature_ablation_abs", 258 "lrp_abs", 259 "cbp_adaptation", 260 "cbp_adaptive_contribution", 261 ]: 262 raise ValueError( 263 f"importance_type must be one of the predefined types, but got {self.importance_type}" 264 ) 265 266 # check importance summing strategy 267 if self.importance_summing_strategy not in [ 268 "add_latest", 269 "add_all", 270 "add_average", 271 "linear_decrease", 272 "quadratic_decrease", 273 "cubic_decrease", 274 "exponential_decrease", 275 "log_decrease", 276 "factorial_decrease", 277 ]: 278 raise ValueError( 279 f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}" 280 ) 281 282 # check importance scheduler type 283 if self.importance_scheduler_type not in [ 284 "linear_sparsity_reg", 285 "sparsity_reg", 286 "summative_mask_sparsity_reg", 287 ]: 288 raise ValueError( 289 f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}" 290 ) 291 292 # check neuron to weight importance aggregation mode 293 if self.neuron_to_weight_importance_aggregation_mode not in [ 294 "min", 295 "max", 296 "mean", 297 ]: 298 raise ValueError( 299 f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}" 300 ) 301 302 # check base values 303 if self.base_importance < 0: 304 raise ValueError( 305 f"base_importance must be >= 0, but got {self.base_importance}" 306 ) 307 if self.base_mask_sparsity_reg <= 0: 308 raise ValueError( 309 f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}" 310 ) 311 if self.base_linear <= 0: 312 raise ValueError(f"base_linear must be > 0, but got {self.base_linear}") 313 314 def on_train_start(self) -> None: 315 r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.""" 316 super().on_train_start() 317 318 self.importances[self.task_id] = ( 319 {} 320 ) # initialize the importance for the current task 321 322 # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time. 323 for layer_name in self.backbone.weighted_layer_names: 324 layer = self.backbone.get_layer_by_name( 325 layer_name 326 ) # get the layer by its name 327 num_units = layer.weight.shape[0] 328 329 # initialize the accumulated importance at the beginning of each task 330 self.importances[self.task_id][layer_name] = torch.zeros(num_units).to( 331 self.device 332 ) 333 334 # reset the number of steps counter for the current task 335 self.num_steps_t = 0 336 337 # initialize the summative neuron-wise importance at the beginning of the first task 338 if self.task_id == 1: 339 self.summative_importance_for_previous_tasks[layer_name] = torch.zeros( 340 num_units 341 ).to( 342 self.device 343 ) # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$ 344 345 def clip_grad_by_adjustment( 346 self, 347 network_sparsity: dict[str, Tensor], 348 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 349 r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper. 350 351 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 352 353 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 354 355 **Args:** 356 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler. 357 358 **Returns:** 359 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 360 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 361 - **capacity** (`Tensor`): the calculated network capacity. 362 """ 363 364 # initialize network capacity metric 365 capacity = HATNetworkCapacityMetric().to(self.device) 366 adjustment_rate_weight = {} 367 adjustment_rate_bias = {} 368 369 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper 370 for layer_name in self.backbone.weighted_layer_names: 371 372 layer = self.backbone.get_layer_by_name( 373 layer_name 374 ) # get the layer by its name 375 376 # placeholder for the adjustment rate to avoid the error of using it before assignment 377 adjustment_rate_weight_layer = 1 378 adjustment_rate_bias_layer = 1 379 380 # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask 381 weight_importance, bias_importance = ( 382 self.backbone.get_layer_measure_parameter_wise( 383 neuron_wise_measure=self.summative_importance_for_previous_tasks, 384 layer_name=layer_name, 385 aggregation_mode=self.neuron_to_weight_importance_aggregation_mode, 386 ) 387 ) 388 389 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 390 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 391 layer_name=layer_name, 392 aggregation_mode="min", 393 ) 394 395 # filter the weight importance by the cumulative mask 396 if self.filter_by_cumulative_mask: 397 weight_importance = weight_importance * weight_mask 398 bias_importance = bias_importance * bias_mask 399 400 network_sparsity_layer = network_sparsity[layer_name] 401 402 # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper 403 factor = network_sparsity_layer + self.base_mask_sparsity_reg 404 if self.importance_scheduler_type == "linear_sparsity_reg": 405 factor = factor * (self.task_id + self.base_linear) 406 elif self.importance_scheduler_type == "sparsity_reg": 407 pass 408 elif self.importance_scheduler_type == "summative_mask_sparsity_reg": 409 factor = factor * ( 410 self.summative_mask_for_previous_tasks + self.base_linear 411 ) 412 413 # calculate the adjustment rate 414 adjustment_rate_weight_layer = torch.div( 415 self.adjustment_intensity, 416 (factor * weight_importance + self.adjustment_intensity), 417 ) 418 419 adjustment_rate_bias_layer = torch.div( 420 self.adjustment_intensity, 421 (factor * bias_importance + self.adjustment_intensity), 422 ) 423 424 # apply the adjustment rate to the gradients 425 layer.weight.grad.data *= adjustment_rate_weight_layer 426 if layer.bias is not None: 427 layer.bias.grad.data *= adjustment_rate_bias_layer 428 429 # store the adjustment rate for logging 430 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 431 if layer.bias is not None: 432 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 433 434 # update network capacity metric 435 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 436 437 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 438 439 def on_train_batch_end( 440 self, outputs: dict[str, Any], batch: Any, batch_idx: int 441 ) -> None: 442 r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step. 443 444 **Args:** 445 - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`). 446 - **batch** (`Any`): training data batch. 447 - **batch_idx** (`int`): index of the current batch (for mask figure file name). 448 """ 449 450 # get potential useful information from training batch 451 activations = outputs["activations"] 452 input = outputs["input"] 453 target = outputs["target"] 454 mask = outputs["mask"] 455 num_batches = self.trainer.num_training_batches 456 457 for layer_name in self.backbone.weighted_layer_names: 458 # layer-wise operation 459 460 activation = activations[layer_name] 461 462 # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. 463 if self.importance_type == "input_weight_abs_sum": 464 importance_step = self.get_importance_step_layer_weight_abs_sum( 465 layer_name=layer_name, 466 if_output_weight=False, 467 reciprocal=False, 468 ) 469 elif self.importance_type == "output_weight_abs_sum": 470 importance_step = self.get_importance_step_layer_weight_abs_sum( 471 layer_name=layer_name, 472 if_output_weight=True, 473 reciprocal=False, 474 ) 475 elif self.importance_type == "input_weight_gradient_abs_sum": 476 importance_step = ( 477 self.get_importance_step_layer_weight_gradient_abs_sum( 478 layer_name=layer_name, if_output_weight=False 479 ) 480 ) 481 elif self.importance_type == "output_weight_gradient_abs_sum": 482 importance_step = ( 483 self.get_importance_step_layer_weight_gradient_abs_sum( 484 layer_name=layer_name, if_output_weight=True 485 ) 486 ) 487 elif self.importance_type == "activation_abs": 488 importance_step = self.get_importance_step_layer_activation_abs( 489 activation=activation 490 ) 491 elif self.importance_type == "input_weight_abs_sum_x_activation_abs": 492 importance_step = ( 493 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 494 layer_name=layer_name, 495 activation=activation, 496 if_output_weight=False, 497 ) 498 ) 499 elif self.importance_type == "output_weight_abs_sum_x_activation_abs": 500 importance_step = ( 501 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 502 layer_name=layer_name, 503 activation=activation, 504 if_output_weight=True, 505 ) 506 ) 507 elif self.importance_type == "gradient_x_activation_abs": 508 importance_step = ( 509 self.get_importance_step_layer_gradient_x_activation_abs( 510 layer_name=layer_name, 511 input=input, 512 target=target, 513 batch_idx=batch_idx, 514 num_batches=num_batches, 515 ) 516 ) 517 elif self.importance_type == "input_weight_gradient_square_sum": 518 importance_step = ( 519 self.get_importance_step_layer_weight_gradient_square_sum( 520 layer_name=layer_name, 521 activation=activation, 522 if_output_weight=False, 523 ) 524 ) 525 elif self.importance_type == "output_weight_gradient_square_sum": 526 importance_step = ( 527 self.get_importance_step_layer_weight_gradient_square_sum( 528 layer_name=layer_name, 529 activation=activation, 530 if_output_weight=True, 531 ) 532 ) 533 elif ( 534 self.importance_type 535 == "input_weight_gradient_square_sum_x_activation_abs" 536 ): 537 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 538 layer_name=layer_name, 539 activation=activation, 540 if_output_weight=False, 541 ) 542 elif ( 543 self.importance_type 544 == "output_weight_gradient_square_sum_x_activation_abs" 545 ): 546 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 547 layer_name=layer_name, 548 activation=activation, 549 if_output_weight=True, 550 ) 551 elif self.importance_type == "conductance_abs": 552 importance_step = self.get_importance_step_layer_conductance_abs( 553 layer_name=layer_name, 554 input=input, 555 baselines=None, 556 target=target, 557 batch_idx=batch_idx, 558 num_batches=num_batches, 559 ) 560 elif self.importance_type == "internal_influence_abs": 561 importance_step = self.get_importance_step_layer_internal_influence_abs( 562 layer_name=layer_name, 563 input=input, 564 baselines=None, 565 target=target, 566 batch_idx=batch_idx, 567 num_batches=num_batches, 568 ) 569 elif self.importance_type == "gradcam_abs": 570 importance_step = self.get_importance_step_layer_gradcam_abs( 571 layer_name=layer_name, 572 input=input, 573 target=target, 574 batch_idx=batch_idx, 575 num_batches=num_batches, 576 ) 577 elif self.importance_type == "deeplift_abs": 578 importance_step = self.get_importance_step_layer_deeplift_abs( 579 layer_name=layer_name, 580 input=input, 581 baselines=None, 582 target=target, 583 batch_idx=batch_idx, 584 num_batches=num_batches, 585 ) 586 elif self.importance_type == "deepliftshap_abs": 587 importance_step = self.get_importance_step_layer_deepliftshap_abs( 588 layer_name=layer_name, 589 input=input, 590 baselines=None, 591 target=target, 592 batch_idx=batch_idx, 593 num_batches=num_batches, 594 ) 595 elif self.importance_type == "gradientshap_abs": 596 importance_step = self.get_importance_step_layer_gradientshap_abs( 597 layer_name=layer_name, 598 input=input, 599 baselines=None, 600 target=target, 601 batch_idx=batch_idx, 602 num_batches=num_batches, 603 ) 604 elif self.importance_type == "integrated_gradients_abs": 605 importance_step = ( 606 self.get_importance_step_layer_integrated_gradients_abs( 607 layer_name=layer_name, 608 input=input, 609 baselines=None, 610 target=target, 611 batch_idx=batch_idx, 612 num_batches=num_batches, 613 ) 614 ) 615 elif self.importance_type == "feature_ablation_abs": 616 importance_step = self.get_importance_step_layer_feature_ablation_abs( 617 layer_name=layer_name, 618 input=input, 619 layer_baselines=None, 620 target=target, 621 batch_idx=batch_idx, 622 num_batches=num_batches, 623 ) 624 elif self.importance_type == "lrp_abs": 625 importance_step = self.get_importance_step_layer_lrp_abs( 626 layer_name=layer_name, 627 input=input, 628 target=target, 629 batch_idx=batch_idx, 630 num_batches=num_batches, 631 ) 632 elif self.importance_type == "cbp_adaptation": 633 importance_step = self.get_importance_step_layer_weight_abs_sum( 634 layer_name=layer_name, 635 if_output_weight=False, 636 reciprocal=True, 637 ) 638 elif self.importance_type == "cbp_adaptive_contribution": 639 importance_step = ( 640 self.get_importance_step_layer_cbp_adaptive_contribution( 641 layer_name=layer_name, 642 activation=activation, 643 ) 644 ) 645 646 importance_step = min_max_normalize( 647 importance_step 648 ) # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper 649 650 # multiply the importance by the training mask. See Eq. (6) in the paper 651 if self.step_multiply_training_mask: 652 importance_step = importance_step * mask[layer_name] 653 654 # update accumulated importance 655 self.importances[self.task_id][layer_name] = ( 656 self.importances[self.task_id][layer_name] + importance_step 657 ) 658 659 # update number of steps counter 660 self.num_steps_t += 1 661 662 def on_train_end(self) -> None: 663 r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task.""" 664 super().on_train_end() # store the mask and update cumulative and summative masks 665 666 for layer_name in self.backbone.weighted_layer_names: 667 668 # average the neuron-wise step importance. See Eq. (4) in the paper 669 self.importances[self.task_id][layer_name] = ( 670 self.importances[self.task_id][layer_name] 671 ) / self.num_steps_t 672 673 # add the base importance. See Eq. (6) in the paper 674 self.importances[self.task_id][layer_name] = ( 675 self.importances[self.task_id][layer_name] + self.base_importance 676 ) 677 678 # filter unmasked importance 679 if self.filter_unmasked_importance: 680 self.importances[self.task_id][layer_name] = ( 681 self.importances[self.task_id][layer_name] 682 * self.backbone.masks[f"{self.task_id}"][layer_name] 683 ) 684 685 # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper 686 if self.importance_summing_strategy == "add_latest": 687 self.summative_importance_for_previous_tasks[ 688 layer_name 689 ] += self.importances[self.task_id][layer_name] 690 691 elif self.importance_summing_strategy == "add_all": 692 for t in range(1, self.task_id + 1): 693 self.summative_importance_for_previous_tasks[ 694 layer_name 695 ] += self.importances[t][layer_name] 696 697 elif self.importance_summing_strategy == "add_average": 698 for t in range(1, self.task_id + 1): 699 self.summative_importance_for_previous_tasks[layer_name] += ( 700 self.importances[t][layer_name] / self.task_id 701 ) 702 else: 703 self.summative_importance_for_previous_tasks[ 704 layer_name 705 ] = torch.zeros_like( 706 self.summative_importance_for_previous_tasks[layer_name] 707 ).to( 708 self.device 709 ) # starting adding from 0 710 711 if self.importance_summing_strategy == "linear_decrease": 712 s = self.importance_summing_strategy_linear_step 713 for t in range(1, self.task_id + 1): 714 w_t = s * (self.task_id - t) + 1 715 716 elif self.importance_summing_strategy == "quadratic_decrease": 717 for t in range(1, self.task_id + 1): 718 w_t = (self.task_id - t + 1) ** 2 719 elif self.importance_summing_strategy == "cubic_decrease": 720 for t in range(1, self.task_id + 1): 721 w_t = (self.task_id - t + 1) ** 3 722 elif self.importance_summing_strategy == "exponential_decrease": 723 for t in range(1, self.task_id + 1): 724 r = self.importance_summing_strategy_exponential_rate 725 726 w_t = r ** (self.task_id - t + 1) 727 elif self.importance_summing_strategy == "log_decrease": 728 a = self.importance_summing_strategy_log_base 729 for t in range(1, self.task_id + 1): 730 w_t = math.log(self.task_id - t, a) + 1 731 elif self.importance_summing_strategy == "factorial_decrease": 732 for t in range(1, self.task_id + 1): 733 w_t = math.factorial(self.task_id - t + 1) 734 else: 735 raise ValueError 736 self.summative_importance_for_previous_tasks[layer_name] += ( 737 self.importances[t][layer_name] * w_t 738 ) 739 740 def get_importance_step_layer_weight_abs_sum( 741 self: str, 742 layer_name: str, 743 if_output_weight: bool, 744 reciprocal: bool, 745 ) -> Tensor: 746 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights. 747 748 **Args:** 749 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 750 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 751 - **reciprocal** (`bool`): whether to take reciprocal. 752 753 **Returns:** 754 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 755 """ 756 layer = self.backbone.get_layer_by_name(layer_name) 757 758 if not if_output_weight: 759 weight_abs = torch.abs(layer.weight.data) 760 weight_abs_sum = torch.sum( 761 weight_abs, 762 dim=[ 763 i for i in range(weight_abs.dim()) if i != 0 764 ], # sum over the input dimension 765 ) 766 else: 767 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 768 weight_abs_sum = torch.sum( 769 weight_abs, 770 dim=[ 771 i for i in range(weight_abs.dim()) if i != 1 772 ], # sum over the output dimension 773 ) 774 775 if reciprocal: 776 weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum) 777 importance_step_layer = weight_abs_sum_reciprocal 778 else: 779 importance_step_layer = weight_abs_sum 780 importance_step_layer = importance_step_layer.detach() 781 782 return importance_step_layer 783 784 def get_importance_step_layer_weight_gradient_abs_sum( 785 self: str, 786 layer_name: str, 787 if_output_weight: bool, 788 ) -> Tensor: 789 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights. 790 791 **Args:** 792 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 793 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 794 795 **Returns:** 796 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 797 """ 798 layer = self.backbone.get_layer_by_name(layer_name) 799 800 if not if_output_weight: 801 gradient_abs = torch.abs(layer.weight.grad.data) 802 gradient_abs_sum = torch.sum( 803 gradient_abs, 804 dim=[ 805 i for i in range(gradient_abs.dim()) if i != 0 806 ], # sum over the input dimension 807 ) 808 else: 809 gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data) 810 gradient_abs_sum = torch.sum( 811 gradient_abs, 812 dim=[ 813 i for i in range(gradient_abs.dim()) if i != 1 814 ], # sum over the output dimension 815 ) 816 817 importance_step_layer = gradient_abs_sum 818 importance_step_layer = importance_step_layer.detach() 819 820 return importance_step_layer 821 822 def get_importance_step_layer_activation_abs( 823 self: str, 824 activation: Tensor, 825 ) -> Tensor: 826 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum. 827 828 **Args:** 829 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 830 831 **Returns:** 832 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 833 """ 834 activation_abs_batch_mean = torch.mean( 835 torch.abs(activation), 836 dim=[ 837 i for i in range(activation.dim()) if i != 1 838 ], # average the features over batch samples 839 ) 840 importance_step_layer = activation_abs_batch_mean 841 importance_step_layer = importance_step_layer.detach() 842 843 return importance_step_layer 844 845 def get_importance_step_layer_weight_abs_sum_x_activation_abs( 846 self: str, 847 layer_name: str, 848 activation: Tensor, 849 if_output_weight: bool, 850 ) -> Tensor: 851 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 852 853 **Args:** 854 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 855 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 856 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 857 858 **Returns:** 859 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 860 """ 861 layer = self.backbone.get_layer_by_name(layer_name) 862 863 if not if_output_weight: 864 weight_abs = torch.abs(layer.weight.data) 865 weight_abs_sum = torch.sum( 866 weight_abs, 867 dim=[ 868 i for i in range(weight_abs.dim()) if i != 0 869 ], # sum over the input dimension 870 ) 871 else: 872 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 873 weight_abs_sum = torch.sum( 874 weight_abs, 875 dim=[ 876 i for i in range(weight_abs.dim()) if i != 1 877 ], # sum over the output dimension 878 ) 879 880 activation_abs_batch_mean = torch.mean( 881 torch.abs(activation), 882 dim=[ 883 i for i in range(activation.dim()) if i != 1 884 ], # average the features over batch samples 885 ) 886 887 importance_step_layer = weight_abs_sum * activation_abs_batch_mean 888 importance_step_layer = importance_step_layer.detach() 889 890 return importance_step_layer 891 892 def get_importance_step_layer_gradient_x_activation_abs( 893 self: str, 894 layer_name: str, 895 input: Tensor | tuple[Tensor, ...], 896 target: Tensor | None, 897 batch_idx: int, 898 num_batches: int, 899 ) -> Tensor: 900 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum. 901 902 **Args:** 903 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 904 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 905 - **target** (`Tensor` | `None`): the target batch of the training step. 906 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 907 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 908 909 **Returns:** 910 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 911 """ 912 layer = self.backbone.get_layer_by_name(layer_name) 913 914 input = input.requires_grad_() 915 916 # initialize the Layer Gradient X Activation object 917 layer_gradient_x_activation = LayerGradientXActivation( 918 forward_func=self.forward, layer=layer 919 ) 920 921 self.set_forward_func_return_logits_only(True) 922 # calculate layer attribution of the step 923 attribution = layer_gradient_x_activation.attribute( 924 inputs=input, 925 target=target, 926 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 927 ) 928 self.set_forward_func_return_logits_only(False) 929 930 attribution_abs_batch_mean = torch.mean( 931 torch.abs(attribution), 932 dim=[ 933 i for i in range(attribution.dim()) if i != 1 934 ], # average the features over batch samples 935 ) 936 937 importance_step_layer = attribution_abs_batch_mean 938 importance_step_layer = importance_step_layer.detach() 939 940 return importance_step_layer 941 942 def get_importance_step_layer_weight_gradient_square_sum( 943 self: str, 944 layer_name: str, 945 activation: Tensor, 946 if_output_weight: bool, 947 ) -> Tensor: 948 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 949 950 **Args:** 951 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 952 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 953 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 954 955 **Returns:** 956 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 957 """ 958 layer = self.backbone.get_layer_by_name(layer_name) 959 960 if not if_output_weight: 961 gradient_square = layer.weight.grad.data**2 962 gradient_square_sum = torch.sum( 963 gradient_square, 964 dim=[ 965 i for i in range(gradient_square.dim()) if i != 0 966 ], # sum over the input dimension 967 ) 968 else: 969 gradient_square = self.next_layer(layer_name).weight.grad.data**2 970 gradient_square_sum = torch.sum( 971 gradient_square, 972 dim=[ 973 i for i in range(gradient_square.dim()) if i != 1 974 ], # sum over the output dimension 975 ) 976 977 importance_step_layer = gradient_square_sum 978 importance_step_layer = importance_step_layer.detach() 979 980 return importance_step_layer 981 982 def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 983 self: str, 984 layer_name: str, 985 activation: Tensor, 986 if_output_weight: bool, 987 ) -> Tensor: 988 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 989 990 **Args:** 991 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 992 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 993 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 994 995 **Returns:** 996 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 997 """ 998 layer = self.backbone.get_layer_by_name(layer_name) 999 1000 if not if_output_weight: 1001 gradient_square = layer.weight.grad.data**2 1002 gradient_square_sum = torch.sum( 1003 gradient_square, 1004 dim=[ 1005 i for i in range(gradient_square.dim()) if i != 0 1006 ], # sum over the input dimension 1007 ) 1008 else: 1009 gradient_square = self.next_layer(layer_name).weight.grad.data**2 1010 gradient_square_sum = torch.sum( 1011 gradient_square, 1012 dim=[ 1013 i for i in range(gradient_square.dim()) if i != 1 1014 ], # sum over the output dimension 1015 ) 1016 1017 activation_abs_batch_mean = torch.mean( 1018 torch.abs(activation), 1019 dim=[ 1020 i for i in range(activation.dim()) if i != 1 1021 ], # average the features over batch samples 1022 ) 1023 1024 importance_step_layer = gradient_square_sum * activation_abs_batch_mean 1025 importance_step_layer = importance_step_layer.detach() 1026 1027 return importance_step_layer 1028 1029 def get_importance_step_layer_conductance_abs( 1030 self: str, 1031 layer_name: str, 1032 input: Tensor | tuple[Tensor, ...], 1033 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1034 target: Tensor | None, 1035 batch_idx: int, 1036 num_batches: int, 1037 ) -> Tensor: 1038 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum. 1039 1040 **Args:** 1041 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1042 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1043 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details. 1044 - **target** (`Tensor` | `None`): the target batch of the training step. 1045 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1046 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ). 1047 1048 **Returns:** 1049 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1050 """ 1051 layer = self.backbone.get_layer_by_name(layer_name) 1052 1053 # initialize the Layer Conductance object 1054 layer_conductance = LayerConductance(forward_func=self.forward, layer=layer) 1055 1056 self.set_forward_func_return_logits_only(True) 1057 # calculate layer attribution of the step 1058 attribution = layer_conductance.attribute( 1059 inputs=input, 1060 baselines=baselines, 1061 target=target, 1062 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1063 ) 1064 self.set_forward_func_return_logits_only(False) 1065 1066 attribution_abs_batch_mean = torch.mean( 1067 torch.abs(attribution), 1068 dim=[ 1069 i for i in range(attribution.dim()) if i != 1 1070 ], # average the features over batch samples 1071 ) 1072 1073 importance_step_layer = attribution_abs_batch_mean 1074 importance_step_layer = importance_step_layer.detach() 1075 1076 return importance_step_layer 1077 1078 def get_importance_step_layer_internal_influence_abs( 1079 self: str, 1080 layer_name: str, 1081 input: Tensor | tuple[Tensor, ...], 1082 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1083 target: Tensor | None, 1084 batch_idx: int, 1085 num_batches: int, 1086 ) -> Tensor: 1087 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum. 1088 1089 **Args:** 1090 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1091 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1092 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details. 1093 - **target** (`Tensor` | `None`): the target batch of the training step. 1094 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1095 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1096 1097 **Returns:** 1098 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1099 """ 1100 layer = self.backbone.get_layer_by_name(layer_name) 1101 1102 # initialize the Internal Influence object 1103 internal_influence = InternalInfluence(forward_func=self.forward, layer=layer) 1104 1105 # convert the target to long type to avoid error 1106 target = target.long() if target is not None else None 1107 1108 self.set_forward_func_return_logits_only(True) 1109 # calculate layer attribution of the step 1110 attribution = internal_influence.attribute( 1111 inputs=input, 1112 baselines=baselines, 1113 target=target, 1114 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1115 n_steps=5, # set 10 instead of default 50 to accelerate the computation 1116 ) 1117 self.set_forward_func_return_logits_only(False) 1118 1119 attribution_abs_batch_mean = torch.mean( 1120 torch.abs(attribution), 1121 dim=[ 1122 i for i in range(attribution.dim()) if i != 1 1123 ], # average the features over batch samples 1124 ) 1125 1126 importance_step_layer = attribution_abs_batch_mean 1127 importance_step_layer = importance_step_layer.detach() 1128 1129 return importance_step_layer 1130 1131 def get_importance_step_layer_gradcam_abs( 1132 self: str, 1133 layer_name: str, 1134 input: Tensor | tuple[Tensor, ...], 1135 target: Tensor | None, 1136 batch_idx: int, 1137 num_batches: int, 1138 ) -> Tensor: 1139 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum. 1140 1141 **Args:** 1142 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1143 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1144 - **target** (`Tensor` | `None`): the target batch of the training step. 1145 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1146 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1147 1148 **Returns:** 1149 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1150 """ 1151 layer = self.backbone.get_layer_by_name(layer_name) 1152 1153 # initialize the GradCAM object 1154 gradcam = LayerGradCam(forward_func=self.forward, layer=layer) 1155 1156 self.set_forward_func_return_logits_only(True) 1157 # calculate layer attribution of the step 1158 attribution = gradcam.attribute( 1159 inputs=input, 1160 target=target, 1161 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1162 ) 1163 self.set_forward_func_return_logits_only(False) 1164 1165 attribution_abs_batch_mean = torch.mean( 1166 torch.abs(attribution), 1167 dim=[ 1168 i for i in range(attribution.dim()) if i != 1 1169 ], # average the features over batch samples 1170 ) 1171 1172 importance_step_layer = attribution_abs_batch_mean 1173 importance_step_layer = importance_step_layer.detach() 1174 1175 return importance_step_layer 1176 1177 def get_importance_step_layer_deeplift_abs( 1178 self: str, 1179 layer_name: str, 1180 input: Tensor | tuple[Tensor, ...], 1181 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1182 target: Tensor | None, 1183 batch_idx: int, 1184 num_batches: int, 1185 ) -> Tensor: 1186 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum. 1187 1188 **Args:** 1189 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1190 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1191 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details. 1192 - **target** (`Tensor` | `None`): the target batch of the training step. 1193 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1194 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1195 1196 **Returns:** 1197 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1198 """ 1199 layer = self.backbone.get_layer_by_name(layer_name) 1200 1201 # initialize the Layer DeepLift object 1202 layer_deeplift = LayerDeepLift(model=self, layer=layer) 1203 1204 # convert the target to long type to avoid error 1205 target = target.long() if target is not None else None 1206 1207 self.set_forward_func_return_logits_only(True) 1208 # calculate layer attribution of the step 1209 attribution = layer_deeplift.attribute( 1210 inputs=input, 1211 baselines=baselines, 1212 target=target, 1213 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1214 ) 1215 self.set_forward_func_return_logits_only(False) 1216 1217 attribution_abs_batch_mean = torch.mean( 1218 torch.abs(attribution), 1219 dim=[ 1220 i for i in range(attribution.dim()) if i != 1 1221 ], # average the features over batch samples 1222 ) 1223 1224 importance_step_layer = attribution_abs_batch_mean 1225 importance_step_layer = importance_step_layer.detach() 1226 1227 return importance_step_layer 1228 1229 def get_importance_step_layer_deepliftshap_abs( 1230 self: str, 1231 layer_name: str, 1232 input: Tensor | tuple[Tensor, ...], 1233 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1234 target: Tensor | None, 1235 batch_idx: int, 1236 num_batches: int, 1237 ) -> Tensor: 1238 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum. 1239 1240 **Args:** 1241 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1242 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1243 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details. 1244 - **target** (`Tensor` | `None`): the target batch of the training step. 1245 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1246 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1247 1248 **Returns:** 1249 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1250 """ 1251 layer = self.backbone.get_layer_by_name(layer_name) 1252 1253 # initialize the Layer DeepLiftShap object 1254 layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer) 1255 1256 # convert the target to long type to avoid error 1257 target = target.long() if target is not None else None 1258 1259 self.set_forward_func_return_logits_only(True) 1260 # calculate layer attribution of the step 1261 attribution = layer_deepliftshap.attribute( 1262 inputs=input, 1263 baselines=baselines, 1264 target=target, 1265 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1266 ) 1267 self.set_forward_func_return_logits_only(False) 1268 1269 attribution_abs_batch_mean = torch.mean( 1270 torch.abs(attribution), 1271 dim=[ 1272 i for i in range(attribution.dim()) if i != 1 1273 ], # average the features over batch samples 1274 ) 1275 1276 importance_step_layer = attribution_abs_batch_mean 1277 importance_step_layer = importance_step_layer.detach() 1278 1279 return importance_step_layer 1280 1281 def get_importance_step_layer_gradientshap_abs( 1282 self: str, 1283 layer_name: str, 1284 input: Tensor | tuple[Tensor, ...], 1285 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1286 target: Tensor | None, 1287 batch_idx: int, 1288 num_batches: int, 1289 ) -> Tensor: 1290 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum. 1291 1292 **Args:** 1293 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1294 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1295 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero. 1296 - **target** (`Tensor` | `None`): the target batch of the training step. 1297 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1298 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1299 1300 **Returns:** 1301 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1302 """ 1303 layer = self.backbone.get_layer_by_name(layer_name) 1304 1305 if baselines is None: 1306 baselines = torch.zeros_like( 1307 input 1308 ) # baselines are mandatory for GradientShap API. We explicitly set them to zero 1309 1310 # initialize the Layer GradientShap object 1311 layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer) 1312 1313 # convert the target to long type to avoid error 1314 target = target.long() if target is not None else None 1315 1316 self.set_forward_func_return_logits_only(True) 1317 # calculate layer attribution of the step 1318 attribution = layer_gradientshap.attribute( 1319 inputs=input, 1320 baselines=baselines, 1321 target=target, 1322 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1323 ) 1324 self.set_forward_func_return_logits_only(False) 1325 1326 attribution_abs_batch_mean = torch.mean( 1327 torch.abs(attribution), 1328 dim=[ 1329 i for i in range(attribution.dim()) if i != 1 1330 ], # average the features over batch samples 1331 ) 1332 1333 importance_step_layer = attribution_abs_batch_mean 1334 importance_step_layer = importance_step_layer.detach() 1335 1336 return importance_step_layer 1337 1338 def get_importance_step_layer_integrated_gradients_abs( 1339 self: str, 1340 layer_name: str, 1341 input: Tensor | tuple[Tensor, ...], 1342 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1343 target: Tensor | None, 1344 batch_idx: int, 1345 num_batches: int, 1346 ) -> Tensor: 1347 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum. 1348 1349 **Args:** 1350 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1351 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1352 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details. 1353 - **target** (`Tensor` | `None`): the target batch of the training step. 1354 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1355 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1356 1357 **Returns:** 1358 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1359 """ 1360 layer = self.backbone.get_layer_by_name(layer_name) 1361 1362 # initialize the Layer Integrated Gradients object 1363 layer_integrated_gradients = LayerIntegratedGradients( 1364 forward_func=self.forward, layer=layer 1365 ) 1366 1367 self.set_forward_func_return_logits_only(True) 1368 # calculate layer attribution of the step 1369 attribution = layer_integrated_gradients.attribute( 1370 inputs=input, 1371 baselines=baselines, 1372 target=target, 1373 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1374 ) 1375 self.set_forward_func_return_logits_only(False) 1376 1377 attribution_abs_batch_mean = torch.mean( 1378 torch.abs(attribution), 1379 dim=[ 1380 i for i in range(attribution.dim()) if i != 1 1381 ], # average the features over batch samples 1382 ) 1383 1384 importance_step_layer = attribution_abs_batch_mean 1385 importance_step_layer = importance_step_layer.detach() 1386 1387 return importance_step_layer 1388 1389 def get_importance_step_layer_feature_ablation_abs( 1390 self: str, 1391 layer_name: str, 1392 input: Tensor | tuple[Tensor, ...], 1393 layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1394 target: Tensor | None, 1395 batch_idx: int, 1396 num_batches: int, 1397 if_captum: bool = False, 1398 ) -> Tensor: 1399 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum. 1400 1401 **Args:** 1402 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1403 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1404 - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details. 1405 - **target** (`Tensor` | `None`): the target batch of the training step. 1406 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1407 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1408 - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster. 1409 1410 **Returns:** 1411 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1412 """ 1413 layer = self.backbone.get_layer_by_name(layer_name) 1414 1415 if not if_captum: 1416 # 1. Baseline logits (take first element of forward output) 1417 baseline_out, _, _ = self.forward( 1418 input, "train", batch_idx, num_batches, self.task_id 1419 ) 1420 if target is not None: 1421 baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1) 1422 else: 1423 baseline_scores = baseline_out.sum(dim=1) 1424 1425 # 2. Capture layer’s output shape 1426 activs = {} 1427 handle = layer.register_forward_hook( 1428 lambda module, inp, out: activs.setdefault("output", out.detach()) 1429 ) 1430 _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id) 1431 handle.remove() 1432 layer_output = activs["output"] # shape (B, F, ...) 1433 1434 # 3. Build baseline tensor matching that shape 1435 if layer_baselines is None: 1436 baseline_tensor = torch.zeros_like(layer_output) 1437 elif isinstance(layer_baselines, (int, float)): 1438 baseline_tensor = torch.full_like(layer_output, layer_baselines) 1439 elif isinstance(layer_baselines, Tensor): 1440 if layer_baselines.shape == layer_output.shape: 1441 baseline_tensor = layer_baselines 1442 elif layer_baselines.shape == layer_output.shape[1:]: 1443 baseline_tensor = layer_baselines.unsqueeze(0).repeat( 1444 layer_output.size(0), *([1] * layer_baselines.ndim) 1445 ) 1446 else: 1447 raise ValueError(...) 1448 else: 1449 raise ValueError(...) 1450 1451 B, F = layer_output.size(0), layer_output.size(1) 1452 1453 # 4. Create a “mega-batch” replicating the input F times 1454 if isinstance(input, tuple): 1455 mega_inputs = tuple( 1456 t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:]) 1457 for t in input 1458 ) 1459 else: 1460 mega_inputs = ( 1461 input.unsqueeze(0) 1462 .repeat(F, *([1] * input.ndim)) 1463 .view(-1, *input.shape[1:]) 1464 ) 1465 1466 # 5. Equally replicate the baseline tensor 1467 mega_baseline = ( 1468 baseline_tensor.unsqueeze(0) 1469 .repeat(F, *([1] * baseline_tensor.ndim)) 1470 .view(-1, *baseline_tensor.shape[1:]) 1471 ) 1472 1473 # 6. Precompute vectorized indices 1474 device = layer_output.device 1475 positions = torch.arange(F * B, device=device) # [0,1,...,F*B-1] 1476 feat_idx = torch.arange(F, device=device).repeat_interleave( 1477 B 1478 ) # [0,0,...,1,1,...,F-1] 1479 1480 # 7. One hook to zero out each channel slice across the mega-batch 1481 def mega_ablate_hook(module, inp, out): 1482 out_mod = out.clone() 1483 # for each sample in mega-batch, zero its corresponding channel 1484 out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx] 1485 return out_mod 1486 1487 h = layer.register_forward_hook(mega_ablate_hook) 1488 out_all, _, _ = self.forward( 1489 mega_inputs, "train", batch_idx, num_batches, self.task_id 1490 ) 1491 h.remove() 1492 1493 # 8. Recover scores, reshape [F*B] → [F, B], diff & mean 1494 if target is not None: 1495 tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1) 1496 scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1) 1497 else: 1498 scores_all = out_all.sum(dim=1) 1499 1500 scores_all = scores_all.view(F, B) 1501 diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all) 1502 importance_step_layer = diffs.mean(dim=1).detach() # [F] 1503 1504 return importance_step_layer 1505 1506 else: 1507 # initialize the Layer Feature Ablation object 1508 layer_feature_ablation = LayerFeatureAblation( 1509 forward_func=self.forward, layer=layer 1510 ) 1511 1512 # calculate layer attribution of the step 1513 self.set_forward_func_return_logits_only(True) 1514 attribution = layer_feature_ablation.attribute( 1515 inputs=input, 1516 layer_baselines=layer_baselines, 1517 # target=target, # disable target to enable perturbations_per_eval 1518 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1519 perturbations_per_eval=128, # to accelerate the computation 1520 ) 1521 self.set_forward_func_return_logits_only(False) 1522 1523 attribution_abs_batch_mean = torch.mean( 1524 torch.abs(attribution), 1525 dim=[ 1526 i for i in range(attribution.dim()) if i != 1 1527 ], # average the features over batch samples 1528 ) 1529 1530 importance_step_layer = attribution_abs_batch_mean 1531 importance_step_layer = importance_step_layer.detach() 1532 1533 return importance_step_layer 1534 1535 def get_importance_step_layer_lrp_abs( 1536 self: str, 1537 layer_name: str, 1538 input: Tensor | tuple[Tensor, ...], 1539 target: Tensor | None, 1540 batch_idx: int, 1541 num_batches: int, 1542 ) -> Tensor: 1543 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum. 1544 1545 **Args:** 1546 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1547 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1548 - **target** (`Tensor` | `None`): the target batch of the training step. 1549 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1550 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1551 1552 **Returns:** 1553 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1554 """ 1555 layer = self.backbone.get_layer_by_name(layer_name) 1556 1557 # initialize the Layer LRP object 1558 layer_lrp = LayerLRP(model=self, layer=layer) 1559 1560 # set model to evaluation mode to prevent updating the model parameters 1561 self.eval() 1562 1563 self.set_forward_func_return_logits_only(True) 1564 # calculate layer attribution of the step 1565 attribution = layer_lrp.attribute( 1566 inputs=input, 1567 target=target, 1568 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1569 ) 1570 self.set_forward_func_return_logits_only(False) 1571 1572 attribution_abs_batch_mean = torch.mean( 1573 torch.abs(attribution), 1574 dim=[ 1575 i for i in range(attribution.dim()) if i != 1 1576 ], # average the features over batch samples 1577 ) 1578 1579 importance_step_layer = attribution_abs_batch_mean 1580 importance_step_layer = importance_step_layer.detach() 1581 1582 return importance_step_layer 1583 1584 def get_importance_step_layer_cbp_adaptive_contribution( 1585 self: str, 1586 layer_name: str, 1587 activation: Tensor, 1588 ) -> Tensor: 1589 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 1590 1591 **Args:** 1592 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1593 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 1594 1595 **Returns:** 1596 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1597 """ 1598 layer = self.backbone.get_layer_by_name(layer_name) 1599 1600 input_weight_abs = torch.abs(layer.weight.data) 1601 input_weight_abs_sum = torch.sum( 1602 input_weight_abs, 1603 dim=[ 1604 i for i in range(input_weight_abs.dim()) if i != 0 1605 ], # sum over the input dimension 1606 ) 1607 input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum) 1608 1609 output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 1610 output_weight_abs_sum = torch.sum( 1611 output_weight_abs, 1612 dim=[ 1613 i for i in range(output_weight_abs.dim()) if i != 1 1614 ], # sum over the output dimension 1615 ) 1616 1617 activation_abs_batch_mean = torch.mean( 1618 torch.abs(activation), 1619 dim=[ 1620 i for i in range(activation.dim()) if i != 1 1621 ], # average the features over batch samples 1622 ) 1623 1624 importance_step_layer = ( 1625 output_weight_abs_sum 1626 * activation_abs_batch_mean 1627 * input_weight_abs_sum_reciprocal 1628 ) 1629 importance_step_layer = importance_step_layer.detach() 1630 1631 return importance_step_layer
FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm.
An architecture-based continual learning approach that improves AdaHAT (Adaptive Hard Attention to the Task) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components.
45 def __init__( 46 self, 47 backbone: HATMaskBackbone, 48 heads: HeadsTIL | HeadDIL, 49 adjustment_intensity: float, 50 importance_type: str, 51 importance_summing_strategy: str, 52 importance_scheduler_type: str, 53 neuron_to_weight_importance_aggregation_mode: str, 54 s_max: float, 55 clamp_threshold: float, 56 mask_sparsity_reg_factor: float, 57 mask_sparsity_reg_mode: str = "original", 58 base_importance: float = 0.01, 59 base_mask_sparsity_reg: float = 0.1, 60 base_linear: float = 10, 61 filter_by_cumulative_mask: bool = False, 62 filter_unmasked_importance: bool = True, 63 step_multiply_training_mask: bool = True, 64 task_embedding_init_mode: str = "N01", 65 importance_summing_strategy_linear_step: float | None = None, 66 importance_summing_strategy_exponential_rate: float | None = None, 67 importance_summing_strategy_log_base: float | None = None, 68 non_algorithmic_hparams: dict[str, Any] = {}, 69 **kwargs, 70 ) -> None: 71 r"""Initialize the FG-AdaHAT algorithm with the network. 72 73 **Args:** 74 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 75 - **heads** (`HeadsTIL` | `HeadDIL`): output heads. FG-AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). 76 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper). 77 - **importance_type** (`str`): the type of neuron-wise importance, must be one of: 78 1. 'input_weight_abs_sum': sum of absolute input weights; 79 2. 'output_weight_abs_sum': sum of absolute output weights; 80 3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper); 81 4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper); 82 5. 'activation_abs': absolute activation; 83 6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper); 84 7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper); 85 8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation; 86 9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights; 87 10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights; 88 11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper); 89 12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation; 90 13. 'conductance_abs': absolute layer conductance; 91 14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper); 92 15. 'gradcam_abs': absolute Grad-CAM; 93 16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper); 94 17. 'deepliftshap_abs': absolute DeepLIFT-SHAP; 95 18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper); 96 19. 'integrated_gradients_abs': absolute Integrated Gradients; 97 20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper); 98 21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP); 99 22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 100 23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 101 - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of: 102 1. 'add_latest': add the latest neuron-wise importance to the summative importance; 103 2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance; 104 3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance; 105 4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID; 106 5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID; 107 6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID; 108 7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID; 109 8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID; 110 9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID; 111 - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of: 112 1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`); 113 2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$; 114 3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$. 115 - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of: 116 1. 'min': take the minimum of neuron-wise importance for each weight; 117 2. 'max': take the maximum of neuron-wise importance for each weight; 118 3. 'mean': take the mean of neuron-wise importance for each weight. 119 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 120 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 121 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 122 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 123 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 124 2. 'cross': the cross version of mask sparsity regularization. 125 - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01. 126 - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1. 127 - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10. 128 - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False. 129 - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False. 130 - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True. 131 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 132 1. 'N01' (default): standard normal distribution $N(0, 1)$. 133 2. 'U-11': uniform distribution $U(-1, 1)$. 134 3. 'U01': uniform distribution $U(0, 1)$. 135 4. 'U-10': uniform distribution $U(-1, 0)$. 136 5. 'last': inherit the task embedding from the last task. 137 - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0. 138 - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1. 139 - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1. 140 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 141 - **kwargs**: Reserved for multiple inheritance. 142 143 """ 144 super().__init__( 145 backbone=backbone, 146 heads=heads, 147 adjustment_mode=None, # use the own adjustment mechanism of FG-AdaHAT 148 adjustment_intensity=adjustment_intensity, 149 s_max=s_max, 150 clamp_threshold=clamp_threshold, 151 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 152 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 153 task_embedding_init_mode=task_embedding_init_mode, 154 epsilon=base_mask_sparsity_reg, # the epsilon is now the base mask sparsity regularization factor 155 non_algorithmic_hparams=non_algorithmic_hparams, 156 **kwargs, 157 ) 158 159 # save additional algorithmic hyperparameters 160 self.save_hyperparameters( 161 "adjustment_intensity", 162 "importance_type", 163 "importance_summing_strategy", 164 "importance_scheduler_type", 165 "neuron_to_weight_importance_aggregation_mode", 166 "s_max", 167 "clamp_threshold", 168 "mask_sparsity_reg_factor", 169 "mask_sparsity_reg_mode", 170 "base_importance", 171 "base_mask_sparsity_reg", 172 "base_linear", 173 "filter_by_cumulative_mask", 174 "filter_unmasked_importance", 175 "step_multiply_training_mask", 176 ) 177 178 self.importance_type: str | None = importance_type 179 r"""The type of the neuron-wise importance added to AdaHAT importance.""" 180 181 self.importance_scheduler_type: str = importance_scheduler_type 182 r"""The type of the importance scheduler.""" 183 self.neuron_to_weight_importance_aggregation_mode: str = ( 184 neuron_to_weight_importance_aggregation_mode 185 ) 186 r"""The mode of aggregation from neuron-wise to weight-wise importance. """ 187 self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask 188 r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate.""" 189 self.filter_unmasked_importance: bool = filter_unmasked_importance 190 r"""The flag to filter unmasked importance values (set them to 0) at the end of task training.""" 191 self.step_multiply_training_mask: bool = step_multiply_training_mask 192 r"""The flag to multiply the training mask to the importance at each training step.""" 193 194 # importance summing strategy 195 self.importance_summing_strategy: str = importance_summing_strategy 196 r"""The strategy to sum the neuron-wise importance for previous tasks.""" 197 if importance_summing_strategy_linear_step is not None: 198 self.importance_summing_strategy_linear_step: float = ( 199 importance_summing_strategy_linear_step 200 ) 201 r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease').""" 202 if importance_summing_strategy_exponential_rate is not None: 203 self.importance_summing_strategy_exponential_rate: float = ( 204 importance_summing_strategy_exponential_rate 205 ) 206 r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """ 207 if importance_summing_strategy_log_base is not None: 208 self.importance_summing_strategy_log_base: float = ( 209 importance_summing_strategy_log_base 210 ) 211 r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """ 212 213 # base values 214 self.base_importance: float = base_importance 215 r"""The base value added to the importance to avoid zero. """ 216 self.base_mask_sparsity_reg: float = base_mask_sparsity_reg 217 r"""The base value added to the mask sparsity regularization to avoid zero. """ 218 self.base_linear: float = base_linear 219 r"""The base value added to the linear layer to avoid zero. """ 220 221 self.importances: dict[int, dict[str, Tensor]] = {} 222 r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 223 self.summative_importance_for_previous_tasks: dict[str, Tensor] = {} 224 r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """ 225 226 self.num_steps_t: int 227 r"""The number of training steps for the current task `self.task_id`.""" 228 # set manual optimization 229 self.automatic_optimization = False 230 231 FGAdaHAT.sanity_check(self)
Initialize the FG-AdaHAT algorithm with the network.
Args:
- backbone (
HATMaskBackbone): must be a backbone network with the HAT mask mechanism. - heads (
HeadsTIL|HeadDIL): output heads. FG-AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning). - adjustment_intensity (
float): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper). - importance_type (
str): the type of neuron-wise importance, must be one of:- 'input_weight_abs_sum': sum of absolute input weights;
- 'output_weight_abs_sum': sum of absolute output weights;
- 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
- 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
- 'activation_abs': absolute activation;
- 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
- 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
- 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
- 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
- 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
- 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
- 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
- 'conductance_abs': absolute layer conductance;
- 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
- 'gradcam_abs': absolute Grad-CAM;
- 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
- 'deepliftshap_abs': absolute DeepLIFT-SHAP;
- 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
- 'integrated_gradients_abs': absolute Integrated Gradients;
- 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
- 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
- 'cbp_adaptation': the adaptation function in Continual Backpropagation (CBP);
- 'cbp_adaptive_contribution': the adaptive contribution function in Continual Backpropagation (CBP);
- importance_summing_strategy (
str): the strategy to sum neuron-wise importance for previous tasks, must be one of:- 'add_latest': add the latest neuron-wise importance to the summative importance;
- 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
- 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
- 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
- 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
- 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
- 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
- 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
- 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
- importance_scheduler_type (
str): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:- 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{
base_linear), and $b_R$ is the base mask sparsity regularization factor (see argument base_mask_sparsity_reg); - 'sparsity_reg': $c^t = [R(M^t, M^{
- 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{
- 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{
- 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{
- neuron_to_weight_importance_aggregation_mode (
str): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:- 'min': take the minimum of neuron-wise importance for each weight;
- 'max': take the maximum of neuron-wise importance for each weight;
- 'mean': take the mean of neuron-wise importance for each weight.
- s_max (
float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper. - clamp_threshold (
float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper. - mask_sparsity_reg_factor (
float): hyperparameter, the regularization factor for mask sparsity. - mask_sparsity_reg_mode (
str): the mode of mask sparsity regularization, must be one of:- 'original' (default): the original mask sparsity regularization in the HAT paper.
- 'cross': the cross version of mask sparsity regularization.
- base_importance (
float): base value added to importance ($b_I$ in the paper). Default: 0.01. - base_mask_sparsity_reg (
float): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1. - base_linear (
float): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10. - filter_by_cumulative_mask (
bool): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False. - filter_unmasked_importance (
bool): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False. - step_multiply_training_mask (
bool): whether to multiply the training mask to the importance at each training step. Default: True. - task_embedding_init_mode (
str): the initialization mode for task embeddings, must be one of:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit the task embedding from the last task.
- importance_summing_strategy_linear_step (
float|None): linear step for the importance summing strategy (used whenimportance_summing_strategyis 'linear_decrease'). Must be > 0. - importance_summing_strategy_exponential_rate (
float|None): exponential rate for the importance summing strategy (used whenimportance_summing_strategyis 'exponential_decrease'). Must be > 1. - importance_summing_strategy_log_base (
float|None): base for the logarithm in the importance summing strategy (used whenimportance_summing_strategyis 'log_decrease'). Must be > 1. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility. - kwargs: Reserved for multiple inheritance.
The mode of aggregation from neuron-wise to weight-wise importance.
The flag to filter importance by the cumulative mask when calculating the adjustment rate.
The flag to filter unmasked importance values (set them to 0) at the end of task training.
The flag to multiply the training mask to the importance at each training step.
The base value added to the mask sparsity regularization to avoid zero.
The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ).
The summative neuron-wise importance values of units for previous tasks before the current task self.task_id. See $I^{
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False you are responsible for calling .backward(), .step(), .zero_grad().
233 def sanity_check(self) -> None: 234 r"""Sanity check.""" 235 236 # check importance type 237 if self.importance_type not in [ 238 "input_weight_abs_sum", 239 "output_weight_abs_sum", 240 "input_weight_gradient_abs_sum", 241 "output_weight_gradient_abs_sum", 242 "activation_abs", 243 "input_weight_abs_sum_x_activation_abs", 244 "output_weight_abs_sum_x_activation_abs", 245 "gradient_x_activation_abs", 246 "input_weight_gradient_square_sum", 247 "output_weight_gradient_square_sum", 248 "input_weight_gradient_square_sum_x_activation_abs", 249 "output_weight_gradient_square_sum_x_activation_abs", 250 "conductance_abs", 251 "internal_influence_abs", 252 "gradcam_abs", 253 "deeplift_abs", 254 "deepliftshap_abs", 255 "gradientshap_abs", 256 "integrated_gradients_abs", 257 "feature_ablation_abs", 258 "lrp_abs", 259 "cbp_adaptation", 260 "cbp_adaptive_contribution", 261 ]: 262 raise ValueError( 263 f"importance_type must be one of the predefined types, but got {self.importance_type}" 264 ) 265 266 # check importance summing strategy 267 if self.importance_summing_strategy not in [ 268 "add_latest", 269 "add_all", 270 "add_average", 271 "linear_decrease", 272 "quadratic_decrease", 273 "cubic_decrease", 274 "exponential_decrease", 275 "log_decrease", 276 "factorial_decrease", 277 ]: 278 raise ValueError( 279 f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}" 280 ) 281 282 # check importance scheduler type 283 if self.importance_scheduler_type not in [ 284 "linear_sparsity_reg", 285 "sparsity_reg", 286 "summative_mask_sparsity_reg", 287 ]: 288 raise ValueError( 289 f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}" 290 ) 291 292 # check neuron to weight importance aggregation mode 293 if self.neuron_to_weight_importance_aggregation_mode not in [ 294 "min", 295 "max", 296 "mean", 297 ]: 298 raise ValueError( 299 f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}" 300 ) 301 302 # check base values 303 if self.base_importance < 0: 304 raise ValueError( 305 f"base_importance must be >= 0, but got {self.base_importance}" 306 ) 307 if self.base_mask_sparsity_reg <= 0: 308 raise ValueError( 309 f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}" 310 ) 311 if self.base_linear <= 0: 312 raise ValueError(f"base_linear must be > 0, but got {self.base_linear}")
Sanity check.
314 def on_train_start(self) -> None: 315 r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.""" 316 super().on_train_start() 317 318 self.importances[self.task_id] = ( 319 {} 320 ) # initialize the importance for the current task 321 322 # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time. 323 for layer_name in self.backbone.weighted_layer_names: 324 layer = self.backbone.get_layer_by_name( 325 layer_name 326 ) # get the layer by its name 327 num_units = layer.weight.shape[0] 328 329 # initialize the accumulated importance at the beginning of each task 330 self.importances[self.task_id][layer_name] = torch.zeros(num_units).to( 331 self.device 332 ) 333 334 # reset the number of steps counter for the current task 335 self.num_steps_t = 0 336 337 # initialize the summative neuron-wise importance at the beginning of the first task 338 if self.task_id == 1: 339 self.summative_importance_for_previous_tasks[layer_name] = torch.zeros( 340 num_units 341 ).to( 342 self.device 343 ) # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$
Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.
345 def clip_grad_by_adjustment( 346 self, 347 network_sparsity: dict[str, Tensor], 348 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 349 r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper. 350 351 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 352 353 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 354 355 **Args:** 356 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler. 357 358 **Returns:** 359 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 360 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 361 - **capacity** (`Tensor`): the calculated network capacity. 362 """ 363 364 # initialize network capacity metric 365 capacity = HATNetworkCapacityMetric().to(self.device) 366 adjustment_rate_weight = {} 367 adjustment_rate_bias = {} 368 369 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper 370 for layer_name in self.backbone.weighted_layer_names: 371 372 layer = self.backbone.get_layer_by_name( 373 layer_name 374 ) # get the layer by its name 375 376 # placeholder for the adjustment rate to avoid the error of using it before assignment 377 adjustment_rate_weight_layer = 1 378 adjustment_rate_bias_layer = 1 379 380 # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask 381 weight_importance, bias_importance = ( 382 self.backbone.get_layer_measure_parameter_wise( 383 neuron_wise_measure=self.summative_importance_for_previous_tasks, 384 layer_name=layer_name, 385 aggregation_mode=self.neuron_to_weight_importance_aggregation_mode, 386 ) 387 ) 388 389 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 390 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 391 layer_name=layer_name, 392 aggregation_mode="min", 393 ) 394 395 # filter the weight importance by the cumulative mask 396 if self.filter_by_cumulative_mask: 397 weight_importance = weight_importance * weight_mask 398 bias_importance = bias_importance * bias_mask 399 400 network_sparsity_layer = network_sparsity[layer_name] 401 402 # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper 403 factor = network_sparsity_layer + self.base_mask_sparsity_reg 404 if self.importance_scheduler_type == "linear_sparsity_reg": 405 factor = factor * (self.task_id + self.base_linear) 406 elif self.importance_scheduler_type == "sparsity_reg": 407 pass 408 elif self.importance_scheduler_type == "summative_mask_sparsity_reg": 409 factor = factor * ( 410 self.summative_mask_for_previous_tasks + self.base_linear 411 ) 412 413 # calculate the adjustment rate 414 adjustment_rate_weight_layer = torch.div( 415 self.adjustment_intensity, 416 (factor * weight_importance + self.adjustment_intensity), 417 ) 418 419 adjustment_rate_bias_layer = torch.div( 420 self.adjustment_intensity, 421 (factor * bias_importance + self.adjustment_intensity), 422 ) 423 424 # apply the adjustment rate to the gradients 425 layer.weight.grad.data *= adjustment_rate_weight_layer 426 if layer.bias is not None: 427 layer.bias.grad.data *= adjustment_rate_bias_layer 428 429 # store the adjustment rate for logging 430 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 431 if layer.bias is not None: 432 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 433 434 # update network capacity metric 435 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 436 437 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
Clip the gradients by the adjustment rate. See Eq. (1) in the paper.
Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.
Args:
- network_sparsity (
dict[str, Tensor]): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.
Returns:
- adjustment_rate_weight (
dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. - adjustment_rate_bias (
dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. - capacity (
Tensor): the calculated network capacity.
439 def on_train_batch_end( 440 self, outputs: dict[str, Any], batch: Any, batch_idx: int 441 ) -> None: 442 r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step. 443 444 **Args:** 445 - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`). 446 - **batch** (`Any`): training data batch. 447 - **batch_idx** (`int`): index of the current batch (for mask figure file name). 448 """ 449 450 # get potential useful information from training batch 451 activations = outputs["activations"] 452 input = outputs["input"] 453 target = outputs["target"] 454 mask = outputs["mask"] 455 num_batches = self.trainer.num_training_batches 456 457 for layer_name in self.backbone.weighted_layer_names: 458 # layer-wise operation 459 460 activation = activations[layer_name] 461 462 # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. 463 if self.importance_type == "input_weight_abs_sum": 464 importance_step = self.get_importance_step_layer_weight_abs_sum( 465 layer_name=layer_name, 466 if_output_weight=False, 467 reciprocal=False, 468 ) 469 elif self.importance_type == "output_weight_abs_sum": 470 importance_step = self.get_importance_step_layer_weight_abs_sum( 471 layer_name=layer_name, 472 if_output_weight=True, 473 reciprocal=False, 474 ) 475 elif self.importance_type == "input_weight_gradient_abs_sum": 476 importance_step = ( 477 self.get_importance_step_layer_weight_gradient_abs_sum( 478 layer_name=layer_name, if_output_weight=False 479 ) 480 ) 481 elif self.importance_type == "output_weight_gradient_abs_sum": 482 importance_step = ( 483 self.get_importance_step_layer_weight_gradient_abs_sum( 484 layer_name=layer_name, if_output_weight=True 485 ) 486 ) 487 elif self.importance_type == "activation_abs": 488 importance_step = self.get_importance_step_layer_activation_abs( 489 activation=activation 490 ) 491 elif self.importance_type == "input_weight_abs_sum_x_activation_abs": 492 importance_step = ( 493 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 494 layer_name=layer_name, 495 activation=activation, 496 if_output_weight=False, 497 ) 498 ) 499 elif self.importance_type == "output_weight_abs_sum_x_activation_abs": 500 importance_step = ( 501 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 502 layer_name=layer_name, 503 activation=activation, 504 if_output_weight=True, 505 ) 506 ) 507 elif self.importance_type == "gradient_x_activation_abs": 508 importance_step = ( 509 self.get_importance_step_layer_gradient_x_activation_abs( 510 layer_name=layer_name, 511 input=input, 512 target=target, 513 batch_idx=batch_idx, 514 num_batches=num_batches, 515 ) 516 ) 517 elif self.importance_type == "input_weight_gradient_square_sum": 518 importance_step = ( 519 self.get_importance_step_layer_weight_gradient_square_sum( 520 layer_name=layer_name, 521 activation=activation, 522 if_output_weight=False, 523 ) 524 ) 525 elif self.importance_type == "output_weight_gradient_square_sum": 526 importance_step = ( 527 self.get_importance_step_layer_weight_gradient_square_sum( 528 layer_name=layer_name, 529 activation=activation, 530 if_output_weight=True, 531 ) 532 ) 533 elif ( 534 self.importance_type 535 == "input_weight_gradient_square_sum_x_activation_abs" 536 ): 537 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 538 layer_name=layer_name, 539 activation=activation, 540 if_output_weight=False, 541 ) 542 elif ( 543 self.importance_type 544 == "output_weight_gradient_square_sum_x_activation_abs" 545 ): 546 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 547 layer_name=layer_name, 548 activation=activation, 549 if_output_weight=True, 550 ) 551 elif self.importance_type == "conductance_abs": 552 importance_step = self.get_importance_step_layer_conductance_abs( 553 layer_name=layer_name, 554 input=input, 555 baselines=None, 556 target=target, 557 batch_idx=batch_idx, 558 num_batches=num_batches, 559 ) 560 elif self.importance_type == "internal_influence_abs": 561 importance_step = self.get_importance_step_layer_internal_influence_abs( 562 layer_name=layer_name, 563 input=input, 564 baselines=None, 565 target=target, 566 batch_idx=batch_idx, 567 num_batches=num_batches, 568 ) 569 elif self.importance_type == "gradcam_abs": 570 importance_step = self.get_importance_step_layer_gradcam_abs( 571 layer_name=layer_name, 572 input=input, 573 target=target, 574 batch_idx=batch_idx, 575 num_batches=num_batches, 576 ) 577 elif self.importance_type == "deeplift_abs": 578 importance_step = self.get_importance_step_layer_deeplift_abs( 579 layer_name=layer_name, 580 input=input, 581 baselines=None, 582 target=target, 583 batch_idx=batch_idx, 584 num_batches=num_batches, 585 ) 586 elif self.importance_type == "deepliftshap_abs": 587 importance_step = self.get_importance_step_layer_deepliftshap_abs( 588 layer_name=layer_name, 589 input=input, 590 baselines=None, 591 target=target, 592 batch_idx=batch_idx, 593 num_batches=num_batches, 594 ) 595 elif self.importance_type == "gradientshap_abs": 596 importance_step = self.get_importance_step_layer_gradientshap_abs( 597 layer_name=layer_name, 598 input=input, 599 baselines=None, 600 target=target, 601 batch_idx=batch_idx, 602 num_batches=num_batches, 603 ) 604 elif self.importance_type == "integrated_gradients_abs": 605 importance_step = ( 606 self.get_importance_step_layer_integrated_gradients_abs( 607 layer_name=layer_name, 608 input=input, 609 baselines=None, 610 target=target, 611 batch_idx=batch_idx, 612 num_batches=num_batches, 613 ) 614 ) 615 elif self.importance_type == "feature_ablation_abs": 616 importance_step = self.get_importance_step_layer_feature_ablation_abs( 617 layer_name=layer_name, 618 input=input, 619 layer_baselines=None, 620 target=target, 621 batch_idx=batch_idx, 622 num_batches=num_batches, 623 ) 624 elif self.importance_type == "lrp_abs": 625 importance_step = self.get_importance_step_layer_lrp_abs( 626 layer_name=layer_name, 627 input=input, 628 target=target, 629 batch_idx=batch_idx, 630 num_batches=num_batches, 631 ) 632 elif self.importance_type == "cbp_adaptation": 633 importance_step = self.get_importance_step_layer_weight_abs_sum( 634 layer_name=layer_name, 635 if_output_weight=False, 636 reciprocal=True, 637 ) 638 elif self.importance_type == "cbp_adaptive_contribution": 639 importance_step = ( 640 self.get_importance_step_layer_cbp_adaptive_contribution( 641 layer_name=layer_name, 642 activation=activation, 643 ) 644 ) 645 646 importance_step = min_max_normalize( 647 importance_step 648 ) # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper 649 650 # multiply the importance by the training mask. See Eq. (6) in the paper 651 if self.step_multiply_training_mask: 652 importance_step = importance_step * mask[layer_name] 653 654 # update accumulated importance 655 self.importances[self.task_id][layer_name] = ( 656 self.importances[self.task_id][layer_name] + importance_step 657 ) 658 659 # update number of steps counter 660 self.num_steps_t += 1
Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.
Args:
- outputs (
dict[str, Any]): outputs of the training step (returns oftraining_step()inCLAlgorithm). - batch (
Any): training data batch. - batch_idx (
int): index of the current batch (for mask figure file name).
662 def on_train_end(self) -> None: 663 r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task.""" 664 super().on_train_end() # store the mask and update cumulative and summative masks 665 666 for layer_name in self.backbone.weighted_layer_names: 667 668 # average the neuron-wise step importance. See Eq. (4) in the paper 669 self.importances[self.task_id][layer_name] = ( 670 self.importances[self.task_id][layer_name] 671 ) / self.num_steps_t 672 673 # add the base importance. See Eq. (6) in the paper 674 self.importances[self.task_id][layer_name] = ( 675 self.importances[self.task_id][layer_name] + self.base_importance 676 ) 677 678 # filter unmasked importance 679 if self.filter_unmasked_importance: 680 self.importances[self.task_id][layer_name] = ( 681 self.importances[self.task_id][layer_name] 682 * self.backbone.masks[f"{self.task_id}"][layer_name] 683 ) 684 685 # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper 686 if self.importance_summing_strategy == "add_latest": 687 self.summative_importance_for_previous_tasks[ 688 layer_name 689 ] += self.importances[self.task_id][layer_name] 690 691 elif self.importance_summing_strategy == "add_all": 692 for t in range(1, self.task_id + 1): 693 self.summative_importance_for_previous_tasks[ 694 layer_name 695 ] += self.importances[t][layer_name] 696 697 elif self.importance_summing_strategy == "add_average": 698 for t in range(1, self.task_id + 1): 699 self.summative_importance_for_previous_tasks[layer_name] += ( 700 self.importances[t][layer_name] / self.task_id 701 ) 702 else: 703 self.summative_importance_for_previous_tasks[ 704 layer_name 705 ] = torch.zeros_like( 706 self.summative_importance_for_previous_tasks[layer_name] 707 ).to( 708 self.device 709 ) # starting adding from 0 710 711 if self.importance_summing_strategy == "linear_decrease": 712 s = self.importance_summing_strategy_linear_step 713 for t in range(1, self.task_id + 1): 714 w_t = s * (self.task_id - t) + 1 715 716 elif self.importance_summing_strategy == "quadratic_decrease": 717 for t in range(1, self.task_id + 1): 718 w_t = (self.task_id - t + 1) ** 2 719 elif self.importance_summing_strategy == "cubic_decrease": 720 for t in range(1, self.task_id + 1): 721 w_t = (self.task_id - t + 1) ** 3 722 elif self.importance_summing_strategy == "exponential_decrease": 723 for t in range(1, self.task_id + 1): 724 r = self.importance_summing_strategy_exponential_rate 725 726 w_t = r ** (self.task_id - t + 1) 727 elif self.importance_summing_strategy == "log_decrease": 728 a = self.importance_summing_strategy_log_base 729 for t in range(1, self.task_id + 1): 730 w_t = math.log(self.task_id - t, a) + 1 731 elif self.importance_summing_strategy == "factorial_decrease": 732 for t in range(1, self.task_id + 1): 733 w_t = math.factorial(self.task_id - t + 1) 734 else: 735 raise ValueError 736 self.summative_importance_for_previous_tasks[layer_name] += ( 737 self.importances[t][layer_name] * w_t 738 )
Additionally calculate neuron-wise importance for previous tasks at the end of training each task.
740 def get_importance_step_layer_weight_abs_sum( 741 self: str, 742 layer_name: str, 743 if_output_weight: bool, 744 reciprocal: bool, 745 ) -> Tensor: 746 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights. 747 748 **Args:** 749 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 750 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 751 - **reciprocal** (`bool`): whether to take reciprocal. 752 753 **Returns:** 754 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 755 """ 756 layer = self.backbone.get_layer_by_name(layer_name) 757 758 if not if_output_weight: 759 weight_abs = torch.abs(layer.weight.data) 760 weight_abs_sum = torch.sum( 761 weight_abs, 762 dim=[ 763 i for i in range(weight_abs.dim()) if i != 0 764 ], # sum over the input dimension 765 ) 766 else: 767 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 768 weight_abs_sum = torch.sum( 769 weight_abs, 770 dim=[ 771 i for i in range(weight_abs.dim()) if i != 1 772 ], # sum over the output dimension 773 ) 774 775 if reciprocal: 776 weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum) 777 importance_step_layer = weight_abs_sum_reciprocal 778 else: 779 importance_step_layer = weight_abs_sum 780 importance_step_layer = importance_step_layer.detach() 781 782 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - if_output_weight (
bool): whether to use the output weights or input weights. - reciprocal (
bool): whether to take reciprocal.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
784 def get_importance_step_layer_weight_gradient_abs_sum( 785 self: str, 786 layer_name: str, 787 if_output_weight: bool, 788 ) -> Tensor: 789 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights. 790 791 **Args:** 792 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 793 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 794 795 **Returns:** 796 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 797 """ 798 layer = self.backbone.get_layer_by_name(layer_name) 799 800 if not if_output_weight: 801 gradient_abs = torch.abs(layer.weight.grad.data) 802 gradient_abs_sum = torch.sum( 803 gradient_abs, 804 dim=[ 805 i for i in range(gradient_abs.dim()) if i != 0 806 ], # sum over the input dimension 807 ) 808 else: 809 gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data) 810 gradient_abs_sum = torch.sum( 811 gradient_abs, 812 dim=[ 813 i for i in range(gradient_abs.dim()) if i != 1 814 ], # sum over the output dimension 815 ) 816 817 importance_step_layer = gradient_abs_sum 818 importance_step_layer = importance_step_layer.detach() 819 820 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - if_output_weight (
bool): whether to use the output weights or input weights.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
822 def get_importance_step_layer_activation_abs( 823 self: str, 824 activation: Tensor, 825 ) -> Tensor: 826 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum. 827 828 **Args:** 829 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 830 831 **Returns:** 832 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 833 """ 834 activation_abs_batch_mean = torch.mean( 835 torch.abs(activation), 836 dim=[ 837 i for i in range(activation.dim()) if i != 1 838 ], # average the features over batch samples 839 ) 840 importance_step_layer = activation_abs_batch_mean 841 importance_step_layer = importance_step_layer.detach() 842 843 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of Layer Activation in Captum.
Args:
- activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
845 def get_importance_step_layer_weight_abs_sum_x_activation_abs( 846 self: str, 847 layer_name: str, 848 activation: Tensor, 849 if_output_weight: bool, 850 ) -> Tensor: 851 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 852 853 **Args:** 854 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 855 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 856 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 857 858 **Returns:** 859 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 860 """ 861 layer = self.backbone.get_layer_by_name(layer_name) 862 863 if not if_output_weight: 864 weight_abs = torch.abs(layer.weight.data) 865 weight_abs_sum = torch.sum( 866 weight_abs, 867 dim=[ 868 i for i in range(weight_abs.dim()) if i != 0 869 ], # sum over the input dimension 870 ) 871 else: 872 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 873 weight_abs_sum = torch.sum( 874 weight_abs, 875 dim=[ 876 i for i in range(weight_abs.dim()) if i != 1 877 ], # sum over the output dimension 878 ) 879 880 activation_abs_batch_mean = torch.mean( 881 torch.abs(activation), 882 dim=[ 883 i for i in range(activation.dim()) if i != 1 884 ], # average the features over batch samples 885 ) 886 887 importance_step_layer = weight_abs_sum * activation_abs_batch_mean 888 importance_step_layer = importance_step_layer.detach() 889 890 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in CBP.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ). - if_output_weight (
bool): whether to use the output weights or input weights.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
892 def get_importance_step_layer_gradient_x_activation_abs( 893 self: str, 894 layer_name: str, 895 input: Tensor | tuple[Tensor, ...], 896 target: Tensor | None, 897 batch_idx: int, 898 num_batches: int, 899 ) -> Tensor: 900 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum. 901 902 **Args:** 903 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 904 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 905 - **target** (`Tensor` | `None`): the target batch of the training step. 906 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 907 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 908 909 **Returns:** 910 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 911 """ 912 layer = self.backbone.get_layer_by_name(layer_name) 913 914 input = input.requires_grad_() 915 916 # initialize the Layer Gradient X Activation object 917 layer_gradient_x_activation = LayerGradientXActivation( 918 forward_func=self.forward, layer=layer 919 ) 920 921 self.set_forward_func_return_logits_only(True) 922 # calculate layer attribution of the step 923 attribution = layer_gradient_x_activation.attribute( 924 inputs=input, 925 target=target, 926 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 927 ) 928 self.set_forward_func_return_logits_only(False) 929 930 attribution_abs_batch_mean = torch.mean( 931 torch.abs(attribution), 932 dim=[ 933 i for i in range(attribution.dim()) if i != 1 934 ], # average the features over batch samples 935 ) 936 937 importance_step_layer = attribution_abs_batch_mean 938 importance_step_layer = importance_step_layer.detach() 939 940 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using Layer Gradient X Activation in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
942 def get_importance_step_layer_weight_gradient_square_sum( 943 self: str, 944 layer_name: str, 945 activation: Tensor, 946 if_output_weight: bool, 947 ) -> Tensor: 948 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 949 950 **Args:** 951 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 952 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 953 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 954 955 **Returns:** 956 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 957 """ 958 layer = self.backbone.get_layer_by_name(layer_name) 959 960 if not if_output_weight: 961 gradient_square = layer.weight.grad.data**2 962 gradient_square_sum = torch.sum( 963 gradient_square, 964 dim=[ 965 i for i in range(gradient_square.dim()) if i != 0 966 ], # sum over the input dimension 967 ) 968 else: 969 gradient_square = self.next_layer(layer_name).weight.grad.data**2 970 gradient_square_sum = torch.sum( 971 gradient_square, 972 dim=[ 973 i for i in range(gradient_square.dim()) if i != 1 974 ], # sum over the output dimension 975 ) 976 977 importance_step_layer = gradient_square_sum 978 importance_step_layer = importance_step_layer.detach() 979 980 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in EWC.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ). - if_output_weight (
bool): whether to use the output weights or input weights.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
982 def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 983 self: str, 984 layer_name: str, 985 activation: Tensor, 986 if_output_weight: bool, 987 ) -> Tensor: 988 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 989 990 **Args:** 991 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 992 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 993 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 994 995 **Returns:** 996 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 997 """ 998 layer = self.backbone.get_layer_by_name(layer_name) 999 1000 if not if_output_weight: 1001 gradient_square = layer.weight.grad.data**2 1002 gradient_square_sum = torch.sum( 1003 gradient_square, 1004 dim=[ 1005 i for i in range(gradient_square.dim()) if i != 0 1006 ], # sum over the input dimension 1007 ) 1008 else: 1009 gradient_square = self.next_layer(layer_name).weight.grad.data**2 1010 gradient_square_sum = torch.sum( 1011 gradient_square, 1012 dim=[ 1013 i for i in range(gradient_square.dim()) if i != 1 1014 ], # sum over the output dimension 1015 ) 1016 1017 activation_abs_batch_mean = torch.mean( 1018 torch.abs(activation), 1019 dim=[ 1020 i for i in range(activation.dim()) if i != 1 1021 ], # average the features over batch samples 1022 ) 1023 1024 importance_step_layer = gradient_square_sum * activation_abs_batch_mean 1025 importance_step_layer = importance_step_layer.detach() 1026 1027 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in EWC.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ). - if_output_weight (
bool): whether to use the output weights or input weights.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1029 def get_importance_step_layer_conductance_abs( 1030 self: str, 1031 layer_name: str, 1032 input: Tensor | tuple[Tensor, ...], 1033 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1034 target: Tensor | None, 1035 batch_idx: int, 1036 num_batches: int, 1037 ) -> Tensor: 1038 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum. 1039 1040 **Args:** 1041 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1042 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1043 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details. 1044 - **target** (`Tensor` | `None`): the target batch of the training step. 1045 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1046 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ). 1047 1048 **Returns:** 1049 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1050 """ 1051 layer = self.backbone.get_layer_by_name(layer_name) 1052 1053 # initialize the Layer Conductance object 1054 layer_conductance = LayerConductance(forward_func=self.forward, layer=layer) 1055 1056 self.set_forward_func_return_logits_only(True) 1057 # calculate layer attribution of the step 1058 attribution = layer_conductance.attribute( 1059 inputs=input, 1060 baselines=baselines, 1061 target=target, 1062 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1063 ) 1064 self.set_forward_func_return_logits_only(False) 1065 1066 attribution_abs_batch_mean = torch.mean( 1067 torch.abs(attribution), 1068 dim=[ 1069 i for i in range(attribution.dim()) if i != 1 1070 ], # average the features over batch samples 1071 ) 1072 1073 importance_step_layer = attribution_abs_batch_mean 1074 importance_step_layer = importance_step_layer.detach() 1075 1076 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of conductance. We implement this using Layer Conductance in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): starting point from which integral is computed in this method. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.- mask (Tensor): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1078 def get_importance_step_layer_internal_influence_abs( 1079 self: str, 1080 layer_name: str, 1081 input: Tensor | tuple[Tensor, ...], 1082 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1083 target: Tensor | None, 1084 batch_idx: int, 1085 num_batches: int, 1086 ) -> Tensor: 1087 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum. 1088 1089 **Args:** 1090 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1091 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1092 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details. 1093 - **target** (`Tensor` | `None`): the target batch of the training step. 1094 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1095 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1096 1097 **Returns:** 1098 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1099 """ 1100 layer = self.backbone.get_layer_by_name(layer_name) 1101 1102 # initialize the Internal Influence object 1103 internal_influence = InternalInfluence(forward_func=self.forward, layer=layer) 1104 1105 # convert the target to long type to avoid error 1106 target = target.long() if target is not None else None 1107 1108 self.set_forward_func_return_logits_only(True) 1109 # calculate layer attribution of the step 1110 attribution = internal_influence.attribute( 1111 inputs=input, 1112 baselines=baselines, 1113 target=target, 1114 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1115 n_steps=5, # set 10 instead of default 50 to accelerate the computation 1116 ) 1117 self.set_forward_func_return_logits_only(False) 1118 1119 attribution_abs_batch_mean = torch.mean( 1120 torch.abs(attribution), 1121 dim=[ 1122 i for i in range(attribution.dim()) if i != 1 1123 ], # average the features over batch samples 1124 ) 1125 1126 importance_step_layer = attribution_abs_batch_mean 1127 importance_step_layer = importance_step_layer.detach() 1128 1129 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of internal influence. We implement this using Internal Influence in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): starting point from which integral is computed in this method. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1131 def get_importance_step_layer_gradcam_abs( 1132 self: str, 1133 layer_name: str, 1134 input: Tensor | tuple[Tensor, ...], 1135 target: Tensor | None, 1136 batch_idx: int, 1137 num_batches: int, 1138 ) -> Tensor: 1139 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum. 1140 1141 **Args:** 1142 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1143 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1144 - **target** (`Tensor` | `None`): the target batch of the training step. 1145 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1146 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1147 1148 **Returns:** 1149 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1150 """ 1151 layer = self.backbone.get_layer_by_name(layer_name) 1152 1153 # initialize the GradCAM object 1154 gradcam = LayerGradCam(forward_func=self.forward, layer=layer) 1155 1156 self.set_forward_func_return_logits_only(True) 1157 # calculate layer attribution of the step 1158 attribution = gradcam.attribute( 1159 inputs=input, 1160 target=target, 1161 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1162 ) 1163 self.set_forward_func_return_logits_only(False) 1164 1165 attribution_abs_batch_mean = torch.mean( 1166 torch.abs(attribution), 1167 dim=[ 1168 i for i in range(attribution.dim()) if i != 1 1169 ], # average the features over batch samples 1170 ) 1171 1172 importance_step_layer = attribution_abs_batch_mean 1173 importance_step_layer = importance_step_layer.detach() 1174 1175 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of Grad-CAM. We implement this using Layer Grad-CAM in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1177 def get_importance_step_layer_deeplift_abs( 1178 self: str, 1179 layer_name: str, 1180 input: Tensor | tuple[Tensor, ...], 1181 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1182 target: Tensor | None, 1183 batch_idx: int, 1184 num_batches: int, 1185 ) -> Tensor: 1186 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum. 1187 1188 **Args:** 1189 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1190 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1191 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details. 1192 - **target** (`Tensor` | `None`): the target batch of the training step. 1193 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1194 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1195 1196 **Returns:** 1197 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1198 """ 1199 layer = self.backbone.get_layer_by_name(layer_name) 1200 1201 # initialize the Layer DeepLift object 1202 layer_deeplift = LayerDeepLift(model=self, layer=layer) 1203 1204 # convert the target to long type to avoid error 1205 target = target.long() if target is not None else None 1206 1207 self.set_forward_func_return_logits_only(True) 1208 # calculate layer attribution of the step 1209 attribution = layer_deeplift.attribute( 1210 inputs=input, 1211 baselines=baselines, 1212 target=target, 1213 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1214 ) 1215 self.set_forward_func_return_logits_only(False) 1216 1217 attribution_abs_batch_mean = torch.mean( 1218 torch.abs(attribution), 1219 dim=[ 1220 i for i in range(attribution.dim()) if i != 1 1221 ], # average the features over batch samples 1222 ) 1223 1224 importance_step_layer = attribution_abs_batch_mean 1225 importance_step_layer = importance_step_layer.detach() 1226 1227 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of DeepLift. We implement this using Layer DeepLift in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): baselines define reference samples that are compared with the inputs. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1229 def get_importance_step_layer_deepliftshap_abs( 1230 self: str, 1231 layer_name: str, 1232 input: Tensor | tuple[Tensor, ...], 1233 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1234 target: Tensor | None, 1235 batch_idx: int, 1236 num_batches: int, 1237 ) -> Tensor: 1238 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum. 1239 1240 **Args:** 1241 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1242 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1243 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details. 1244 - **target** (`Tensor` | `None`): the target batch of the training step. 1245 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1246 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1247 1248 **Returns:** 1249 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1250 """ 1251 layer = self.backbone.get_layer_by_name(layer_name) 1252 1253 # initialize the Layer DeepLiftShap object 1254 layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer) 1255 1256 # convert the target to long type to avoid error 1257 target = target.long() if target is not None else None 1258 1259 self.set_forward_func_return_logits_only(True) 1260 # calculate layer attribution of the step 1261 attribution = layer_deepliftshap.attribute( 1262 inputs=input, 1263 baselines=baselines, 1264 target=target, 1265 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1266 ) 1267 self.set_forward_func_return_logits_only(False) 1268 1269 attribution_abs_batch_mean = torch.mean( 1270 torch.abs(attribution), 1271 dim=[ 1272 i for i in range(attribution.dim()) if i != 1 1273 ], # average the features over batch samples 1274 ) 1275 1276 importance_step_layer = attribution_abs_batch_mean 1277 importance_step_layer = importance_step_layer.detach() 1278 1279 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of DeepLift SHAP. We implement this using Layer DeepLiftShap in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): baselines define reference samples that are compared with the inputs. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1281 def get_importance_step_layer_gradientshap_abs( 1282 self: str, 1283 layer_name: str, 1284 input: Tensor | tuple[Tensor, ...], 1285 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1286 target: Tensor | None, 1287 batch_idx: int, 1288 num_batches: int, 1289 ) -> Tensor: 1290 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum. 1291 1292 **Args:** 1293 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1294 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1295 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero. 1296 - **target** (`Tensor` | `None`): the target batch of the training step. 1297 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1298 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1299 1300 **Returns:** 1301 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1302 """ 1303 layer = self.backbone.get_layer_by_name(layer_name) 1304 1305 if baselines is None: 1306 baselines = torch.zeros_like( 1307 input 1308 ) # baselines are mandatory for GradientShap API. We explicitly set them to zero 1309 1310 # initialize the Layer GradientShap object 1311 layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer) 1312 1313 # convert the target to long type to avoid error 1314 target = target.long() if target is not None else None 1315 1316 self.set_forward_func_return_logits_only(True) 1317 # calculate layer attribution of the step 1318 attribution = layer_gradientshap.attribute( 1319 inputs=input, 1320 baselines=baselines, 1321 target=target, 1322 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1323 ) 1324 self.set_forward_func_return_logits_only(False) 1325 1326 attribution_abs_batch_mean = torch.mean( 1327 torch.abs(attribution), 1328 dim=[ 1329 i for i in range(attribution.dim()) if i != 1 1330 ], # average the features over batch samples 1331 ) 1332 1333 importance_step_layer = attribution_abs_batch_mean 1334 importance_step_layer = importance_step_layer.detach() 1335 1336 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using Layer GradientShap in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): starting point from which expectation is computed. Please refer to the Captum documentation for more details. IfNone, the baselines are set to zero. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1338 def get_importance_step_layer_integrated_gradients_abs( 1339 self: str, 1340 layer_name: str, 1341 input: Tensor | tuple[Tensor, ...], 1342 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1343 target: Tensor | None, 1344 batch_idx: int, 1345 num_batches: int, 1346 ) -> Tensor: 1347 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum. 1348 1349 **Args:** 1350 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1351 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1352 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details. 1353 - **target** (`Tensor` | `None`): the target batch of the training step. 1354 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1355 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1356 1357 **Returns:** 1358 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1359 """ 1360 layer = self.backbone.get_layer_by_name(layer_name) 1361 1362 # initialize the Layer Integrated Gradients object 1363 layer_integrated_gradients = LayerIntegratedGradients( 1364 forward_func=self.forward, layer=layer 1365 ) 1366 1367 self.set_forward_func_return_logits_only(True) 1368 # calculate layer attribution of the step 1369 attribution = layer_integrated_gradients.attribute( 1370 inputs=input, 1371 baselines=baselines, 1372 target=target, 1373 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1374 ) 1375 self.set_forward_func_return_logits_only(False) 1376 1377 attribution_abs_batch_mean = torch.mean( 1378 torch.abs(attribution), 1379 dim=[ 1380 i for i in range(attribution.dim()) if i != 1 1381 ], # average the features over batch samples 1382 ) 1383 1384 importance_step_layer = attribution_abs_batch_mean 1385 importance_step_layer = importance_step_layer.detach() 1386 1387 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of integrated gradients. We implement this using Layer Integrated Gradients in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): starting point from which integral is computed. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1389 def get_importance_step_layer_feature_ablation_abs( 1390 self: str, 1391 layer_name: str, 1392 input: Tensor | tuple[Tensor, ...], 1393 layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1394 target: Tensor | None, 1395 batch_idx: int, 1396 num_batches: int, 1397 if_captum: bool = False, 1398 ) -> Tensor: 1399 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum. 1400 1401 **Args:** 1402 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1403 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1404 - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details. 1405 - **target** (`Tensor` | `None`): the target batch of the training step. 1406 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1407 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1408 - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster. 1409 1410 **Returns:** 1411 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1412 """ 1413 layer = self.backbone.get_layer_by_name(layer_name) 1414 1415 if not if_captum: 1416 # 1. Baseline logits (take first element of forward output) 1417 baseline_out, _, _ = self.forward( 1418 input, "train", batch_idx, num_batches, self.task_id 1419 ) 1420 if target is not None: 1421 baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1) 1422 else: 1423 baseline_scores = baseline_out.sum(dim=1) 1424 1425 # 2. Capture layer’s output shape 1426 activs = {} 1427 handle = layer.register_forward_hook( 1428 lambda module, inp, out: activs.setdefault("output", out.detach()) 1429 ) 1430 _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id) 1431 handle.remove() 1432 layer_output = activs["output"] # shape (B, F, ...) 1433 1434 # 3. Build baseline tensor matching that shape 1435 if layer_baselines is None: 1436 baseline_tensor = torch.zeros_like(layer_output) 1437 elif isinstance(layer_baselines, (int, float)): 1438 baseline_tensor = torch.full_like(layer_output, layer_baselines) 1439 elif isinstance(layer_baselines, Tensor): 1440 if layer_baselines.shape == layer_output.shape: 1441 baseline_tensor = layer_baselines 1442 elif layer_baselines.shape == layer_output.shape[1:]: 1443 baseline_tensor = layer_baselines.unsqueeze(0).repeat( 1444 layer_output.size(0), *([1] * layer_baselines.ndim) 1445 ) 1446 else: 1447 raise ValueError(...) 1448 else: 1449 raise ValueError(...) 1450 1451 B, F = layer_output.size(0), layer_output.size(1) 1452 1453 # 4. Create a “mega-batch” replicating the input F times 1454 if isinstance(input, tuple): 1455 mega_inputs = tuple( 1456 t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:]) 1457 for t in input 1458 ) 1459 else: 1460 mega_inputs = ( 1461 input.unsqueeze(0) 1462 .repeat(F, *([1] * input.ndim)) 1463 .view(-1, *input.shape[1:]) 1464 ) 1465 1466 # 5. Equally replicate the baseline tensor 1467 mega_baseline = ( 1468 baseline_tensor.unsqueeze(0) 1469 .repeat(F, *([1] * baseline_tensor.ndim)) 1470 .view(-1, *baseline_tensor.shape[1:]) 1471 ) 1472 1473 # 6. Precompute vectorized indices 1474 device = layer_output.device 1475 positions = torch.arange(F * B, device=device) # [0,1,...,F*B-1] 1476 feat_idx = torch.arange(F, device=device).repeat_interleave( 1477 B 1478 ) # [0,0,...,1,1,...,F-1] 1479 1480 # 7. One hook to zero out each channel slice across the mega-batch 1481 def mega_ablate_hook(module, inp, out): 1482 out_mod = out.clone() 1483 # for each sample in mega-batch, zero its corresponding channel 1484 out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx] 1485 return out_mod 1486 1487 h = layer.register_forward_hook(mega_ablate_hook) 1488 out_all, _, _ = self.forward( 1489 mega_inputs, "train", batch_idx, num_batches, self.task_id 1490 ) 1491 h.remove() 1492 1493 # 8. Recover scores, reshape [F*B] → [F, B], diff & mean 1494 if target is not None: 1495 tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1) 1496 scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1) 1497 else: 1498 scores_all = out_all.sum(dim=1) 1499 1500 scores_all = scores_all.view(F, B) 1501 diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all) 1502 importance_step_layer = diffs.mean(dim=1).detach() # [F] 1503 1504 return importance_step_layer 1505 1506 else: 1507 # initialize the Layer Feature Ablation object 1508 layer_feature_ablation = LayerFeatureAblation( 1509 forward_func=self.forward, layer=layer 1510 ) 1511 1512 # calculate layer attribution of the step 1513 self.set_forward_func_return_logits_only(True) 1514 attribution = layer_feature_ablation.attribute( 1515 inputs=input, 1516 layer_baselines=layer_baselines, 1517 # target=target, # disable target to enable perturbations_per_eval 1518 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1519 perturbations_per_eval=128, # to accelerate the computation 1520 ) 1521 self.set_forward_func_return_logits_only(False) 1522 1523 attribution_abs_batch_mean = torch.mean( 1524 torch.abs(attribution), 1525 dim=[ 1526 i for i in range(attribution.dim()) if i != 1 1527 ], # average the features over batch samples 1528 ) 1529 1530 importance_step_layer = attribution_abs_batch_mean 1531 importance_step_layer = importance_step_layer.detach() 1532 1533 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of feature ablation attribution. We implement this using Layer Feature Ablation in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - layer_baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): reference values which replace each layer input / output value when ablated. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training. - if_captum (
bool): whether to use Captum or not. IfTrue, we use Captum to calculate the feature ablation. IfFalse, we use our implementation. Default isFalse, because our implementation is much faster.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1535 def get_importance_step_layer_lrp_abs( 1536 self: str, 1537 layer_name: str, 1538 input: Tensor | tuple[Tensor, ...], 1539 target: Tensor | None, 1540 batch_idx: int, 1541 num_batches: int, 1542 ) -> Tensor: 1543 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum. 1544 1545 **Args:** 1546 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1547 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1548 - **target** (`Tensor` | `None`): the target batch of the training step. 1549 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1550 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1551 1552 **Returns:** 1553 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1554 """ 1555 layer = self.backbone.get_layer_by_name(layer_name) 1556 1557 # initialize the Layer LRP object 1558 layer_lrp = LayerLRP(model=self, layer=layer) 1559 1560 # set model to evaluation mode to prevent updating the model parameters 1561 self.eval() 1562 1563 self.set_forward_func_return_logits_only(True) 1564 # calculate layer attribution of the step 1565 attribution = layer_lrp.attribute( 1566 inputs=input, 1567 target=target, 1568 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1569 ) 1570 self.set_forward_func_return_logits_only(False) 1571 1572 attribution_abs_batch_mean = torch.mean( 1573 torch.abs(attribution), 1574 dim=[ 1575 i for i in range(attribution.dim()) if i != 1 1576 ], # average the features over batch samples 1577 ) 1578 1579 importance_step_layer = attribution_abs_batch_mean 1580 importance_step_layer = importance_step_layer.detach() 1581 1582 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of LRP. We implement this using Layer LRP in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1584 def get_importance_step_layer_cbp_adaptive_contribution( 1585 self: str, 1586 layer_name: str, 1587 activation: Tensor, 1588 ) -> Tensor: 1589 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 1590 1591 **Args:** 1592 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1593 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 1594 1595 **Returns:** 1596 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1597 """ 1598 layer = self.backbone.get_layer_by_name(layer_name) 1599 1600 input_weight_abs = torch.abs(layer.weight.data) 1601 input_weight_abs_sum = torch.sum( 1602 input_weight_abs, 1603 dim=[ 1604 i for i in range(input_weight_abs.dim()) if i != 0 1605 ], # sum over the input dimension 1606 ) 1607 input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum) 1608 1609 output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 1610 output_weight_abs_sum = torch.sum( 1611 output_weight_abs, 1612 dim=[ 1613 i for i in range(output_weight_abs.dim()) if i != 1 1614 ], # sum over the output dimension 1615 ) 1616 1617 activation_abs_batch_mean = torch.mean( 1618 torch.abs(activation), 1619 dim=[ 1620 i for i in range(activation.dim()) if i != 1 1621 ], # average the features over batch samples 1622 ) 1623 1624 importance_step_layer = ( 1625 output_weight_abs_sum 1626 * activation_abs_batch_mean 1627 * input_weight_abs_sum_reciprocal 1628 ) 1629 importance_step_layer = importance_step_layer.detach() 1630 1631 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in CBP.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
Inherited Members
- clarena.cl_algorithms.hat.HAT
- adjustment_mode
- s_max
- clamp_threshold
- mask_sparsity_reg_factor
- mask_sparsity_reg_mode
- mark_sparsity_reg
- task_embedding_init_mode
- alpha
- cumulative_mask_for_previous_tasks
- clip_grad_by_mask
- compensate_task_embedding_gradients
- forward
- training_step
- validation_step
- test_step