clarena.cl_algorithms.fgadahat
The submodule in cl_algorithms for FG-AdaHAT algorithm.
1r""" 2The submodule in `cl_algorithms` for FG-AdaHAT algorithm. 3""" 4 5__all__ = ["FGAdaHAT"] 6 7import logging 8import math 9from typing import Any 10 11import torch 12from captum.attr import ( 13 InternalInfluence, 14 LayerConductance, 15 LayerDeepLift, 16 LayerDeepLiftShap, 17 LayerFeatureAblation, 18 LayerGradCam, 19 LayerGradientShap, 20 LayerGradientXActivation, 21 LayerIntegratedGradients, 22 LayerLRP, 23) 24from torch import Tensor 25 26from clarena.backbones import HATMaskBackbone 27from clarena.cl_algorithms import AdaHAT 28from clarena.heads import HeadsTIL 29from clarena.utils.metrics import HATNetworkCapacityMetric 30from clarena.utils.transforms import min_max_normalize 31 32# always get logger for built-in logging in each module 33pylogger = logging.getLogger(__name__) 34 35 36class FGAdaHAT(AdaHAT): 37 r"""FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm. 38 39 An architecture-based continual learning approach that improves [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT. 40 41 We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components. 42 """ 43 44 def __init__( 45 self, 46 backbone: HATMaskBackbone, 47 heads: HeadsTIL, 48 adjustment_intensity: float, 49 importance_type: str, 50 importance_summing_strategy: str, 51 importance_scheduler_type: str, 52 neuron_to_weight_importance_aggregation_mode: str, 53 s_max: float, 54 clamp_threshold: float, 55 mask_sparsity_reg_factor: float, 56 mask_sparsity_reg_mode: str = "original", 57 base_importance: float = 0.01, 58 base_mask_sparsity_reg: float = 0.1, 59 base_linear: float = 10, 60 filter_by_cumulative_mask: bool = False, 61 filter_unmasked_importance: bool = True, 62 step_multiply_training_mask: bool = True, 63 task_embedding_init_mode: str = "N01", 64 importance_summing_strategy_linear_step: float | None = None, 65 importance_summing_strategy_exponential_rate: float | None = None, 66 importance_summing_strategy_log_base: float | None = None, 67 non_algorithmic_hparams: dict[str, Any] = {}, 68 ) -> None: 69 r"""Initialize the FG-AdaHAT algorithm with the network. 70 71 **Args:** 72 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 73 - **heads** (`HeadsTIL`): output heads. FG-AdaHAT supports only TIL (Task-Incremental Learning). 74 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper). 75 - **importance_type** (`str`): the type of neuron-wise importance, must be one of: 76 1. 'input_weight_abs_sum': sum of absolute input weights; 77 2. 'output_weight_abs_sum': sum of absolute output weights; 78 3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper); 79 4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper); 80 5. 'activation_abs': absolute activation; 81 6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper); 82 7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper); 83 8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation; 84 9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights; 85 10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights; 86 11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper); 87 12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation; 88 13. 'conductance_abs': absolute layer conductance; 89 14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper); 90 15. 'gradcam_abs': absolute Grad-CAM; 91 16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper); 92 17. 'deepliftshap_abs': absolute DeepLIFT-SHAP; 93 18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper); 94 19. 'integrated_gradients_abs': absolute Integrated Gradients; 95 20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper); 96 21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP); 97 22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 98 23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 99 - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of: 100 1. 'add_latest': add the latest neuron-wise importance to the summative importance; 101 2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance; 102 3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance; 103 4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID; 104 5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID; 105 6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID; 106 7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID; 107 8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID; 108 9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID; 109 - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of: 110 1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`); 111 2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$; 112 3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$. 113 - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of: 114 1. 'min': take the minimum of neuron-wise importance for each weight; 115 2. 'max': take the maximum of neuron-wise importance for each weight; 116 3. 'mean': take the mean of neuron-wise importance for each weight. 117 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 118 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 119 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 120 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 121 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 122 2. 'cross': the cross version of mask sparsity regularization. 123 - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01. 124 - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1. 125 - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10. 126 - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False. 127 - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False. 128 - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True. 129 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 130 1. 'N01' (default): standard normal distribution $N(0, 1)$. 131 2. 'U-11': uniform distribution $U(-1, 1)$. 132 3. 'U01': uniform distribution $U(0, 1)$. 133 4. 'U-10': uniform distribution $U(-1, 0)$. 134 5. 'last': inherit the task embedding from the last task. 135 - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0. 136 - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1. 137 - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1. 138 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 139 140 """ 141 super().__init__( 142 backbone=backbone, 143 heads=heads, 144 adjustment_mode=None, # use the own adjustment mechanism of FG-AdaHAT 145 adjustment_intensity=adjustment_intensity, 146 s_max=s_max, 147 clamp_threshold=clamp_threshold, 148 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 149 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 150 task_embedding_init_mode=task_embedding_init_mode, 151 epsilon=base_mask_sparsity_reg, # the epsilon is now the base mask sparsity regularization factor 152 non_algorithmic_hparams=non_algorithmic_hparams, 153 ) 154 155 # save additional algorithmic hyperparameters 156 self.save_hyperparameters( 157 "adjustment_intensity", 158 "importance_type", 159 "importance_summing_strategy", 160 "importance_scheduler_type", 161 "neuron_to_weight_importance_aggregation_mode", 162 "s_max", 163 "clamp_threshold", 164 "mask_sparsity_reg_factor", 165 "mask_sparsity_reg_mode", 166 "base_importance", 167 "base_mask_sparsity_reg", 168 "base_linear", 169 "filter_by_cumulative_mask", 170 "filter_unmasked_importance", 171 "step_multiply_training_mask", 172 ) 173 174 self.importance_type: str | None = importance_type 175 r"""The type of the neuron-wise importance added to AdaHAT importance.""" 176 177 self.importance_scheduler_type: str = importance_scheduler_type 178 r"""The type of the importance scheduler.""" 179 self.neuron_to_weight_importance_aggregation_mode: str = ( 180 neuron_to_weight_importance_aggregation_mode 181 ) 182 r"""The mode of aggregation from neuron-wise to weight-wise importance. """ 183 self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask 184 r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate.""" 185 self.filter_unmasked_importance: bool = filter_unmasked_importance 186 r"""The flag to filter unmasked importance values (set them to 0) at the end of task training.""" 187 self.step_multiply_training_mask: bool = step_multiply_training_mask 188 r"""The flag to multiply the training mask to the importance at each training step.""" 189 190 # importance summing strategy 191 self.importance_summing_strategy: str = importance_summing_strategy 192 r"""The strategy to sum the neuron-wise importance for previous tasks.""" 193 if importance_summing_strategy_linear_step is not None: 194 self.importance_summing_strategy_linear_step: float = ( 195 importance_summing_strategy_linear_step 196 ) 197 r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease').""" 198 if importance_summing_strategy_exponential_rate is not None: 199 self.importance_summing_strategy_exponential_rate: float = ( 200 importance_summing_strategy_exponential_rate 201 ) 202 r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """ 203 if importance_summing_strategy_log_base is not None: 204 self.importance_summing_strategy_log_base: float = ( 205 importance_summing_strategy_log_base 206 ) 207 r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """ 208 209 # base values 210 self.base_importance: float = base_importance 211 r"""The base value added to the importance to avoid zero. """ 212 self.base_mask_sparsity_reg: float = base_mask_sparsity_reg 213 r"""The base value added to the mask sparsity regularization to avoid zero. """ 214 self.base_linear: float = base_linear 215 r"""The base value added to the linear layer to avoid zero. """ 216 217 self.importances: dict[int, dict[str, Tensor]] = {} 218 r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 219 self.summative_importance_for_previous_tasks: dict[str, Tensor] = {} 220 r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """ 221 222 self.num_steps_t: int 223 r"""The number of training steps for the current task `self.task_id`.""" 224 # set manual optimization 225 self.automatic_optimization = False 226 227 FGAdaHAT.sanity_check(self) 228 229 def sanity_check(self) -> None: 230 r"""Sanity check.""" 231 232 # check importance type 233 if self.importance_type not in [ 234 "input_weight_abs_sum", 235 "output_weight_abs_sum", 236 "input_weight_gradient_abs_sum", 237 "output_weight_gradient_abs_sum", 238 "activation_abs", 239 "input_weight_abs_sum_x_activation_abs", 240 "output_weight_abs_sum_x_activation_abs", 241 "gradient_x_activation_abs", 242 "input_weight_gradient_square_sum", 243 "output_weight_gradient_square_sum", 244 "input_weight_gradient_square_sum_x_activation_abs", 245 "output_weight_gradient_square_sum_x_activation_abs", 246 "conductance_abs", 247 "internal_influence_abs", 248 "gradcam_abs", 249 "deeplift_abs", 250 "deepliftshap_abs", 251 "gradientshap_abs", 252 "integrated_gradients_abs", 253 "feature_ablation_abs", 254 "lrp_abs", 255 "cbp_adaptation", 256 "cbp_adaptive_contribution", 257 ]: 258 raise ValueError( 259 f"importance_type must be one of the predefined types, but got {self.importance_type}" 260 ) 261 262 # check importance summing strategy 263 if self.importance_summing_strategy not in [ 264 "add_latest", 265 "add_all", 266 "add_average", 267 "linear_decrease", 268 "quadratic_decrease", 269 "cubic_decrease", 270 "exponential_decrease", 271 "log_decrease", 272 "factorial_decrease", 273 ]: 274 raise ValueError( 275 f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}" 276 ) 277 278 # check importance scheduler type 279 if self.importance_scheduler_type not in [ 280 "linear_sparsity_reg", 281 "sparsity_reg", 282 "summative_mask_sparsity_reg", 283 ]: 284 raise ValueError( 285 f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}" 286 ) 287 288 # check neuron to weight importance aggregation mode 289 if self.neuron_to_weight_importance_aggregation_mode not in [ 290 "min", 291 "max", 292 "mean", 293 ]: 294 raise ValueError( 295 f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}" 296 ) 297 298 # check base values 299 if self.base_importance < 0: 300 raise ValueError( 301 f"base_importance must be >= 0, but got {self.base_importance}" 302 ) 303 if self.base_mask_sparsity_reg <= 0: 304 raise ValueError( 305 f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}" 306 ) 307 if self.base_linear <= 0: 308 raise ValueError(f"base_linear must be > 0, but got {self.base_linear}") 309 310 def on_train_start(self) -> None: 311 r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.""" 312 super().on_train_start() 313 314 self.importances[self.task_id] = ( 315 {} 316 ) # initialize the importance for the current task 317 318 # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time. 319 for layer_name in self.backbone.weighted_layer_names: 320 layer = self.backbone.get_layer_by_name( 321 layer_name 322 ) # get the layer by its name 323 num_units = layer.weight.shape[0] 324 325 # initialize the accumulated importance at the beginning of each task 326 self.importances[self.task_id][layer_name] = torch.zeros(num_units).to( 327 self.device 328 ) 329 330 # reset the number of steps counter for the current task 331 self.num_steps_t = 0 332 333 # initialize the summative neuron-wise importance at the beginning of the first task 334 if self.task_id == 1: 335 self.summative_importance_for_previous_tasks[layer_name] = torch.zeros( 336 num_units 337 ).to( 338 self.device 339 ) # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$ 340 341 def clip_grad_by_adjustment( 342 self, 343 network_sparsity: dict[str, Tensor], 344 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 345 r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper. 346 347 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 348 349 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 350 351 **Args:** 352 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler. 353 354 **Returns:** 355 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 356 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 357 - **capacity** (`Tensor`): the calculated network capacity. 358 """ 359 360 # initialize network capacity metric 361 capacity = HATNetworkCapacityMetric().to(self.device) 362 adjustment_rate_weight = {} 363 adjustment_rate_bias = {} 364 365 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper 366 for layer_name in self.backbone.weighted_layer_names: 367 368 layer = self.backbone.get_layer_by_name( 369 layer_name 370 ) # get the layer by its name 371 372 # placeholder for the adjustment rate to avoid the error of using it before assignment 373 adjustment_rate_weight_layer = 1 374 adjustment_rate_bias_layer = 1 375 376 # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask 377 weight_importance, bias_importance = ( 378 self.backbone.get_layer_measure_parameter_wise( 379 neuron_wise_measure=self.summative_importance_for_previous_tasks, 380 layer_name=layer_name, 381 aggregation_mode=self.neuron_to_weight_importance_aggregation_mode, 382 ) 383 ) 384 385 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 386 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 387 layer_name=layer_name, 388 aggregation_mode="min", 389 ) 390 391 # filter the weight importance by the cumulative mask 392 if self.filter_by_cumulative_mask: 393 weight_importance = weight_importance * weight_mask 394 bias_importance = bias_importance * bias_mask 395 396 network_sparsity_layer = network_sparsity[layer_name] 397 398 # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper 399 factor = network_sparsity_layer + self.base_mask_sparsity_reg 400 if self.importance_scheduler_type == "linear_sparsity_reg": 401 factor = factor * (self.task_id + self.base_linear) 402 elif self.importance_scheduler_type == "sparsity_reg": 403 pass 404 elif self.importance_scheduler_type == "summative_mask_sparsity_reg": 405 factor = factor * ( 406 self.summative_mask_for_previous_tasks + self.base_linear 407 ) 408 409 # calculate the adjustment rate 410 adjustment_rate_weight_layer = torch.div( 411 self.adjustment_intensity, 412 (factor * weight_importance + self.adjustment_intensity), 413 ) 414 415 adjustment_rate_bias_layer = torch.div( 416 self.adjustment_intensity, 417 (factor * bias_importance + self.adjustment_intensity), 418 ) 419 420 # apply the adjustment rate to the gradients 421 layer.weight.grad.data *= adjustment_rate_weight_layer 422 if layer.bias is not None: 423 layer.bias.grad.data *= adjustment_rate_bias_layer 424 425 # store the adjustment rate for logging 426 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 427 if layer.bias is not None: 428 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 429 430 # update network capacity metric 431 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 432 433 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 434 435 def on_train_batch_end( 436 self, outputs: dict[str, Any], batch: Any, batch_idx: int 437 ) -> None: 438 r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step. 439 440 **Args:** 441 - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`). 442 - **batch** (`Any`): training data batch. 443 - **batch_idx** (`int`): index of the current batch (for mask figure file name). 444 """ 445 446 # get potential useful information from training batch 447 activations = outputs["activations"] 448 input = outputs["input"] 449 target = outputs["target"] 450 mask = outputs["mask"] 451 num_batches = self.trainer.num_training_batches 452 453 for layer_name in self.backbone.weighted_layer_names: 454 # layer-wise operation 455 456 activation = activations[layer_name] 457 458 # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. 459 if self.importance_type == "input_weight_abs_sum": 460 importance_step = self.get_importance_step_layer_weight_abs_sum( 461 layer_name=layer_name, 462 if_output_weight=False, 463 reciprocal=False, 464 ) 465 elif self.importance_type == "output_weight_abs_sum": 466 importance_step = self.get_importance_step_layer_weight_abs_sum( 467 layer_name=layer_name, 468 if_output_weight=True, 469 reciprocal=False, 470 ) 471 elif self.importance_type == "input_weight_gradient_abs_sum": 472 importance_step = ( 473 self.get_importance_step_layer_weight_gradient_abs_sum( 474 layer_name=layer_name, if_output_weight=False 475 ) 476 ) 477 elif self.importance_type == "output_weight_gradient_abs_sum": 478 importance_step = ( 479 self.get_importance_step_layer_weight_gradient_abs_sum( 480 layer_name=layer_name, if_output_weight=True 481 ) 482 ) 483 elif self.importance_type == "activation_abs": 484 importance_step = self.get_importance_step_layer_activation_abs( 485 activation=activation 486 ) 487 elif self.importance_type == "input_weight_abs_sum_x_activation_abs": 488 importance_step = ( 489 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 490 layer_name=layer_name, 491 activation=activation, 492 if_output_weight=False, 493 ) 494 ) 495 elif self.importance_type == "output_weight_abs_sum_x_activation_abs": 496 importance_step = ( 497 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 498 layer_name=layer_name, 499 activation=activation, 500 if_output_weight=True, 501 ) 502 ) 503 elif self.importance_type == "gradient_x_activation_abs": 504 importance_step = ( 505 self.get_importance_step_layer_gradient_x_activation_abs( 506 layer_name=layer_name, 507 input=input, 508 target=target, 509 batch_idx=batch_idx, 510 num_batches=num_batches, 511 ) 512 ) 513 elif self.importance_type == "input_weight_gradient_square_sum": 514 importance_step = ( 515 self.get_importance_step_layer_weight_gradient_square_sum( 516 layer_name=layer_name, 517 activation=activation, 518 if_output_weight=False, 519 ) 520 ) 521 elif self.importance_type == "output_weight_gradient_square_sum": 522 importance_step = ( 523 self.get_importance_step_layer_weight_gradient_square_sum( 524 layer_name=layer_name, 525 activation=activation, 526 if_output_weight=True, 527 ) 528 ) 529 elif ( 530 self.importance_type 531 == "input_weight_gradient_square_sum_x_activation_abs" 532 ): 533 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 534 layer_name=layer_name, 535 activation=activation, 536 if_output_weight=False, 537 ) 538 elif ( 539 self.importance_type 540 == "output_weight_gradient_square_sum_x_activation_abs" 541 ): 542 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 543 layer_name=layer_name, 544 activation=activation, 545 if_output_weight=True, 546 ) 547 elif self.importance_type == "conductance_abs": 548 importance_step = self.get_importance_step_layer_conductance_abs( 549 layer_name=layer_name, 550 input=input, 551 baselines=None, 552 target=target, 553 batch_idx=batch_idx, 554 num_batches=num_batches, 555 ) 556 elif self.importance_type == "internal_influence_abs": 557 importance_step = self.get_importance_step_layer_internal_influence_abs( 558 layer_name=layer_name, 559 input=input, 560 baselines=None, 561 target=target, 562 batch_idx=batch_idx, 563 num_batches=num_batches, 564 ) 565 elif self.importance_type == "gradcam_abs": 566 importance_step = self.get_importance_step_layer_gradcam_abs( 567 layer_name=layer_name, 568 input=input, 569 target=target, 570 batch_idx=batch_idx, 571 num_batches=num_batches, 572 ) 573 elif self.importance_type == "deeplift_abs": 574 importance_step = self.get_importance_step_layer_deeplift_abs( 575 layer_name=layer_name, 576 input=input, 577 baselines=None, 578 target=target, 579 batch_idx=batch_idx, 580 num_batches=num_batches, 581 ) 582 elif self.importance_type == "deepliftshap_abs": 583 importance_step = self.get_importance_step_layer_deepliftshap_abs( 584 layer_name=layer_name, 585 input=input, 586 baselines=None, 587 target=target, 588 batch_idx=batch_idx, 589 num_batches=num_batches, 590 ) 591 elif self.importance_type == "gradientshap_abs": 592 importance_step = self.get_importance_step_layer_gradientshap_abs( 593 layer_name=layer_name, 594 input=input, 595 baselines=None, 596 target=target, 597 batch_idx=batch_idx, 598 num_batches=num_batches, 599 ) 600 elif self.importance_type == "integrated_gradients_abs": 601 importance_step = ( 602 self.get_importance_step_layer_integrated_gradients_abs( 603 layer_name=layer_name, 604 input=input, 605 baselines=None, 606 target=target, 607 batch_idx=batch_idx, 608 num_batches=num_batches, 609 ) 610 ) 611 elif self.importance_type == "feature_ablation_abs": 612 importance_step = self.get_importance_step_layer_feature_ablation_abs( 613 layer_name=layer_name, 614 input=input, 615 layer_baselines=None, 616 target=target, 617 batch_idx=batch_idx, 618 num_batches=num_batches, 619 ) 620 elif self.importance_type == "lrp_abs": 621 importance_step = self.get_importance_step_layer_lrp_abs( 622 layer_name=layer_name, 623 input=input, 624 target=target, 625 batch_idx=batch_idx, 626 num_batches=num_batches, 627 ) 628 elif self.importance_type == "cbp_adaptation": 629 importance_step = self.get_importance_step_layer_weight_abs_sum( 630 layer_name=layer_name, 631 if_output_weight=False, 632 reciprocal=True, 633 ) 634 elif self.importance_type == "cbp_adaptive_contribution": 635 importance_step = ( 636 self.get_importance_step_layer_cbp_adaptive_contribution( 637 layer_name=layer_name, 638 activation=activation, 639 ) 640 ) 641 642 importance_step = min_max_normalize( 643 importance_step 644 ) # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper 645 646 # multiply the importance by the training mask. See Eq. (6) in the paper 647 if self.step_multiply_training_mask: 648 importance_step = importance_step * mask[layer_name] 649 650 # update accumulated importance 651 self.importances[self.task_id][layer_name] = ( 652 self.importances[self.task_id][layer_name] + importance_step 653 ) 654 655 # update number of steps counter 656 self.num_steps_t += 1 657 658 def on_train_end(self) -> None: 659 r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task.""" 660 super().on_train_end() # store the mask and update cumulative and summative masks 661 662 for layer_name in self.backbone.weighted_layer_names: 663 664 # average the neuron-wise step importance. See Eq. (4) in the paper 665 self.importances[self.task_id][layer_name] = ( 666 self.importances[self.task_id][layer_name] 667 ) / self.num_steps_t 668 669 # add the base importance. See Eq. (6) in the paper 670 self.importances[self.task_id][layer_name] = ( 671 self.importances[self.task_id][layer_name] + self.base_importance 672 ) 673 674 # filter unmasked importance 675 if self.filter_unmasked_importance: 676 self.importances[self.task_id][layer_name] = ( 677 self.importances[self.task_id][layer_name] 678 * self.backbone.masks[f"{self.task_id}"][layer_name] 679 ) 680 681 # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper 682 if self.importance_summing_strategy == "add_latest": 683 self.summative_importance_for_previous_tasks[ 684 layer_name 685 ] += self.importances[self.task_id][layer_name] 686 687 elif self.importance_summing_strategy == "add_all": 688 for t in range(1, self.task_id + 1): 689 self.summative_importance_for_previous_tasks[ 690 layer_name 691 ] += self.importances[t][layer_name] 692 693 elif self.importance_summing_strategy == "add_average": 694 for t in range(1, self.task_id + 1): 695 self.summative_importance_for_previous_tasks[layer_name] += ( 696 self.importances[t][layer_name] / self.task_id 697 ) 698 else: 699 self.summative_importance_for_previous_tasks[ 700 layer_name 701 ] = torch.zeros_like( 702 self.summative_importance_for_previous_tasks[layer_name] 703 ).to( 704 self.device 705 ) # starting adding from 0 706 707 if self.importance_summing_strategy == "linear_decrease": 708 s = self.importance_summing_strategy_linear_step 709 for t in range(1, self.task_id + 1): 710 w_t = s * (self.task_id - t) + 1 711 712 elif self.importance_summing_strategy == "quadratic_decrease": 713 for t in range(1, self.task_id + 1): 714 w_t = (self.task_id - t + 1) ** 2 715 elif self.importance_summing_strategy == "cubic_decrease": 716 for t in range(1, self.task_id + 1): 717 w_t = (self.task_id - t + 1) ** 3 718 elif self.importance_summing_strategy == "exponential_decrease": 719 for t in range(1, self.task_id + 1): 720 r = self.importance_summing_strategy_exponential_rate 721 722 w_t = r ** (self.task_id - t + 1) 723 elif self.importance_summing_strategy == "log_decrease": 724 a = self.importance_summing_strategy_log_base 725 for t in range(1, self.task_id + 1): 726 w_t = math.log(self.task_id - t, a) + 1 727 elif self.importance_summing_strategy == "factorial_decrease": 728 for t in range(1, self.task_id + 1): 729 w_t = math.factorial(self.task_id - t + 1) 730 else: 731 raise ValueError 732 self.summative_importance_for_previous_tasks[layer_name] += ( 733 self.importances[t][layer_name] * w_t 734 ) 735 736 def get_importance_step_layer_weight_abs_sum( 737 self: str, 738 layer_name: str, 739 if_output_weight: bool, 740 reciprocal: bool, 741 ) -> Tensor: 742 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights. 743 744 **Args:** 745 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 746 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 747 - **reciprocal** (`bool`): whether to take reciprocal. 748 749 **Returns:** 750 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 751 """ 752 layer = self.backbone.get_layer_by_name(layer_name) 753 754 if not if_output_weight: 755 weight_abs = torch.abs(layer.weight.data) 756 weight_abs_sum = torch.sum( 757 weight_abs, 758 dim=[ 759 i for i in range(weight_abs.dim()) if i != 0 760 ], # sum over the input dimension 761 ) 762 else: 763 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 764 weight_abs_sum = torch.sum( 765 weight_abs, 766 dim=[ 767 i for i in range(weight_abs.dim()) if i != 1 768 ], # sum over the output dimension 769 ) 770 771 if reciprocal: 772 weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum) 773 importance_step_layer = weight_abs_sum_reciprocal 774 else: 775 importance_step_layer = weight_abs_sum 776 importance_step_layer = importance_step_layer.detach() 777 778 return importance_step_layer 779 780 def get_importance_step_layer_weight_gradient_abs_sum( 781 self: str, 782 layer_name: str, 783 if_output_weight: bool, 784 ) -> Tensor: 785 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights. 786 787 **Args:** 788 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 789 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 790 791 **Returns:** 792 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 793 """ 794 layer = self.backbone.get_layer_by_name(layer_name) 795 796 if not if_output_weight: 797 gradient_abs = torch.abs(layer.weight.grad.data) 798 gradient_abs_sum = torch.sum( 799 gradient_abs, 800 dim=[ 801 i for i in range(gradient_abs.dim()) if i != 0 802 ], # sum over the input dimension 803 ) 804 else: 805 gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data) 806 gradient_abs_sum = torch.sum( 807 gradient_abs, 808 dim=[ 809 i for i in range(gradient_abs.dim()) if i != 1 810 ], # sum over the output dimension 811 ) 812 813 importance_step_layer = gradient_abs_sum 814 importance_step_layer = importance_step_layer.detach() 815 816 return importance_step_layer 817 818 def get_importance_step_layer_activation_abs( 819 self: str, 820 activation: Tensor, 821 ) -> Tensor: 822 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum. 823 824 **Args:** 825 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 826 827 **Returns:** 828 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 829 """ 830 activation_abs_batch_mean = torch.mean( 831 torch.abs(activation), 832 dim=[ 833 i for i in range(activation.dim()) if i != 1 834 ], # average the features over batch samples 835 ) 836 importance_step_layer = activation_abs_batch_mean 837 importance_step_layer = importance_step_layer.detach() 838 839 return importance_step_layer 840 841 def get_importance_step_layer_weight_abs_sum_x_activation_abs( 842 self: str, 843 layer_name: str, 844 activation: Tensor, 845 if_output_weight: bool, 846 ) -> Tensor: 847 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 848 849 **Args:** 850 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 851 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 852 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 853 854 **Returns:** 855 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 856 """ 857 layer = self.backbone.get_layer_by_name(layer_name) 858 859 if not if_output_weight: 860 weight_abs = torch.abs(layer.weight.data) 861 weight_abs_sum = torch.sum( 862 weight_abs, 863 dim=[ 864 i for i in range(weight_abs.dim()) if i != 0 865 ], # sum over the input dimension 866 ) 867 else: 868 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 869 weight_abs_sum = torch.sum( 870 weight_abs, 871 dim=[ 872 i for i in range(weight_abs.dim()) if i != 1 873 ], # sum over the output dimension 874 ) 875 876 activation_abs_batch_mean = torch.mean( 877 torch.abs(activation), 878 dim=[ 879 i for i in range(activation.dim()) if i != 1 880 ], # average the features over batch samples 881 ) 882 883 importance_step_layer = weight_abs_sum * activation_abs_batch_mean 884 importance_step_layer = importance_step_layer.detach() 885 886 return importance_step_layer 887 888 def get_importance_step_layer_gradient_x_activation_abs( 889 self: str, 890 layer_name: str, 891 input: Tensor | tuple[Tensor, ...], 892 target: Tensor | None, 893 batch_idx: int, 894 num_batches: int, 895 ) -> Tensor: 896 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum. 897 898 **Args:** 899 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 900 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 901 - **target** (`Tensor` | `None`): the target batch of the training step. 902 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 903 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 904 905 **Returns:** 906 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 907 """ 908 layer = self.backbone.get_layer_by_name(layer_name) 909 910 input = input.requires_grad_() 911 912 # initialize the Layer Gradient X Activation object 913 layer_gradient_x_activation = LayerGradientXActivation( 914 forward_func=self.forward, layer=layer 915 ) 916 917 self.set_forward_func_return_logits_only(True) 918 # calculate layer attribution of the step 919 attribution = layer_gradient_x_activation.attribute( 920 inputs=input, 921 target=target, 922 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 923 ) 924 self.set_forward_func_return_logits_only(False) 925 926 attribution_abs_batch_mean = torch.mean( 927 torch.abs(attribution), 928 dim=[ 929 i for i in range(attribution.dim()) if i != 1 930 ], # average the features over batch samples 931 ) 932 933 importance_step_layer = attribution_abs_batch_mean 934 importance_step_layer = importance_step_layer.detach() 935 936 return importance_step_layer 937 938 def get_importance_step_layer_weight_gradient_square_sum( 939 self: str, 940 layer_name: str, 941 activation: Tensor, 942 if_output_weight: bool, 943 ) -> Tensor: 944 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 945 946 **Args:** 947 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 948 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 949 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 950 951 **Returns:** 952 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 953 """ 954 layer = self.backbone.get_layer_by_name(layer_name) 955 956 if not if_output_weight: 957 gradient_square = layer.weight.grad.data**2 958 gradient_square_sum = torch.sum( 959 gradient_square, 960 dim=[ 961 i for i in range(gradient_square.dim()) if i != 0 962 ], # sum over the input dimension 963 ) 964 else: 965 gradient_square = self.next_layer(layer_name).weight.grad.data**2 966 gradient_square_sum = torch.sum( 967 gradient_square, 968 dim=[ 969 i for i in range(gradient_square.dim()) if i != 1 970 ], # sum over the output dimension 971 ) 972 973 importance_step_layer = gradient_square_sum 974 importance_step_layer = importance_step_layer.detach() 975 976 return importance_step_layer 977 978 def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 979 self: str, 980 layer_name: str, 981 activation: Tensor, 982 if_output_weight: bool, 983 ) -> Tensor: 984 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 985 986 **Args:** 987 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 988 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 989 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 990 991 **Returns:** 992 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 993 """ 994 layer = self.backbone.get_layer_by_name(layer_name) 995 996 if not if_output_weight: 997 gradient_square = layer.weight.grad.data**2 998 gradient_square_sum = torch.sum( 999 gradient_square, 1000 dim=[ 1001 i for i in range(gradient_square.dim()) if i != 0 1002 ], # sum over the input dimension 1003 ) 1004 else: 1005 gradient_square = self.next_layer(layer_name).weight.grad.data**2 1006 gradient_square_sum = torch.sum( 1007 gradient_square, 1008 dim=[ 1009 i for i in range(gradient_square.dim()) if i != 1 1010 ], # sum over the output dimension 1011 ) 1012 1013 activation_abs_batch_mean = torch.mean( 1014 torch.abs(activation), 1015 dim=[ 1016 i for i in range(activation.dim()) if i != 1 1017 ], # average the features over batch samples 1018 ) 1019 1020 importance_step_layer = gradient_square_sum * activation_abs_batch_mean 1021 importance_step_layer = importance_step_layer.detach() 1022 1023 return importance_step_layer 1024 1025 def get_importance_step_layer_conductance_abs( 1026 self: str, 1027 layer_name: str, 1028 input: Tensor | tuple[Tensor, ...], 1029 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1030 target: Tensor | None, 1031 batch_idx: int, 1032 num_batches: int, 1033 ) -> Tensor: 1034 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum. 1035 1036 **Args:** 1037 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1038 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1039 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details. 1040 - **target** (`Tensor` | `None`): the target batch of the training step. 1041 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1042 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ). 1043 1044 **Returns:** 1045 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1046 """ 1047 layer = self.backbone.get_layer_by_name(layer_name) 1048 1049 # initialize the Layer Conductance object 1050 layer_conductance = LayerConductance(forward_func=self.forward, layer=layer) 1051 1052 self.set_forward_func_return_logits_only(True) 1053 # calculate layer attribution of the step 1054 attribution = layer_conductance.attribute( 1055 inputs=input, 1056 baselines=baselines, 1057 target=target, 1058 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1059 ) 1060 self.set_forward_func_return_logits_only(False) 1061 1062 attribution_abs_batch_mean = torch.mean( 1063 torch.abs(attribution), 1064 dim=[ 1065 i for i in range(attribution.dim()) if i != 1 1066 ], # average the features over batch samples 1067 ) 1068 1069 importance_step_layer = attribution_abs_batch_mean 1070 importance_step_layer = importance_step_layer.detach() 1071 1072 return importance_step_layer 1073 1074 def get_importance_step_layer_internal_influence_abs( 1075 self: str, 1076 layer_name: str, 1077 input: Tensor | tuple[Tensor, ...], 1078 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1079 target: Tensor | None, 1080 batch_idx: int, 1081 num_batches: int, 1082 ) -> Tensor: 1083 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum. 1084 1085 **Args:** 1086 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1087 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1088 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details. 1089 - **target** (`Tensor` | `None`): the target batch of the training step. 1090 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1091 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1092 1093 **Returns:** 1094 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1095 """ 1096 layer = self.backbone.get_layer_by_name(layer_name) 1097 1098 # initialize the Internal Influence object 1099 internal_influence = InternalInfluence(forward_func=self.forward, layer=layer) 1100 1101 # convert the target to long type to avoid error 1102 target = target.long() if target is not None else None 1103 1104 self.set_forward_func_return_logits_only(True) 1105 # calculate layer attribution of the step 1106 attribution = internal_influence.attribute( 1107 inputs=input, 1108 baselines=baselines, 1109 target=target, 1110 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1111 n_steps=5, # set 10 instead of default 50 to accelerate the computation 1112 ) 1113 self.set_forward_func_return_logits_only(False) 1114 1115 attribution_abs_batch_mean = torch.mean( 1116 torch.abs(attribution), 1117 dim=[ 1118 i for i in range(attribution.dim()) if i != 1 1119 ], # average the features over batch samples 1120 ) 1121 1122 importance_step_layer = attribution_abs_batch_mean 1123 importance_step_layer = importance_step_layer.detach() 1124 1125 return importance_step_layer 1126 1127 def get_importance_step_layer_gradcam_abs( 1128 self: str, 1129 layer_name: str, 1130 input: Tensor | tuple[Tensor, ...], 1131 target: Tensor | None, 1132 batch_idx: int, 1133 num_batches: int, 1134 ) -> Tensor: 1135 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum. 1136 1137 **Args:** 1138 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1139 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1140 - **target** (`Tensor` | `None`): the target batch of the training step. 1141 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1142 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1143 1144 **Returns:** 1145 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1146 """ 1147 layer = self.backbone.get_layer_by_name(layer_name) 1148 1149 # initialize the GradCAM object 1150 gradcam = LayerGradCam(forward_func=self.forward, layer=layer) 1151 1152 self.set_forward_func_return_logits_only(True) 1153 # calculate layer attribution of the step 1154 attribution = gradcam.attribute( 1155 inputs=input, 1156 target=target, 1157 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1158 ) 1159 self.set_forward_func_return_logits_only(False) 1160 1161 attribution_abs_batch_mean = torch.mean( 1162 torch.abs(attribution), 1163 dim=[ 1164 i for i in range(attribution.dim()) if i != 1 1165 ], # average the features over batch samples 1166 ) 1167 1168 importance_step_layer = attribution_abs_batch_mean 1169 importance_step_layer = importance_step_layer.detach() 1170 1171 return importance_step_layer 1172 1173 def get_importance_step_layer_deeplift_abs( 1174 self: str, 1175 layer_name: str, 1176 input: Tensor | tuple[Tensor, ...], 1177 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1178 target: Tensor | None, 1179 batch_idx: int, 1180 num_batches: int, 1181 ) -> Tensor: 1182 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum. 1183 1184 **Args:** 1185 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1186 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1187 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details. 1188 - **target** (`Tensor` | `None`): the target batch of the training step. 1189 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1190 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1191 1192 **Returns:** 1193 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1194 """ 1195 layer = self.backbone.get_layer_by_name(layer_name) 1196 1197 # initialize the Layer DeepLift object 1198 layer_deeplift = LayerDeepLift(model=self, layer=layer) 1199 1200 # convert the target to long type to avoid error 1201 target = target.long() if target is not None else None 1202 1203 self.set_forward_func_return_logits_only(True) 1204 # calculate layer attribution of the step 1205 attribution = layer_deeplift.attribute( 1206 inputs=input, 1207 baselines=baselines, 1208 target=target, 1209 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1210 ) 1211 self.set_forward_func_return_logits_only(False) 1212 1213 attribution_abs_batch_mean = torch.mean( 1214 torch.abs(attribution), 1215 dim=[ 1216 i for i in range(attribution.dim()) if i != 1 1217 ], # average the features over batch samples 1218 ) 1219 1220 importance_step_layer = attribution_abs_batch_mean 1221 importance_step_layer = importance_step_layer.detach() 1222 1223 return importance_step_layer 1224 1225 def get_importance_step_layer_deepliftshap_abs( 1226 self: str, 1227 layer_name: str, 1228 input: Tensor | tuple[Tensor, ...], 1229 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1230 target: Tensor | None, 1231 batch_idx: int, 1232 num_batches: int, 1233 ) -> Tensor: 1234 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum. 1235 1236 **Args:** 1237 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1238 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1239 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details. 1240 - **target** (`Tensor` | `None`): the target batch of the training step. 1241 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1242 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1243 1244 **Returns:** 1245 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1246 """ 1247 layer = self.backbone.get_layer_by_name(layer_name) 1248 1249 # initialize the Layer DeepLiftShap object 1250 layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer) 1251 1252 # convert the target to long type to avoid error 1253 target = target.long() if target is not None else None 1254 1255 self.set_forward_func_return_logits_only(True) 1256 # calculate layer attribution of the step 1257 attribution = layer_deepliftshap.attribute( 1258 inputs=input, 1259 baselines=baselines, 1260 target=target, 1261 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1262 ) 1263 self.set_forward_func_return_logits_only(False) 1264 1265 attribution_abs_batch_mean = torch.mean( 1266 torch.abs(attribution), 1267 dim=[ 1268 i for i in range(attribution.dim()) if i != 1 1269 ], # average the features over batch samples 1270 ) 1271 1272 importance_step_layer = attribution_abs_batch_mean 1273 importance_step_layer = importance_step_layer.detach() 1274 1275 return importance_step_layer 1276 1277 def get_importance_step_layer_gradientshap_abs( 1278 self: str, 1279 layer_name: str, 1280 input: Tensor | tuple[Tensor, ...], 1281 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1282 target: Tensor | None, 1283 batch_idx: int, 1284 num_batches: int, 1285 ) -> Tensor: 1286 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum. 1287 1288 **Args:** 1289 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1290 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1291 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero. 1292 - **target** (`Tensor` | `None`): the target batch of the training step. 1293 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1294 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1295 1296 **Returns:** 1297 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1298 """ 1299 layer = self.backbone.get_layer_by_name(layer_name) 1300 1301 if baselines is None: 1302 baselines = torch.zeros_like( 1303 input 1304 ) # baselines are mandatory for GradientShap API. We explicitly set them to zero 1305 1306 # initialize the Layer GradientShap object 1307 layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer) 1308 1309 # convert the target to long type to avoid error 1310 target = target.long() if target is not None else None 1311 1312 self.set_forward_func_return_logits_only(True) 1313 # calculate layer attribution of the step 1314 attribution = layer_gradientshap.attribute( 1315 inputs=input, 1316 baselines=baselines, 1317 target=target, 1318 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1319 ) 1320 self.set_forward_func_return_logits_only(False) 1321 1322 attribution_abs_batch_mean = torch.mean( 1323 torch.abs(attribution), 1324 dim=[ 1325 i for i in range(attribution.dim()) if i != 1 1326 ], # average the features over batch samples 1327 ) 1328 1329 importance_step_layer = attribution_abs_batch_mean 1330 importance_step_layer = importance_step_layer.detach() 1331 1332 return importance_step_layer 1333 1334 def get_importance_step_layer_integrated_gradients_abs( 1335 self: str, 1336 layer_name: str, 1337 input: Tensor | tuple[Tensor, ...], 1338 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1339 target: Tensor | None, 1340 batch_idx: int, 1341 num_batches: int, 1342 ) -> Tensor: 1343 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum. 1344 1345 **Args:** 1346 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1347 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1348 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details. 1349 - **target** (`Tensor` | `None`): the target batch of the training step. 1350 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1351 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1352 1353 **Returns:** 1354 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1355 """ 1356 layer = self.backbone.get_layer_by_name(layer_name) 1357 1358 # initialize the Layer Integrated Gradients object 1359 layer_integrated_gradients = LayerIntegratedGradients( 1360 forward_func=self.forward, layer=layer 1361 ) 1362 1363 self.set_forward_func_return_logits_only(True) 1364 # calculate layer attribution of the step 1365 attribution = layer_integrated_gradients.attribute( 1366 inputs=input, 1367 baselines=baselines, 1368 target=target, 1369 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1370 ) 1371 self.set_forward_func_return_logits_only(False) 1372 1373 attribution_abs_batch_mean = torch.mean( 1374 torch.abs(attribution), 1375 dim=[ 1376 i for i in range(attribution.dim()) if i != 1 1377 ], # average the features over batch samples 1378 ) 1379 1380 importance_step_layer = attribution_abs_batch_mean 1381 importance_step_layer = importance_step_layer.detach() 1382 1383 return importance_step_layer 1384 1385 def get_importance_step_layer_feature_ablation_abs( 1386 self: str, 1387 layer_name: str, 1388 input: Tensor | tuple[Tensor, ...], 1389 layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1390 target: Tensor | None, 1391 batch_idx: int, 1392 num_batches: int, 1393 if_captum: bool = False, 1394 ) -> Tensor: 1395 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum. 1396 1397 **Args:** 1398 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1399 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1400 - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details. 1401 - **target** (`Tensor` | `None`): the target batch of the training step. 1402 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1403 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1404 - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster. 1405 1406 **Returns:** 1407 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1408 """ 1409 layer = self.backbone.get_layer_by_name(layer_name) 1410 1411 if not if_captum: 1412 # 1. Baseline logits (take first element of forward output) 1413 baseline_out, _, _ = self.forward( 1414 input, "train", batch_idx, num_batches, self.task_id 1415 ) 1416 if target is not None: 1417 baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1) 1418 else: 1419 baseline_scores = baseline_out.sum(dim=1) 1420 1421 # 2. Capture layer’s output shape 1422 activs = {} 1423 handle = layer.register_forward_hook( 1424 lambda module, inp, out: activs.setdefault("output", out.detach()) 1425 ) 1426 _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id) 1427 handle.remove() 1428 layer_output = activs["output"] # shape (B, F, ...) 1429 1430 # 3. Build baseline tensor matching that shape 1431 if layer_baselines is None: 1432 baseline_tensor = torch.zeros_like(layer_output) 1433 elif isinstance(layer_baselines, (int, float)): 1434 baseline_tensor = torch.full_like(layer_output, layer_baselines) 1435 elif isinstance(layer_baselines, Tensor): 1436 if layer_baselines.shape == layer_output.shape: 1437 baseline_tensor = layer_baselines 1438 elif layer_baselines.shape == layer_output.shape[1:]: 1439 baseline_tensor = layer_baselines.unsqueeze(0).repeat( 1440 layer_output.size(0), *([1] * layer_baselines.ndim) 1441 ) 1442 else: 1443 raise ValueError(...) 1444 else: 1445 raise ValueError(...) 1446 1447 B, F = layer_output.size(0), layer_output.size(1) 1448 1449 # 4. Create a “mega-batch” replicating the input F times 1450 if isinstance(input, tuple): 1451 mega_inputs = tuple( 1452 t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:]) 1453 for t in input 1454 ) 1455 else: 1456 mega_inputs = ( 1457 input.unsqueeze(0) 1458 .repeat(F, *([1] * input.ndim)) 1459 .view(-1, *input.shape[1:]) 1460 ) 1461 1462 # 5. Equally replicate the baseline tensor 1463 mega_baseline = ( 1464 baseline_tensor.unsqueeze(0) 1465 .repeat(F, *([1] * baseline_tensor.ndim)) 1466 .view(-1, *baseline_tensor.shape[1:]) 1467 ) 1468 1469 # 6. Precompute vectorized indices 1470 device = layer_output.device 1471 positions = torch.arange(F * B, device=device) # [0,1,...,F*B-1] 1472 feat_idx = torch.arange(F, device=device).repeat_interleave( 1473 B 1474 ) # [0,0,...,1,1,...,F-1] 1475 1476 # 7. One hook to zero out each channel slice across the mega-batch 1477 def mega_ablate_hook(module, inp, out): 1478 out_mod = out.clone() 1479 # for each sample in mega-batch, zero its corresponding channel 1480 out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx] 1481 return out_mod 1482 1483 h = layer.register_forward_hook(mega_ablate_hook) 1484 out_all, _, _ = self.forward( 1485 mega_inputs, "train", batch_idx, num_batches, self.task_id 1486 ) 1487 h.remove() 1488 1489 # 8. Recover scores, reshape [F*B] → [F, B], diff & mean 1490 if target is not None: 1491 tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1) 1492 scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1) 1493 else: 1494 scores_all = out_all.sum(dim=1) 1495 1496 scores_all = scores_all.view(F, B) 1497 diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all) 1498 importance_step_layer = diffs.mean(dim=1).detach() # [F] 1499 1500 return importance_step_layer 1501 1502 else: 1503 # initialize the Layer Feature Ablation object 1504 layer_feature_ablation = LayerFeatureAblation( 1505 forward_func=self.forward, layer=layer 1506 ) 1507 1508 # calculate layer attribution of the step 1509 self.set_forward_func_return_logits_only(True) 1510 attribution = layer_feature_ablation.attribute( 1511 inputs=input, 1512 layer_baselines=layer_baselines, 1513 # target=target, # disable target to enable perturbations_per_eval 1514 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1515 perturbations_per_eval=128, # to accelerate the computation 1516 ) 1517 self.set_forward_func_return_logits_only(False) 1518 1519 attribution_abs_batch_mean = torch.mean( 1520 torch.abs(attribution), 1521 dim=[ 1522 i for i in range(attribution.dim()) if i != 1 1523 ], # average the features over batch samples 1524 ) 1525 1526 importance_step_layer = attribution_abs_batch_mean 1527 importance_step_layer = importance_step_layer.detach() 1528 1529 return importance_step_layer 1530 1531 def get_importance_step_layer_lrp_abs( 1532 self: str, 1533 layer_name: str, 1534 input: Tensor | tuple[Tensor, ...], 1535 target: Tensor | None, 1536 batch_idx: int, 1537 num_batches: int, 1538 ) -> Tensor: 1539 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum. 1540 1541 **Args:** 1542 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1543 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1544 - **target** (`Tensor` | `None`): the target batch of the training step. 1545 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1546 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1547 1548 **Returns:** 1549 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1550 """ 1551 layer = self.backbone.get_layer_by_name(layer_name) 1552 1553 # initialize the Layer LRP object 1554 layer_lrp = LayerLRP(model=self, layer=layer) 1555 1556 # set model to evaluation mode to prevent updating the model parameters 1557 self.eval() 1558 1559 self.set_forward_func_return_logits_only(True) 1560 # calculate layer attribution of the step 1561 attribution = layer_lrp.attribute( 1562 inputs=input, 1563 target=target, 1564 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1565 ) 1566 self.set_forward_func_return_logits_only(False) 1567 1568 attribution_abs_batch_mean = torch.mean( 1569 torch.abs(attribution), 1570 dim=[ 1571 i for i in range(attribution.dim()) if i != 1 1572 ], # average the features over batch samples 1573 ) 1574 1575 importance_step_layer = attribution_abs_batch_mean 1576 importance_step_layer = importance_step_layer.detach() 1577 1578 return importance_step_layer 1579 1580 def get_importance_step_layer_cbp_adaptive_contribution( 1581 self: str, 1582 layer_name: str, 1583 activation: Tensor, 1584 ) -> Tensor: 1585 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 1586 1587 **Args:** 1588 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1589 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 1590 1591 **Returns:** 1592 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1593 """ 1594 layer = self.backbone.get_layer_by_name(layer_name) 1595 1596 input_weight_abs = torch.abs(layer.weight.data) 1597 input_weight_abs_sum = torch.sum( 1598 input_weight_abs, 1599 dim=[ 1600 i for i in range(input_weight_abs.dim()) if i != 0 1601 ], # sum over the input dimension 1602 ) 1603 input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum) 1604 1605 output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 1606 output_weight_abs_sum = torch.sum( 1607 output_weight_abs, 1608 dim=[ 1609 i for i in range(output_weight_abs.dim()) if i != 1 1610 ], # sum over the output dimension 1611 ) 1612 1613 activation_abs_batch_mean = torch.mean( 1614 torch.abs(activation), 1615 dim=[ 1616 i for i in range(activation.dim()) if i != 1 1617 ], # average the features over batch samples 1618 ) 1619 1620 importance_step_layer = ( 1621 output_weight_abs_sum 1622 * activation_abs_batch_mean 1623 * input_weight_abs_sum_reciprocal 1624 ) 1625 importance_step_layer = importance_step_layer.detach() 1626 1627 return importance_step_layer
37class FGAdaHAT(AdaHAT): 38 r"""FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm. 39 40 An architecture-based continual learning approach that improves [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT. 41 42 We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components. 43 """ 44 45 def __init__( 46 self, 47 backbone: HATMaskBackbone, 48 heads: HeadsTIL, 49 adjustment_intensity: float, 50 importance_type: str, 51 importance_summing_strategy: str, 52 importance_scheduler_type: str, 53 neuron_to_weight_importance_aggregation_mode: str, 54 s_max: float, 55 clamp_threshold: float, 56 mask_sparsity_reg_factor: float, 57 mask_sparsity_reg_mode: str = "original", 58 base_importance: float = 0.01, 59 base_mask_sparsity_reg: float = 0.1, 60 base_linear: float = 10, 61 filter_by_cumulative_mask: bool = False, 62 filter_unmasked_importance: bool = True, 63 step_multiply_training_mask: bool = True, 64 task_embedding_init_mode: str = "N01", 65 importance_summing_strategy_linear_step: float | None = None, 66 importance_summing_strategy_exponential_rate: float | None = None, 67 importance_summing_strategy_log_base: float | None = None, 68 non_algorithmic_hparams: dict[str, Any] = {}, 69 ) -> None: 70 r"""Initialize the FG-AdaHAT algorithm with the network. 71 72 **Args:** 73 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 74 - **heads** (`HeadsTIL`): output heads. FG-AdaHAT supports only TIL (Task-Incremental Learning). 75 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper). 76 - **importance_type** (`str`): the type of neuron-wise importance, must be one of: 77 1. 'input_weight_abs_sum': sum of absolute input weights; 78 2. 'output_weight_abs_sum': sum of absolute output weights; 79 3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper); 80 4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper); 81 5. 'activation_abs': absolute activation; 82 6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper); 83 7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper); 84 8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation; 85 9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights; 86 10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights; 87 11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper); 88 12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation; 89 13. 'conductance_abs': absolute layer conductance; 90 14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper); 91 15. 'gradcam_abs': absolute Grad-CAM; 92 16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper); 93 17. 'deepliftshap_abs': absolute DeepLIFT-SHAP; 94 18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper); 95 19. 'integrated_gradients_abs': absolute Integrated Gradients; 96 20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper); 97 21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP); 98 22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 99 23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 100 - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of: 101 1. 'add_latest': add the latest neuron-wise importance to the summative importance; 102 2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance; 103 3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance; 104 4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID; 105 5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID; 106 6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID; 107 7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID; 108 8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID; 109 9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID; 110 - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of: 111 1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`); 112 2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$; 113 3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$. 114 - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of: 115 1. 'min': take the minimum of neuron-wise importance for each weight; 116 2. 'max': take the maximum of neuron-wise importance for each weight; 117 3. 'mean': take the mean of neuron-wise importance for each weight. 118 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 119 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 120 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 121 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 122 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 123 2. 'cross': the cross version of mask sparsity regularization. 124 - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01. 125 - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1. 126 - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10. 127 - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False. 128 - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False. 129 - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True. 130 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 131 1. 'N01' (default): standard normal distribution $N(0, 1)$. 132 2. 'U-11': uniform distribution $U(-1, 1)$. 133 3. 'U01': uniform distribution $U(0, 1)$. 134 4. 'U-10': uniform distribution $U(-1, 0)$. 135 5. 'last': inherit the task embedding from the last task. 136 - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0. 137 - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1. 138 - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1. 139 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 140 141 """ 142 super().__init__( 143 backbone=backbone, 144 heads=heads, 145 adjustment_mode=None, # use the own adjustment mechanism of FG-AdaHAT 146 adjustment_intensity=adjustment_intensity, 147 s_max=s_max, 148 clamp_threshold=clamp_threshold, 149 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 150 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 151 task_embedding_init_mode=task_embedding_init_mode, 152 epsilon=base_mask_sparsity_reg, # the epsilon is now the base mask sparsity regularization factor 153 non_algorithmic_hparams=non_algorithmic_hparams, 154 ) 155 156 # save additional algorithmic hyperparameters 157 self.save_hyperparameters( 158 "adjustment_intensity", 159 "importance_type", 160 "importance_summing_strategy", 161 "importance_scheduler_type", 162 "neuron_to_weight_importance_aggregation_mode", 163 "s_max", 164 "clamp_threshold", 165 "mask_sparsity_reg_factor", 166 "mask_sparsity_reg_mode", 167 "base_importance", 168 "base_mask_sparsity_reg", 169 "base_linear", 170 "filter_by_cumulative_mask", 171 "filter_unmasked_importance", 172 "step_multiply_training_mask", 173 ) 174 175 self.importance_type: str | None = importance_type 176 r"""The type of the neuron-wise importance added to AdaHAT importance.""" 177 178 self.importance_scheduler_type: str = importance_scheduler_type 179 r"""The type of the importance scheduler.""" 180 self.neuron_to_weight_importance_aggregation_mode: str = ( 181 neuron_to_weight_importance_aggregation_mode 182 ) 183 r"""The mode of aggregation from neuron-wise to weight-wise importance. """ 184 self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask 185 r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate.""" 186 self.filter_unmasked_importance: bool = filter_unmasked_importance 187 r"""The flag to filter unmasked importance values (set them to 0) at the end of task training.""" 188 self.step_multiply_training_mask: bool = step_multiply_training_mask 189 r"""The flag to multiply the training mask to the importance at each training step.""" 190 191 # importance summing strategy 192 self.importance_summing_strategy: str = importance_summing_strategy 193 r"""The strategy to sum the neuron-wise importance for previous tasks.""" 194 if importance_summing_strategy_linear_step is not None: 195 self.importance_summing_strategy_linear_step: float = ( 196 importance_summing_strategy_linear_step 197 ) 198 r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease').""" 199 if importance_summing_strategy_exponential_rate is not None: 200 self.importance_summing_strategy_exponential_rate: float = ( 201 importance_summing_strategy_exponential_rate 202 ) 203 r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """ 204 if importance_summing_strategy_log_base is not None: 205 self.importance_summing_strategy_log_base: float = ( 206 importance_summing_strategy_log_base 207 ) 208 r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """ 209 210 # base values 211 self.base_importance: float = base_importance 212 r"""The base value added to the importance to avoid zero. """ 213 self.base_mask_sparsity_reg: float = base_mask_sparsity_reg 214 r"""The base value added to the mask sparsity regularization to avoid zero. """ 215 self.base_linear: float = base_linear 216 r"""The base value added to the linear layer to avoid zero. """ 217 218 self.importances: dict[int, dict[str, Tensor]] = {} 219 r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 220 self.summative_importance_for_previous_tasks: dict[str, Tensor] = {} 221 r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """ 222 223 self.num_steps_t: int 224 r"""The number of training steps for the current task `self.task_id`.""" 225 # set manual optimization 226 self.automatic_optimization = False 227 228 FGAdaHAT.sanity_check(self) 229 230 def sanity_check(self) -> None: 231 r"""Sanity check.""" 232 233 # check importance type 234 if self.importance_type not in [ 235 "input_weight_abs_sum", 236 "output_weight_abs_sum", 237 "input_weight_gradient_abs_sum", 238 "output_weight_gradient_abs_sum", 239 "activation_abs", 240 "input_weight_abs_sum_x_activation_abs", 241 "output_weight_abs_sum_x_activation_abs", 242 "gradient_x_activation_abs", 243 "input_weight_gradient_square_sum", 244 "output_weight_gradient_square_sum", 245 "input_weight_gradient_square_sum_x_activation_abs", 246 "output_weight_gradient_square_sum_x_activation_abs", 247 "conductance_abs", 248 "internal_influence_abs", 249 "gradcam_abs", 250 "deeplift_abs", 251 "deepliftshap_abs", 252 "gradientshap_abs", 253 "integrated_gradients_abs", 254 "feature_ablation_abs", 255 "lrp_abs", 256 "cbp_adaptation", 257 "cbp_adaptive_contribution", 258 ]: 259 raise ValueError( 260 f"importance_type must be one of the predefined types, but got {self.importance_type}" 261 ) 262 263 # check importance summing strategy 264 if self.importance_summing_strategy not in [ 265 "add_latest", 266 "add_all", 267 "add_average", 268 "linear_decrease", 269 "quadratic_decrease", 270 "cubic_decrease", 271 "exponential_decrease", 272 "log_decrease", 273 "factorial_decrease", 274 ]: 275 raise ValueError( 276 f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}" 277 ) 278 279 # check importance scheduler type 280 if self.importance_scheduler_type not in [ 281 "linear_sparsity_reg", 282 "sparsity_reg", 283 "summative_mask_sparsity_reg", 284 ]: 285 raise ValueError( 286 f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}" 287 ) 288 289 # check neuron to weight importance aggregation mode 290 if self.neuron_to_weight_importance_aggregation_mode not in [ 291 "min", 292 "max", 293 "mean", 294 ]: 295 raise ValueError( 296 f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}" 297 ) 298 299 # check base values 300 if self.base_importance < 0: 301 raise ValueError( 302 f"base_importance must be >= 0, but got {self.base_importance}" 303 ) 304 if self.base_mask_sparsity_reg <= 0: 305 raise ValueError( 306 f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}" 307 ) 308 if self.base_linear <= 0: 309 raise ValueError(f"base_linear must be > 0, but got {self.base_linear}") 310 311 def on_train_start(self) -> None: 312 r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.""" 313 super().on_train_start() 314 315 self.importances[self.task_id] = ( 316 {} 317 ) # initialize the importance for the current task 318 319 # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time. 320 for layer_name in self.backbone.weighted_layer_names: 321 layer = self.backbone.get_layer_by_name( 322 layer_name 323 ) # get the layer by its name 324 num_units = layer.weight.shape[0] 325 326 # initialize the accumulated importance at the beginning of each task 327 self.importances[self.task_id][layer_name] = torch.zeros(num_units).to( 328 self.device 329 ) 330 331 # reset the number of steps counter for the current task 332 self.num_steps_t = 0 333 334 # initialize the summative neuron-wise importance at the beginning of the first task 335 if self.task_id == 1: 336 self.summative_importance_for_previous_tasks[layer_name] = torch.zeros( 337 num_units 338 ).to( 339 self.device 340 ) # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$ 341 342 def clip_grad_by_adjustment( 343 self, 344 network_sparsity: dict[str, Tensor], 345 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 346 r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper. 347 348 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 349 350 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 351 352 **Args:** 353 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler. 354 355 **Returns:** 356 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 357 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 358 - **capacity** (`Tensor`): the calculated network capacity. 359 """ 360 361 # initialize network capacity metric 362 capacity = HATNetworkCapacityMetric().to(self.device) 363 adjustment_rate_weight = {} 364 adjustment_rate_bias = {} 365 366 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper 367 for layer_name in self.backbone.weighted_layer_names: 368 369 layer = self.backbone.get_layer_by_name( 370 layer_name 371 ) # get the layer by its name 372 373 # placeholder for the adjustment rate to avoid the error of using it before assignment 374 adjustment_rate_weight_layer = 1 375 adjustment_rate_bias_layer = 1 376 377 # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask 378 weight_importance, bias_importance = ( 379 self.backbone.get_layer_measure_parameter_wise( 380 neuron_wise_measure=self.summative_importance_for_previous_tasks, 381 layer_name=layer_name, 382 aggregation_mode=self.neuron_to_weight_importance_aggregation_mode, 383 ) 384 ) 385 386 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 387 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 388 layer_name=layer_name, 389 aggregation_mode="min", 390 ) 391 392 # filter the weight importance by the cumulative mask 393 if self.filter_by_cumulative_mask: 394 weight_importance = weight_importance * weight_mask 395 bias_importance = bias_importance * bias_mask 396 397 network_sparsity_layer = network_sparsity[layer_name] 398 399 # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper 400 factor = network_sparsity_layer + self.base_mask_sparsity_reg 401 if self.importance_scheduler_type == "linear_sparsity_reg": 402 factor = factor * (self.task_id + self.base_linear) 403 elif self.importance_scheduler_type == "sparsity_reg": 404 pass 405 elif self.importance_scheduler_type == "summative_mask_sparsity_reg": 406 factor = factor * ( 407 self.summative_mask_for_previous_tasks + self.base_linear 408 ) 409 410 # calculate the adjustment rate 411 adjustment_rate_weight_layer = torch.div( 412 self.adjustment_intensity, 413 (factor * weight_importance + self.adjustment_intensity), 414 ) 415 416 adjustment_rate_bias_layer = torch.div( 417 self.adjustment_intensity, 418 (factor * bias_importance + self.adjustment_intensity), 419 ) 420 421 # apply the adjustment rate to the gradients 422 layer.weight.grad.data *= adjustment_rate_weight_layer 423 if layer.bias is not None: 424 layer.bias.grad.data *= adjustment_rate_bias_layer 425 426 # store the adjustment rate for logging 427 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 428 if layer.bias is not None: 429 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 430 431 # update network capacity metric 432 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 433 434 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute() 435 436 def on_train_batch_end( 437 self, outputs: dict[str, Any], batch: Any, batch_idx: int 438 ) -> None: 439 r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step. 440 441 **Args:** 442 - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`). 443 - **batch** (`Any`): training data batch. 444 - **batch_idx** (`int`): index of the current batch (for mask figure file name). 445 """ 446 447 # get potential useful information from training batch 448 activations = outputs["activations"] 449 input = outputs["input"] 450 target = outputs["target"] 451 mask = outputs["mask"] 452 num_batches = self.trainer.num_training_batches 453 454 for layer_name in self.backbone.weighted_layer_names: 455 # layer-wise operation 456 457 activation = activations[layer_name] 458 459 # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. 460 if self.importance_type == "input_weight_abs_sum": 461 importance_step = self.get_importance_step_layer_weight_abs_sum( 462 layer_name=layer_name, 463 if_output_weight=False, 464 reciprocal=False, 465 ) 466 elif self.importance_type == "output_weight_abs_sum": 467 importance_step = self.get_importance_step_layer_weight_abs_sum( 468 layer_name=layer_name, 469 if_output_weight=True, 470 reciprocal=False, 471 ) 472 elif self.importance_type == "input_weight_gradient_abs_sum": 473 importance_step = ( 474 self.get_importance_step_layer_weight_gradient_abs_sum( 475 layer_name=layer_name, if_output_weight=False 476 ) 477 ) 478 elif self.importance_type == "output_weight_gradient_abs_sum": 479 importance_step = ( 480 self.get_importance_step_layer_weight_gradient_abs_sum( 481 layer_name=layer_name, if_output_weight=True 482 ) 483 ) 484 elif self.importance_type == "activation_abs": 485 importance_step = self.get_importance_step_layer_activation_abs( 486 activation=activation 487 ) 488 elif self.importance_type == "input_weight_abs_sum_x_activation_abs": 489 importance_step = ( 490 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 491 layer_name=layer_name, 492 activation=activation, 493 if_output_weight=False, 494 ) 495 ) 496 elif self.importance_type == "output_weight_abs_sum_x_activation_abs": 497 importance_step = ( 498 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 499 layer_name=layer_name, 500 activation=activation, 501 if_output_weight=True, 502 ) 503 ) 504 elif self.importance_type == "gradient_x_activation_abs": 505 importance_step = ( 506 self.get_importance_step_layer_gradient_x_activation_abs( 507 layer_name=layer_name, 508 input=input, 509 target=target, 510 batch_idx=batch_idx, 511 num_batches=num_batches, 512 ) 513 ) 514 elif self.importance_type == "input_weight_gradient_square_sum": 515 importance_step = ( 516 self.get_importance_step_layer_weight_gradient_square_sum( 517 layer_name=layer_name, 518 activation=activation, 519 if_output_weight=False, 520 ) 521 ) 522 elif self.importance_type == "output_weight_gradient_square_sum": 523 importance_step = ( 524 self.get_importance_step_layer_weight_gradient_square_sum( 525 layer_name=layer_name, 526 activation=activation, 527 if_output_weight=True, 528 ) 529 ) 530 elif ( 531 self.importance_type 532 == "input_weight_gradient_square_sum_x_activation_abs" 533 ): 534 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 535 layer_name=layer_name, 536 activation=activation, 537 if_output_weight=False, 538 ) 539 elif ( 540 self.importance_type 541 == "output_weight_gradient_square_sum_x_activation_abs" 542 ): 543 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 544 layer_name=layer_name, 545 activation=activation, 546 if_output_weight=True, 547 ) 548 elif self.importance_type == "conductance_abs": 549 importance_step = self.get_importance_step_layer_conductance_abs( 550 layer_name=layer_name, 551 input=input, 552 baselines=None, 553 target=target, 554 batch_idx=batch_idx, 555 num_batches=num_batches, 556 ) 557 elif self.importance_type == "internal_influence_abs": 558 importance_step = self.get_importance_step_layer_internal_influence_abs( 559 layer_name=layer_name, 560 input=input, 561 baselines=None, 562 target=target, 563 batch_idx=batch_idx, 564 num_batches=num_batches, 565 ) 566 elif self.importance_type == "gradcam_abs": 567 importance_step = self.get_importance_step_layer_gradcam_abs( 568 layer_name=layer_name, 569 input=input, 570 target=target, 571 batch_idx=batch_idx, 572 num_batches=num_batches, 573 ) 574 elif self.importance_type == "deeplift_abs": 575 importance_step = self.get_importance_step_layer_deeplift_abs( 576 layer_name=layer_name, 577 input=input, 578 baselines=None, 579 target=target, 580 batch_idx=batch_idx, 581 num_batches=num_batches, 582 ) 583 elif self.importance_type == "deepliftshap_abs": 584 importance_step = self.get_importance_step_layer_deepliftshap_abs( 585 layer_name=layer_name, 586 input=input, 587 baselines=None, 588 target=target, 589 batch_idx=batch_idx, 590 num_batches=num_batches, 591 ) 592 elif self.importance_type == "gradientshap_abs": 593 importance_step = self.get_importance_step_layer_gradientshap_abs( 594 layer_name=layer_name, 595 input=input, 596 baselines=None, 597 target=target, 598 batch_idx=batch_idx, 599 num_batches=num_batches, 600 ) 601 elif self.importance_type == "integrated_gradients_abs": 602 importance_step = ( 603 self.get_importance_step_layer_integrated_gradients_abs( 604 layer_name=layer_name, 605 input=input, 606 baselines=None, 607 target=target, 608 batch_idx=batch_idx, 609 num_batches=num_batches, 610 ) 611 ) 612 elif self.importance_type == "feature_ablation_abs": 613 importance_step = self.get_importance_step_layer_feature_ablation_abs( 614 layer_name=layer_name, 615 input=input, 616 layer_baselines=None, 617 target=target, 618 batch_idx=batch_idx, 619 num_batches=num_batches, 620 ) 621 elif self.importance_type == "lrp_abs": 622 importance_step = self.get_importance_step_layer_lrp_abs( 623 layer_name=layer_name, 624 input=input, 625 target=target, 626 batch_idx=batch_idx, 627 num_batches=num_batches, 628 ) 629 elif self.importance_type == "cbp_adaptation": 630 importance_step = self.get_importance_step_layer_weight_abs_sum( 631 layer_name=layer_name, 632 if_output_weight=False, 633 reciprocal=True, 634 ) 635 elif self.importance_type == "cbp_adaptive_contribution": 636 importance_step = ( 637 self.get_importance_step_layer_cbp_adaptive_contribution( 638 layer_name=layer_name, 639 activation=activation, 640 ) 641 ) 642 643 importance_step = min_max_normalize( 644 importance_step 645 ) # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper 646 647 # multiply the importance by the training mask. See Eq. (6) in the paper 648 if self.step_multiply_training_mask: 649 importance_step = importance_step * mask[layer_name] 650 651 # update accumulated importance 652 self.importances[self.task_id][layer_name] = ( 653 self.importances[self.task_id][layer_name] + importance_step 654 ) 655 656 # update number of steps counter 657 self.num_steps_t += 1 658 659 def on_train_end(self) -> None: 660 r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task.""" 661 super().on_train_end() # store the mask and update cumulative and summative masks 662 663 for layer_name in self.backbone.weighted_layer_names: 664 665 # average the neuron-wise step importance. See Eq. (4) in the paper 666 self.importances[self.task_id][layer_name] = ( 667 self.importances[self.task_id][layer_name] 668 ) / self.num_steps_t 669 670 # add the base importance. See Eq. (6) in the paper 671 self.importances[self.task_id][layer_name] = ( 672 self.importances[self.task_id][layer_name] + self.base_importance 673 ) 674 675 # filter unmasked importance 676 if self.filter_unmasked_importance: 677 self.importances[self.task_id][layer_name] = ( 678 self.importances[self.task_id][layer_name] 679 * self.backbone.masks[f"{self.task_id}"][layer_name] 680 ) 681 682 # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper 683 if self.importance_summing_strategy == "add_latest": 684 self.summative_importance_for_previous_tasks[ 685 layer_name 686 ] += self.importances[self.task_id][layer_name] 687 688 elif self.importance_summing_strategy == "add_all": 689 for t in range(1, self.task_id + 1): 690 self.summative_importance_for_previous_tasks[ 691 layer_name 692 ] += self.importances[t][layer_name] 693 694 elif self.importance_summing_strategy == "add_average": 695 for t in range(1, self.task_id + 1): 696 self.summative_importance_for_previous_tasks[layer_name] += ( 697 self.importances[t][layer_name] / self.task_id 698 ) 699 else: 700 self.summative_importance_for_previous_tasks[ 701 layer_name 702 ] = torch.zeros_like( 703 self.summative_importance_for_previous_tasks[layer_name] 704 ).to( 705 self.device 706 ) # starting adding from 0 707 708 if self.importance_summing_strategy == "linear_decrease": 709 s = self.importance_summing_strategy_linear_step 710 for t in range(1, self.task_id + 1): 711 w_t = s * (self.task_id - t) + 1 712 713 elif self.importance_summing_strategy == "quadratic_decrease": 714 for t in range(1, self.task_id + 1): 715 w_t = (self.task_id - t + 1) ** 2 716 elif self.importance_summing_strategy == "cubic_decrease": 717 for t in range(1, self.task_id + 1): 718 w_t = (self.task_id - t + 1) ** 3 719 elif self.importance_summing_strategy == "exponential_decrease": 720 for t in range(1, self.task_id + 1): 721 r = self.importance_summing_strategy_exponential_rate 722 723 w_t = r ** (self.task_id - t + 1) 724 elif self.importance_summing_strategy == "log_decrease": 725 a = self.importance_summing_strategy_log_base 726 for t in range(1, self.task_id + 1): 727 w_t = math.log(self.task_id - t, a) + 1 728 elif self.importance_summing_strategy == "factorial_decrease": 729 for t in range(1, self.task_id + 1): 730 w_t = math.factorial(self.task_id - t + 1) 731 else: 732 raise ValueError 733 self.summative_importance_for_previous_tasks[layer_name] += ( 734 self.importances[t][layer_name] * w_t 735 ) 736 737 def get_importance_step_layer_weight_abs_sum( 738 self: str, 739 layer_name: str, 740 if_output_weight: bool, 741 reciprocal: bool, 742 ) -> Tensor: 743 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights. 744 745 **Args:** 746 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 747 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 748 - **reciprocal** (`bool`): whether to take reciprocal. 749 750 **Returns:** 751 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 752 """ 753 layer = self.backbone.get_layer_by_name(layer_name) 754 755 if not if_output_weight: 756 weight_abs = torch.abs(layer.weight.data) 757 weight_abs_sum = torch.sum( 758 weight_abs, 759 dim=[ 760 i for i in range(weight_abs.dim()) if i != 0 761 ], # sum over the input dimension 762 ) 763 else: 764 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 765 weight_abs_sum = torch.sum( 766 weight_abs, 767 dim=[ 768 i for i in range(weight_abs.dim()) if i != 1 769 ], # sum over the output dimension 770 ) 771 772 if reciprocal: 773 weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum) 774 importance_step_layer = weight_abs_sum_reciprocal 775 else: 776 importance_step_layer = weight_abs_sum 777 importance_step_layer = importance_step_layer.detach() 778 779 return importance_step_layer 780 781 def get_importance_step_layer_weight_gradient_abs_sum( 782 self: str, 783 layer_name: str, 784 if_output_weight: bool, 785 ) -> Tensor: 786 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights. 787 788 **Args:** 789 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 790 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 791 792 **Returns:** 793 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 794 """ 795 layer = self.backbone.get_layer_by_name(layer_name) 796 797 if not if_output_weight: 798 gradient_abs = torch.abs(layer.weight.grad.data) 799 gradient_abs_sum = torch.sum( 800 gradient_abs, 801 dim=[ 802 i for i in range(gradient_abs.dim()) if i != 0 803 ], # sum over the input dimension 804 ) 805 else: 806 gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data) 807 gradient_abs_sum = torch.sum( 808 gradient_abs, 809 dim=[ 810 i for i in range(gradient_abs.dim()) if i != 1 811 ], # sum over the output dimension 812 ) 813 814 importance_step_layer = gradient_abs_sum 815 importance_step_layer = importance_step_layer.detach() 816 817 return importance_step_layer 818 819 def get_importance_step_layer_activation_abs( 820 self: str, 821 activation: Tensor, 822 ) -> Tensor: 823 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum. 824 825 **Args:** 826 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 827 828 **Returns:** 829 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 830 """ 831 activation_abs_batch_mean = torch.mean( 832 torch.abs(activation), 833 dim=[ 834 i for i in range(activation.dim()) if i != 1 835 ], # average the features over batch samples 836 ) 837 importance_step_layer = activation_abs_batch_mean 838 importance_step_layer = importance_step_layer.detach() 839 840 return importance_step_layer 841 842 def get_importance_step_layer_weight_abs_sum_x_activation_abs( 843 self: str, 844 layer_name: str, 845 activation: Tensor, 846 if_output_weight: bool, 847 ) -> Tensor: 848 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 849 850 **Args:** 851 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 852 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 853 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 854 855 **Returns:** 856 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 857 """ 858 layer = self.backbone.get_layer_by_name(layer_name) 859 860 if not if_output_weight: 861 weight_abs = torch.abs(layer.weight.data) 862 weight_abs_sum = torch.sum( 863 weight_abs, 864 dim=[ 865 i for i in range(weight_abs.dim()) if i != 0 866 ], # sum over the input dimension 867 ) 868 else: 869 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 870 weight_abs_sum = torch.sum( 871 weight_abs, 872 dim=[ 873 i for i in range(weight_abs.dim()) if i != 1 874 ], # sum over the output dimension 875 ) 876 877 activation_abs_batch_mean = torch.mean( 878 torch.abs(activation), 879 dim=[ 880 i for i in range(activation.dim()) if i != 1 881 ], # average the features over batch samples 882 ) 883 884 importance_step_layer = weight_abs_sum * activation_abs_batch_mean 885 importance_step_layer = importance_step_layer.detach() 886 887 return importance_step_layer 888 889 def get_importance_step_layer_gradient_x_activation_abs( 890 self: str, 891 layer_name: str, 892 input: Tensor | tuple[Tensor, ...], 893 target: Tensor | None, 894 batch_idx: int, 895 num_batches: int, 896 ) -> Tensor: 897 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum. 898 899 **Args:** 900 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 901 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 902 - **target** (`Tensor` | `None`): the target batch of the training step. 903 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 904 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 905 906 **Returns:** 907 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 908 """ 909 layer = self.backbone.get_layer_by_name(layer_name) 910 911 input = input.requires_grad_() 912 913 # initialize the Layer Gradient X Activation object 914 layer_gradient_x_activation = LayerGradientXActivation( 915 forward_func=self.forward, layer=layer 916 ) 917 918 self.set_forward_func_return_logits_only(True) 919 # calculate layer attribution of the step 920 attribution = layer_gradient_x_activation.attribute( 921 inputs=input, 922 target=target, 923 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 924 ) 925 self.set_forward_func_return_logits_only(False) 926 927 attribution_abs_batch_mean = torch.mean( 928 torch.abs(attribution), 929 dim=[ 930 i for i in range(attribution.dim()) if i != 1 931 ], # average the features over batch samples 932 ) 933 934 importance_step_layer = attribution_abs_batch_mean 935 importance_step_layer = importance_step_layer.detach() 936 937 return importance_step_layer 938 939 def get_importance_step_layer_weight_gradient_square_sum( 940 self: str, 941 layer_name: str, 942 activation: Tensor, 943 if_output_weight: bool, 944 ) -> Tensor: 945 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 946 947 **Args:** 948 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 949 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 950 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 951 952 **Returns:** 953 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 954 """ 955 layer = self.backbone.get_layer_by_name(layer_name) 956 957 if not if_output_weight: 958 gradient_square = layer.weight.grad.data**2 959 gradient_square_sum = torch.sum( 960 gradient_square, 961 dim=[ 962 i for i in range(gradient_square.dim()) if i != 0 963 ], # sum over the input dimension 964 ) 965 else: 966 gradient_square = self.next_layer(layer_name).weight.grad.data**2 967 gradient_square_sum = torch.sum( 968 gradient_square, 969 dim=[ 970 i for i in range(gradient_square.dim()) if i != 1 971 ], # sum over the output dimension 972 ) 973 974 importance_step_layer = gradient_square_sum 975 importance_step_layer = importance_step_layer.detach() 976 977 return importance_step_layer 978 979 def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 980 self: str, 981 layer_name: str, 982 activation: Tensor, 983 if_output_weight: bool, 984 ) -> Tensor: 985 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 986 987 **Args:** 988 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 989 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 990 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 991 992 **Returns:** 993 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 994 """ 995 layer = self.backbone.get_layer_by_name(layer_name) 996 997 if not if_output_weight: 998 gradient_square = layer.weight.grad.data**2 999 gradient_square_sum = torch.sum( 1000 gradient_square, 1001 dim=[ 1002 i for i in range(gradient_square.dim()) if i != 0 1003 ], # sum over the input dimension 1004 ) 1005 else: 1006 gradient_square = self.next_layer(layer_name).weight.grad.data**2 1007 gradient_square_sum = torch.sum( 1008 gradient_square, 1009 dim=[ 1010 i for i in range(gradient_square.dim()) if i != 1 1011 ], # sum over the output dimension 1012 ) 1013 1014 activation_abs_batch_mean = torch.mean( 1015 torch.abs(activation), 1016 dim=[ 1017 i for i in range(activation.dim()) if i != 1 1018 ], # average the features over batch samples 1019 ) 1020 1021 importance_step_layer = gradient_square_sum * activation_abs_batch_mean 1022 importance_step_layer = importance_step_layer.detach() 1023 1024 return importance_step_layer 1025 1026 def get_importance_step_layer_conductance_abs( 1027 self: str, 1028 layer_name: str, 1029 input: Tensor | tuple[Tensor, ...], 1030 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1031 target: Tensor | None, 1032 batch_idx: int, 1033 num_batches: int, 1034 ) -> Tensor: 1035 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum. 1036 1037 **Args:** 1038 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1039 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1040 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details. 1041 - **target** (`Tensor` | `None`): the target batch of the training step. 1042 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1043 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ). 1044 1045 **Returns:** 1046 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1047 """ 1048 layer = self.backbone.get_layer_by_name(layer_name) 1049 1050 # initialize the Layer Conductance object 1051 layer_conductance = LayerConductance(forward_func=self.forward, layer=layer) 1052 1053 self.set_forward_func_return_logits_only(True) 1054 # calculate layer attribution of the step 1055 attribution = layer_conductance.attribute( 1056 inputs=input, 1057 baselines=baselines, 1058 target=target, 1059 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1060 ) 1061 self.set_forward_func_return_logits_only(False) 1062 1063 attribution_abs_batch_mean = torch.mean( 1064 torch.abs(attribution), 1065 dim=[ 1066 i for i in range(attribution.dim()) if i != 1 1067 ], # average the features over batch samples 1068 ) 1069 1070 importance_step_layer = attribution_abs_batch_mean 1071 importance_step_layer = importance_step_layer.detach() 1072 1073 return importance_step_layer 1074 1075 def get_importance_step_layer_internal_influence_abs( 1076 self: str, 1077 layer_name: str, 1078 input: Tensor | tuple[Tensor, ...], 1079 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1080 target: Tensor | None, 1081 batch_idx: int, 1082 num_batches: int, 1083 ) -> Tensor: 1084 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum. 1085 1086 **Args:** 1087 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1088 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1089 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details. 1090 - **target** (`Tensor` | `None`): the target batch of the training step. 1091 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1092 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1093 1094 **Returns:** 1095 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1096 """ 1097 layer = self.backbone.get_layer_by_name(layer_name) 1098 1099 # initialize the Internal Influence object 1100 internal_influence = InternalInfluence(forward_func=self.forward, layer=layer) 1101 1102 # convert the target to long type to avoid error 1103 target = target.long() if target is not None else None 1104 1105 self.set_forward_func_return_logits_only(True) 1106 # calculate layer attribution of the step 1107 attribution = internal_influence.attribute( 1108 inputs=input, 1109 baselines=baselines, 1110 target=target, 1111 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1112 n_steps=5, # set 10 instead of default 50 to accelerate the computation 1113 ) 1114 self.set_forward_func_return_logits_only(False) 1115 1116 attribution_abs_batch_mean = torch.mean( 1117 torch.abs(attribution), 1118 dim=[ 1119 i for i in range(attribution.dim()) if i != 1 1120 ], # average the features over batch samples 1121 ) 1122 1123 importance_step_layer = attribution_abs_batch_mean 1124 importance_step_layer = importance_step_layer.detach() 1125 1126 return importance_step_layer 1127 1128 def get_importance_step_layer_gradcam_abs( 1129 self: str, 1130 layer_name: str, 1131 input: Tensor | tuple[Tensor, ...], 1132 target: Tensor | None, 1133 batch_idx: int, 1134 num_batches: int, 1135 ) -> Tensor: 1136 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum. 1137 1138 **Args:** 1139 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1140 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1141 - **target** (`Tensor` | `None`): the target batch of the training step. 1142 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1143 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1144 1145 **Returns:** 1146 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1147 """ 1148 layer = self.backbone.get_layer_by_name(layer_name) 1149 1150 # initialize the GradCAM object 1151 gradcam = LayerGradCam(forward_func=self.forward, layer=layer) 1152 1153 self.set_forward_func_return_logits_only(True) 1154 # calculate layer attribution of the step 1155 attribution = gradcam.attribute( 1156 inputs=input, 1157 target=target, 1158 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1159 ) 1160 self.set_forward_func_return_logits_only(False) 1161 1162 attribution_abs_batch_mean = torch.mean( 1163 torch.abs(attribution), 1164 dim=[ 1165 i for i in range(attribution.dim()) if i != 1 1166 ], # average the features over batch samples 1167 ) 1168 1169 importance_step_layer = attribution_abs_batch_mean 1170 importance_step_layer = importance_step_layer.detach() 1171 1172 return importance_step_layer 1173 1174 def get_importance_step_layer_deeplift_abs( 1175 self: str, 1176 layer_name: str, 1177 input: Tensor | tuple[Tensor, ...], 1178 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1179 target: Tensor | None, 1180 batch_idx: int, 1181 num_batches: int, 1182 ) -> Tensor: 1183 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum. 1184 1185 **Args:** 1186 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1187 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1188 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details. 1189 - **target** (`Tensor` | `None`): the target batch of the training step. 1190 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1191 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1192 1193 **Returns:** 1194 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1195 """ 1196 layer = self.backbone.get_layer_by_name(layer_name) 1197 1198 # initialize the Layer DeepLift object 1199 layer_deeplift = LayerDeepLift(model=self, layer=layer) 1200 1201 # convert the target to long type to avoid error 1202 target = target.long() if target is not None else None 1203 1204 self.set_forward_func_return_logits_only(True) 1205 # calculate layer attribution of the step 1206 attribution = layer_deeplift.attribute( 1207 inputs=input, 1208 baselines=baselines, 1209 target=target, 1210 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1211 ) 1212 self.set_forward_func_return_logits_only(False) 1213 1214 attribution_abs_batch_mean = torch.mean( 1215 torch.abs(attribution), 1216 dim=[ 1217 i for i in range(attribution.dim()) if i != 1 1218 ], # average the features over batch samples 1219 ) 1220 1221 importance_step_layer = attribution_abs_batch_mean 1222 importance_step_layer = importance_step_layer.detach() 1223 1224 return importance_step_layer 1225 1226 def get_importance_step_layer_deepliftshap_abs( 1227 self: str, 1228 layer_name: str, 1229 input: Tensor | tuple[Tensor, ...], 1230 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1231 target: Tensor | None, 1232 batch_idx: int, 1233 num_batches: int, 1234 ) -> Tensor: 1235 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum. 1236 1237 **Args:** 1238 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1239 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1240 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details. 1241 - **target** (`Tensor` | `None`): the target batch of the training step. 1242 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1243 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1244 1245 **Returns:** 1246 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1247 """ 1248 layer = self.backbone.get_layer_by_name(layer_name) 1249 1250 # initialize the Layer DeepLiftShap object 1251 layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer) 1252 1253 # convert the target to long type to avoid error 1254 target = target.long() if target is not None else None 1255 1256 self.set_forward_func_return_logits_only(True) 1257 # calculate layer attribution of the step 1258 attribution = layer_deepliftshap.attribute( 1259 inputs=input, 1260 baselines=baselines, 1261 target=target, 1262 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1263 ) 1264 self.set_forward_func_return_logits_only(False) 1265 1266 attribution_abs_batch_mean = torch.mean( 1267 torch.abs(attribution), 1268 dim=[ 1269 i for i in range(attribution.dim()) if i != 1 1270 ], # average the features over batch samples 1271 ) 1272 1273 importance_step_layer = attribution_abs_batch_mean 1274 importance_step_layer = importance_step_layer.detach() 1275 1276 return importance_step_layer 1277 1278 def get_importance_step_layer_gradientshap_abs( 1279 self: str, 1280 layer_name: str, 1281 input: Tensor | tuple[Tensor, ...], 1282 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1283 target: Tensor | None, 1284 batch_idx: int, 1285 num_batches: int, 1286 ) -> Tensor: 1287 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum. 1288 1289 **Args:** 1290 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1291 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1292 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero. 1293 - **target** (`Tensor` | `None`): the target batch of the training step. 1294 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1295 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1296 1297 **Returns:** 1298 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1299 """ 1300 layer = self.backbone.get_layer_by_name(layer_name) 1301 1302 if baselines is None: 1303 baselines = torch.zeros_like( 1304 input 1305 ) # baselines are mandatory for GradientShap API. We explicitly set them to zero 1306 1307 # initialize the Layer GradientShap object 1308 layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer) 1309 1310 # convert the target to long type to avoid error 1311 target = target.long() if target is not None else None 1312 1313 self.set_forward_func_return_logits_only(True) 1314 # calculate layer attribution of the step 1315 attribution = layer_gradientshap.attribute( 1316 inputs=input, 1317 baselines=baselines, 1318 target=target, 1319 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1320 ) 1321 self.set_forward_func_return_logits_only(False) 1322 1323 attribution_abs_batch_mean = torch.mean( 1324 torch.abs(attribution), 1325 dim=[ 1326 i for i in range(attribution.dim()) if i != 1 1327 ], # average the features over batch samples 1328 ) 1329 1330 importance_step_layer = attribution_abs_batch_mean 1331 importance_step_layer = importance_step_layer.detach() 1332 1333 return importance_step_layer 1334 1335 def get_importance_step_layer_integrated_gradients_abs( 1336 self: str, 1337 layer_name: str, 1338 input: Tensor | tuple[Tensor, ...], 1339 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1340 target: Tensor | None, 1341 batch_idx: int, 1342 num_batches: int, 1343 ) -> Tensor: 1344 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum. 1345 1346 **Args:** 1347 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1348 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1349 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details. 1350 - **target** (`Tensor` | `None`): the target batch of the training step. 1351 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1352 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1353 1354 **Returns:** 1355 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1356 """ 1357 layer = self.backbone.get_layer_by_name(layer_name) 1358 1359 # initialize the Layer Integrated Gradients object 1360 layer_integrated_gradients = LayerIntegratedGradients( 1361 forward_func=self.forward, layer=layer 1362 ) 1363 1364 self.set_forward_func_return_logits_only(True) 1365 # calculate layer attribution of the step 1366 attribution = layer_integrated_gradients.attribute( 1367 inputs=input, 1368 baselines=baselines, 1369 target=target, 1370 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1371 ) 1372 self.set_forward_func_return_logits_only(False) 1373 1374 attribution_abs_batch_mean = torch.mean( 1375 torch.abs(attribution), 1376 dim=[ 1377 i for i in range(attribution.dim()) if i != 1 1378 ], # average the features over batch samples 1379 ) 1380 1381 importance_step_layer = attribution_abs_batch_mean 1382 importance_step_layer = importance_step_layer.detach() 1383 1384 return importance_step_layer 1385 1386 def get_importance_step_layer_feature_ablation_abs( 1387 self: str, 1388 layer_name: str, 1389 input: Tensor | tuple[Tensor, ...], 1390 layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1391 target: Tensor | None, 1392 batch_idx: int, 1393 num_batches: int, 1394 if_captum: bool = False, 1395 ) -> Tensor: 1396 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum. 1397 1398 **Args:** 1399 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1400 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1401 - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details. 1402 - **target** (`Tensor` | `None`): the target batch of the training step. 1403 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1404 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1405 - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster. 1406 1407 **Returns:** 1408 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1409 """ 1410 layer = self.backbone.get_layer_by_name(layer_name) 1411 1412 if not if_captum: 1413 # 1. Baseline logits (take first element of forward output) 1414 baseline_out, _, _ = self.forward( 1415 input, "train", batch_idx, num_batches, self.task_id 1416 ) 1417 if target is not None: 1418 baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1) 1419 else: 1420 baseline_scores = baseline_out.sum(dim=1) 1421 1422 # 2. Capture layer’s output shape 1423 activs = {} 1424 handle = layer.register_forward_hook( 1425 lambda module, inp, out: activs.setdefault("output", out.detach()) 1426 ) 1427 _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id) 1428 handle.remove() 1429 layer_output = activs["output"] # shape (B, F, ...) 1430 1431 # 3. Build baseline tensor matching that shape 1432 if layer_baselines is None: 1433 baseline_tensor = torch.zeros_like(layer_output) 1434 elif isinstance(layer_baselines, (int, float)): 1435 baseline_tensor = torch.full_like(layer_output, layer_baselines) 1436 elif isinstance(layer_baselines, Tensor): 1437 if layer_baselines.shape == layer_output.shape: 1438 baseline_tensor = layer_baselines 1439 elif layer_baselines.shape == layer_output.shape[1:]: 1440 baseline_tensor = layer_baselines.unsqueeze(0).repeat( 1441 layer_output.size(0), *([1] * layer_baselines.ndim) 1442 ) 1443 else: 1444 raise ValueError(...) 1445 else: 1446 raise ValueError(...) 1447 1448 B, F = layer_output.size(0), layer_output.size(1) 1449 1450 # 4. Create a “mega-batch” replicating the input F times 1451 if isinstance(input, tuple): 1452 mega_inputs = tuple( 1453 t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:]) 1454 for t in input 1455 ) 1456 else: 1457 mega_inputs = ( 1458 input.unsqueeze(0) 1459 .repeat(F, *([1] * input.ndim)) 1460 .view(-1, *input.shape[1:]) 1461 ) 1462 1463 # 5. Equally replicate the baseline tensor 1464 mega_baseline = ( 1465 baseline_tensor.unsqueeze(0) 1466 .repeat(F, *([1] * baseline_tensor.ndim)) 1467 .view(-1, *baseline_tensor.shape[1:]) 1468 ) 1469 1470 # 6. Precompute vectorized indices 1471 device = layer_output.device 1472 positions = torch.arange(F * B, device=device) # [0,1,...,F*B-1] 1473 feat_idx = torch.arange(F, device=device).repeat_interleave( 1474 B 1475 ) # [0,0,...,1,1,...,F-1] 1476 1477 # 7. One hook to zero out each channel slice across the mega-batch 1478 def mega_ablate_hook(module, inp, out): 1479 out_mod = out.clone() 1480 # for each sample in mega-batch, zero its corresponding channel 1481 out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx] 1482 return out_mod 1483 1484 h = layer.register_forward_hook(mega_ablate_hook) 1485 out_all, _, _ = self.forward( 1486 mega_inputs, "train", batch_idx, num_batches, self.task_id 1487 ) 1488 h.remove() 1489 1490 # 8. Recover scores, reshape [F*B] → [F, B], diff & mean 1491 if target is not None: 1492 tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1) 1493 scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1) 1494 else: 1495 scores_all = out_all.sum(dim=1) 1496 1497 scores_all = scores_all.view(F, B) 1498 diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all) 1499 importance_step_layer = diffs.mean(dim=1).detach() # [F] 1500 1501 return importance_step_layer 1502 1503 else: 1504 # initialize the Layer Feature Ablation object 1505 layer_feature_ablation = LayerFeatureAblation( 1506 forward_func=self.forward, layer=layer 1507 ) 1508 1509 # calculate layer attribution of the step 1510 self.set_forward_func_return_logits_only(True) 1511 attribution = layer_feature_ablation.attribute( 1512 inputs=input, 1513 layer_baselines=layer_baselines, 1514 # target=target, # disable target to enable perturbations_per_eval 1515 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1516 perturbations_per_eval=128, # to accelerate the computation 1517 ) 1518 self.set_forward_func_return_logits_only(False) 1519 1520 attribution_abs_batch_mean = torch.mean( 1521 torch.abs(attribution), 1522 dim=[ 1523 i for i in range(attribution.dim()) if i != 1 1524 ], # average the features over batch samples 1525 ) 1526 1527 importance_step_layer = attribution_abs_batch_mean 1528 importance_step_layer = importance_step_layer.detach() 1529 1530 return importance_step_layer 1531 1532 def get_importance_step_layer_lrp_abs( 1533 self: str, 1534 layer_name: str, 1535 input: Tensor | tuple[Tensor, ...], 1536 target: Tensor | None, 1537 batch_idx: int, 1538 num_batches: int, 1539 ) -> Tensor: 1540 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum. 1541 1542 **Args:** 1543 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1544 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1545 - **target** (`Tensor` | `None`): the target batch of the training step. 1546 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1547 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1548 1549 **Returns:** 1550 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1551 """ 1552 layer = self.backbone.get_layer_by_name(layer_name) 1553 1554 # initialize the Layer LRP object 1555 layer_lrp = LayerLRP(model=self, layer=layer) 1556 1557 # set model to evaluation mode to prevent updating the model parameters 1558 self.eval() 1559 1560 self.set_forward_func_return_logits_only(True) 1561 # calculate layer attribution of the step 1562 attribution = layer_lrp.attribute( 1563 inputs=input, 1564 target=target, 1565 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1566 ) 1567 self.set_forward_func_return_logits_only(False) 1568 1569 attribution_abs_batch_mean = torch.mean( 1570 torch.abs(attribution), 1571 dim=[ 1572 i for i in range(attribution.dim()) if i != 1 1573 ], # average the features over batch samples 1574 ) 1575 1576 importance_step_layer = attribution_abs_batch_mean 1577 importance_step_layer = importance_step_layer.detach() 1578 1579 return importance_step_layer 1580 1581 def get_importance_step_layer_cbp_adaptive_contribution( 1582 self: str, 1583 layer_name: str, 1584 activation: Tensor, 1585 ) -> Tensor: 1586 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 1587 1588 **Args:** 1589 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1590 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 1591 1592 **Returns:** 1593 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1594 """ 1595 layer = self.backbone.get_layer_by_name(layer_name) 1596 1597 input_weight_abs = torch.abs(layer.weight.data) 1598 input_weight_abs_sum = torch.sum( 1599 input_weight_abs, 1600 dim=[ 1601 i for i in range(input_weight_abs.dim()) if i != 0 1602 ], # sum over the input dimension 1603 ) 1604 input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum) 1605 1606 output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 1607 output_weight_abs_sum = torch.sum( 1608 output_weight_abs, 1609 dim=[ 1610 i for i in range(output_weight_abs.dim()) if i != 1 1611 ], # sum over the output dimension 1612 ) 1613 1614 activation_abs_batch_mean = torch.mean( 1615 torch.abs(activation), 1616 dim=[ 1617 i for i in range(activation.dim()) if i != 1 1618 ], # average the features over batch samples 1619 ) 1620 1621 importance_step_layer = ( 1622 output_weight_abs_sum 1623 * activation_abs_batch_mean 1624 * input_weight_abs_sum_reciprocal 1625 ) 1626 importance_step_layer = importance_step_layer.detach() 1627 1628 return importance_step_layer
FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm.
An architecture-based continual learning approach that improves AdaHAT (Adaptive Hard Attention to the Task) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components.
45 def __init__( 46 self, 47 backbone: HATMaskBackbone, 48 heads: HeadsTIL, 49 adjustment_intensity: float, 50 importance_type: str, 51 importance_summing_strategy: str, 52 importance_scheduler_type: str, 53 neuron_to_weight_importance_aggregation_mode: str, 54 s_max: float, 55 clamp_threshold: float, 56 mask_sparsity_reg_factor: float, 57 mask_sparsity_reg_mode: str = "original", 58 base_importance: float = 0.01, 59 base_mask_sparsity_reg: float = 0.1, 60 base_linear: float = 10, 61 filter_by_cumulative_mask: bool = False, 62 filter_unmasked_importance: bool = True, 63 step_multiply_training_mask: bool = True, 64 task_embedding_init_mode: str = "N01", 65 importance_summing_strategy_linear_step: float | None = None, 66 importance_summing_strategy_exponential_rate: float | None = None, 67 importance_summing_strategy_log_base: float | None = None, 68 non_algorithmic_hparams: dict[str, Any] = {}, 69 ) -> None: 70 r"""Initialize the FG-AdaHAT algorithm with the network. 71 72 **Args:** 73 - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism. 74 - **heads** (`HeadsTIL`): output heads. FG-AdaHAT supports only TIL (Task-Incremental Learning). 75 - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper). 76 - **importance_type** (`str`): the type of neuron-wise importance, must be one of: 77 1. 'input_weight_abs_sum': sum of absolute input weights; 78 2. 'output_weight_abs_sum': sum of absolute output weights; 79 3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper); 80 4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper); 81 5. 'activation_abs': absolute activation; 82 6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper); 83 7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper); 84 8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation; 85 9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights; 86 10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights; 87 11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper); 88 12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation; 89 13. 'conductance_abs': absolute layer conductance; 90 14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper); 91 15. 'gradcam_abs': absolute Grad-CAM; 92 16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper); 93 17. 'deepliftshap_abs': absolute DeepLIFT-SHAP; 94 18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper); 95 19. 'integrated_gradients_abs': absolute Integrated Gradients; 96 20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper); 97 21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP); 98 22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 99 23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7); 100 - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of: 101 1. 'add_latest': add the latest neuron-wise importance to the summative importance; 102 2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance; 103 3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance; 104 4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID; 105 5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID; 106 6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID; 107 7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID; 108 8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID; 109 9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID; 110 - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of: 111 1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`); 112 2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$; 113 3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$. 114 - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of: 115 1. 'min': take the minimum of neuron-wise importance for each weight; 116 2. 'max': take the maximum of neuron-wise importance for each weight; 117 3. 'mean': take the mean of neuron-wise importance for each weight. 118 - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 119 - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 120 - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity. 121 - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of: 122 1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a). 123 2. 'cross': the cross version of mask sparsity regularization. 124 - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01. 125 - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1. 126 - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10. 127 - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False. 128 - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False. 129 - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True. 130 - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of: 131 1. 'N01' (default): standard normal distribution $N(0, 1)$. 132 2. 'U-11': uniform distribution $U(-1, 1)$. 133 3. 'U01': uniform distribution $U(0, 1)$. 134 4. 'U-10': uniform distribution $U(-1, 0)$. 135 5. 'last': inherit the task embedding from the last task. 136 - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0. 137 - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1. 138 - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1. 139 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 140 141 """ 142 super().__init__( 143 backbone=backbone, 144 heads=heads, 145 adjustment_mode=None, # use the own adjustment mechanism of FG-AdaHAT 146 adjustment_intensity=adjustment_intensity, 147 s_max=s_max, 148 clamp_threshold=clamp_threshold, 149 mask_sparsity_reg_factor=mask_sparsity_reg_factor, 150 mask_sparsity_reg_mode=mask_sparsity_reg_mode, 151 task_embedding_init_mode=task_embedding_init_mode, 152 epsilon=base_mask_sparsity_reg, # the epsilon is now the base mask sparsity regularization factor 153 non_algorithmic_hparams=non_algorithmic_hparams, 154 ) 155 156 # save additional algorithmic hyperparameters 157 self.save_hyperparameters( 158 "adjustment_intensity", 159 "importance_type", 160 "importance_summing_strategy", 161 "importance_scheduler_type", 162 "neuron_to_weight_importance_aggregation_mode", 163 "s_max", 164 "clamp_threshold", 165 "mask_sparsity_reg_factor", 166 "mask_sparsity_reg_mode", 167 "base_importance", 168 "base_mask_sparsity_reg", 169 "base_linear", 170 "filter_by_cumulative_mask", 171 "filter_unmasked_importance", 172 "step_multiply_training_mask", 173 ) 174 175 self.importance_type: str | None = importance_type 176 r"""The type of the neuron-wise importance added to AdaHAT importance.""" 177 178 self.importance_scheduler_type: str = importance_scheduler_type 179 r"""The type of the importance scheduler.""" 180 self.neuron_to_weight_importance_aggregation_mode: str = ( 181 neuron_to_weight_importance_aggregation_mode 182 ) 183 r"""The mode of aggregation from neuron-wise to weight-wise importance. """ 184 self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask 185 r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate.""" 186 self.filter_unmasked_importance: bool = filter_unmasked_importance 187 r"""The flag to filter unmasked importance values (set them to 0) at the end of task training.""" 188 self.step_multiply_training_mask: bool = step_multiply_training_mask 189 r"""The flag to multiply the training mask to the importance at each training step.""" 190 191 # importance summing strategy 192 self.importance_summing_strategy: str = importance_summing_strategy 193 r"""The strategy to sum the neuron-wise importance for previous tasks.""" 194 if importance_summing_strategy_linear_step is not None: 195 self.importance_summing_strategy_linear_step: float = ( 196 importance_summing_strategy_linear_step 197 ) 198 r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease').""" 199 if importance_summing_strategy_exponential_rate is not None: 200 self.importance_summing_strategy_exponential_rate: float = ( 201 importance_summing_strategy_exponential_rate 202 ) 203 r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """ 204 if importance_summing_strategy_log_base is not None: 205 self.importance_summing_strategy_log_base: float = ( 206 importance_summing_strategy_log_base 207 ) 208 r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """ 209 210 # base values 211 self.base_importance: float = base_importance 212 r"""The base value added to the importance to avoid zero. """ 213 self.base_mask_sparsity_reg: float = base_mask_sparsity_reg 214 r"""The base value added to the mask sparsity regularization to avoid zero. """ 215 self.base_linear: float = base_linear 216 r"""The base value added to the linear layer to avoid zero. """ 217 218 self.importances: dict[int, dict[str, Tensor]] = {} 219 r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """ 220 self.summative_importance_for_previous_tasks: dict[str, Tensor] = {} 221 r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """ 222 223 self.num_steps_t: int 224 r"""The number of training steps for the current task `self.task_id`.""" 225 # set manual optimization 226 self.automatic_optimization = False 227 228 FGAdaHAT.sanity_check(self)
Initialize the FG-AdaHAT algorithm with the network.
Args:
- backbone (
HATMaskBackbone): must be a backbone network with the HAT mask mechanism. - heads (
HeadsTIL): output heads. FG-AdaHAT supports only TIL (Task-Incremental Learning). - adjustment_intensity (
float): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper). - importance_type (
str): the type of neuron-wise importance, must be one of:- 'input_weight_abs_sum': sum of absolute input weights;
- 'output_weight_abs_sum': sum of absolute output weights;
- 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
- 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
- 'activation_abs': absolute activation;
- 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
- 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
- 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
- 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
- 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
- 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
- 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
- 'conductance_abs': absolute layer conductance;
- 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
- 'gradcam_abs': absolute Grad-CAM;
- 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
- 'deepliftshap_abs': absolute DeepLIFT-SHAP;
- 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
- 'integrated_gradients_abs': absolute Integrated Gradients;
- 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
- 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
- 'cbp_adaptation': the adaptation function in Continual Backpropagation (CBP);
- 'cbp_adaptive_contribution': the adaptive contribution function in Continual Backpropagation (CBP);
- importance_summing_strategy (
str): the strategy to sum neuron-wise importance for previous tasks, must be one of:- 'add_latest': add the latest neuron-wise importance to the summative importance;
- 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
- 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
- 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
- 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
- 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
- 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
- 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
- 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
- importance_scheduler_type (
str): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:- 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{
base_linear), and $b_R$ is the base mask sparsity regularization factor (see argument base_mask_sparsity_reg); - 'sparsity_reg': $c^t = [R(M^t, M^{
- 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{
- 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{
- neuron_to_weight_importance_aggregation_mode (
str): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:- 'min': take the minimum of neuron-wise importance for each weight;
- 'max': take the maximum of neuron-wise importance for each weight;
- 'mean': take the mean of neuron-wise importance for each weight.
- s_max (
float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper. - clamp_threshold (
float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper. - mask_sparsity_reg_factor (
float): hyperparameter, the regularization factor for mask sparsity. - mask_sparsity_reg_mode (
str): the mode of mask sparsity regularization, must be one of:- 'original' (default): the original mask sparsity regularization in the HAT paper.
- 'cross': the cross version of mask sparsity regularization.
- base_importance (
float): base value added to importance ($b_I$ in the paper). Default: 0.01. - base_mask_sparsity_reg (
float): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1. - base_linear (
float): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10. - filter_by_cumulative_mask (
bool): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False. - filter_unmasked_importance (
bool): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False. - step_multiply_training_mask (
bool): whether to multiply the training mask to the importance at each training step. Default: True. - task_embedding_init_mode (
str): the initialization mode for task embeddings, must be one of:- 'N01' (default): standard normal distribution $N(0, 1)$.
- 'U-11': uniform distribution $U(-1, 1)$.
- 'U01': uniform distribution $U(0, 1)$.
- 'U-10': uniform distribution $U(-1, 0)$.
- 'last': inherit the task embedding from the last task.
- importance_summing_strategy_linear_step (
float|None): linear step for the importance summing strategy (used whenimportance_summing_strategyis 'linear_decrease'). Must be > 0. - importance_summing_strategy_exponential_rate (
float|None): exponential rate for the importance summing strategy (used whenimportance_summing_strategyis 'exponential_decrease'). Must be > 1. - importance_summing_strategy_log_base (
float|None): base for the logarithm in the importance summing strategy (used whenimportance_summing_strategyis 'log_decrease'). Must be > 1. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility.
The mode of aggregation from neuron-wise to weight-wise importance.
The flag to filter importance by the cumulative mask when calculating the adjustment rate.
The flag to filter unmasked importance values (set them to 0) at the end of task training.
The flag to multiply the training mask to the importance at each training step.
The base value added to the mask sparsity regularization to avoid zero.
The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ).
The summative neuron-wise importance values of units for previous tasks before the current task self.task_id. See $I^{
290 @property 291 def automatic_optimization(self) -> bool: 292 """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" 293 return self._automatic_optimization
If set to False you are responsible for calling .backward(), .step(), .zero_grad().
230 def sanity_check(self) -> None: 231 r"""Sanity check.""" 232 233 # check importance type 234 if self.importance_type not in [ 235 "input_weight_abs_sum", 236 "output_weight_abs_sum", 237 "input_weight_gradient_abs_sum", 238 "output_weight_gradient_abs_sum", 239 "activation_abs", 240 "input_weight_abs_sum_x_activation_abs", 241 "output_weight_abs_sum_x_activation_abs", 242 "gradient_x_activation_abs", 243 "input_weight_gradient_square_sum", 244 "output_weight_gradient_square_sum", 245 "input_weight_gradient_square_sum_x_activation_abs", 246 "output_weight_gradient_square_sum_x_activation_abs", 247 "conductance_abs", 248 "internal_influence_abs", 249 "gradcam_abs", 250 "deeplift_abs", 251 "deepliftshap_abs", 252 "gradientshap_abs", 253 "integrated_gradients_abs", 254 "feature_ablation_abs", 255 "lrp_abs", 256 "cbp_adaptation", 257 "cbp_adaptive_contribution", 258 ]: 259 raise ValueError( 260 f"importance_type must be one of the predefined types, but got {self.importance_type}" 261 ) 262 263 # check importance summing strategy 264 if self.importance_summing_strategy not in [ 265 "add_latest", 266 "add_all", 267 "add_average", 268 "linear_decrease", 269 "quadratic_decrease", 270 "cubic_decrease", 271 "exponential_decrease", 272 "log_decrease", 273 "factorial_decrease", 274 ]: 275 raise ValueError( 276 f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}" 277 ) 278 279 # check importance scheduler type 280 if self.importance_scheduler_type not in [ 281 "linear_sparsity_reg", 282 "sparsity_reg", 283 "summative_mask_sparsity_reg", 284 ]: 285 raise ValueError( 286 f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}" 287 ) 288 289 # check neuron to weight importance aggregation mode 290 if self.neuron_to_weight_importance_aggregation_mode not in [ 291 "min", 292 "max", 293 "mean", 294 ]: 295 raise ValueError( 296 f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}" 297 ) 298 299 # check base values 300 if self.base_importance < 0: 301 raise ValueError( 302 f"base_importance must be >= 0, but got {self.base_importance}" 303 ) 304 if self.base_mask_sparsity_reg <= 0: 305 raise ValueError( 306 f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}" 307 ) 308 if self.base_linear <= 0: 309 raise ValueError(f"base_linear must be > 0, but got {self.base_linear}")
Sanity check.
311 def on_train_start(self) -> None: 312 r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.""" 313 super().on_train_start() 314 315 self.importances[self.task_id] = ( 316 {} 317 ) # initialize the importance for the current task 318 319 # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time. 320 for layer_name in self.backbone.weighted_layer_names: 321 layer = self.backbone.get_layer_by_name( 322 layer_name 323 ) # get the layer by its name 324 num_units = layer.weight.shape[0] 325 326 # initialize the accumulated importance at the beginning of each task 327 self.importances[self.task_id][layer_name] = torch.zeros(num_units).to( 328 self.device 329 ) 330 331 # reset the number of steps counter for the current task 332 self.num_steps_t = 0 333 334 # initialize the summative neuron-wise importance at the beginning of the first task 335 if self.task_id == 1: 336 self.summative_importance_for_previous_tasks[layer_name] = torch.zeros( 337 num_units 338 ).to( 339 self.device 340 ) # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$
Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.
342 def clip_grad_by_adjustment( 343 self, 344 network_sparsity: dict[str, Tensor], 345 ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]: 346 r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper. 347 348 Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code. 349 350 Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 351 352 **Args:** 353 - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler. 354 355 **Returns:** 356 - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 357 - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. 358 - **capacity** (`Tensor`): the calculated network capacity. 359 """ 360 361 # initialize network capacity metric 362 capacity = HATNetworkCapacityMetric().to(self.device) 363 adjustment_rate_weight = {} 364 adjustment_rate_bias = {} 365 366 # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper 367 for layer_name in self.backbone.weighted_layer_names: 368 369 layer = self.backbone.get_layer_by_name( 370 layer_name 371 ) # get the layer by its name 372 373 # placeholder for the adjustment rate to avoid the error of using it before assignment 374 adjustment_rate_weight_layer = 1 375 adjustment_rate_bias_layer = 1 376 377 # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask 378 weight_importance, bias_importance = ( 379 self.backbone.get_layer_measure_parameter_wise( 380 neuron_wise_measure=self.summative_importance_for_previous_tasks, 381 layer_name=layer_name, 382 aggregation_mode=self.neuron_to_weight_importance_aggregation_mode, 383 ) 384 ) 385 386 weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise( 387 neuron_wise_measure=self.cumulative_mask_for_previous_tasks, 388 layer_name=layer_name, 389 aggregation_mode="min", 390 ) 391 392 # filter the weight importance by the cumulative mask 393 if self.filter_by_cumulative_mask: 394 weight_importance = weight_importance * weight_mask 395 bias_importance = bias_importance * bias_mask 396 397 network_sparsity_layer = network_sparsity[layer_name] 398 399 # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper 400 factor = network_sparsity_layer + self.base_mask_sparsity_reg 401 if self.importance_scheduler_type == "linear_sparsity_reg": 402 factor = factor * (self.task_id + self.base_linear) 403 elif self.importance_scheduler_type == "sparsity_reg": 404 pass 405 elif self.importance_scheduler_type == "summative_mask_sparsity_reg": 406 factor = factor * ( 407 self.summative_mask_for_previous_tasks + self.base_linear 408 ) 409 410 # calculate the adjustment rate 411 adjustment_rate_weight_layer = torch.div( 412 self.adjustment_intensity, 413 (factor * weight_importance + self.adjustment_intensity), 414 ) 415 416 adjustment_rate_bias_layer = torch.div( 417 self.adjustment_intensity, 418 (factor * bias_importance + self.adjustment_intensity), 419 ) 420 421 # apply the adjustment rate to the gradients 422 layer.weight.grad.data *= adjustment_rate_weight_layer 423 if layer.bias is not None: 424 layer.bias.grad.data *= adjustment_rate_bias_layer 425 426 # store the adjustment rate for logging 427 adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer 428 if layer.bias is not None: 429 adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer 430 431 # update network capacity metric 432 capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer) 433 434 return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
Clip the gradients by the adjustment rate. See Eq. (1) in the paper.
Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.
Args:
- network_sparsity (
dict[str, Tensor]): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.
Returns:
- adjustment_rate_weight (
dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. - adjustment_rate_bias (
dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. - capacity (
Tensor): the calculated network capacity.
436 def on_train_batch_end( 437 self, outputs: dict[str, Any], batch: Any, batch_idx: int 438 ) -> None: 439 r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step. 440 441 **Args:** 442 - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`). 443 - **batch** (`Any`): training data batch. 444 - **batch_idx** (`int`): index of the current batch (for mask figure file name). 445 """ 446 447 # get potential useful information from training batch 448 activations = outputs["activations"] 449 input = outputs["input"] 450 target = outputs["target"] 451 mask = outputs["mask"] 452 num_batches = self.trainer.num_training_batches 453 454 for layer_name in self.backbone.weighted_layer_names: 455 # layer-wise operation 456 457 activation = activations[layer_name] 458 459 # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. 460 if self.importance_type == "input_weight_abs_sum": 461 importance_step = self.get_importance_step_layer_weight_abs_sum( 462 layer_name=layer_name, 463 if_output_weight=False, 464 reciprocal=False, 465 ) 466 elif self.importance_type == "output_weight_abs_sum": 467 importance_step = self.get_importance_step_layer_weight_abs_sum( 468 layer_name=layer_name, 469 if_output_weight=True, 470 reciprocal=False, 471 ) 472 elif self.importance_type == "input_weight_gradient_abs_sum": 473 importance_step = ( 474 self.get_importance_step_layer_weight_gradient_abs_sum( 475 layer_name=layer_name, if_output_weight=False 476 ) 477 ) 478 elif self.importance_type == "output_weight_gradient_abs_sum": 479 importance_step = ( 480 self.get_importance_step_layer_weight_gradient_abs_sum( 481 layer_name=layer_name, if_output_weight=True 482 ) 483 ) 484 elif self.importance_type == "activation_abs": 485 importance_step = self.get_importance_step_layer_activation_abs( 486 activation=activation 487 ) 488 elif self.importance_type == "input_weight_abs_sum_x_activation_abs": 489 importance_step = ( 490 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 491 layer_name=layer_name, 492 activation=activation, 493 if_output_weight=False, 494 ) 495 ) 496 elif self.importance_type == "output_weight_abs_sum_x_activation_abs": 497 importance_step = ( 498 self.get_importance_step_layer_weight_abs_sum_x_activation_abs( 499 layer_name=layer_name, 500 activation=activation, 501 if_output_weight=True, 502 ) 503 ) 504 elif self.importance_type == "gradient_x_activation_abs": 505 importance_step = ( 506 self.get_importance_step_layer_gradient_x_activation_abs( 507 layer_name=layer_name, 508 input=input, 509 target=target, 510 batch_idx=batch_idx, 511 num_batches=num_batches, 512 ) 513 ) 514 elif self.importance_type == "input_weight_gradient_square_sum": 515 importance_step = ( 516 self.get_importance_step_layer_weight_gradient_square_sum( 517 layer_name=layer_name, 518 activation=activation, 519 if_output_weight=False, 520 ) 521 ) 522 elif self.importance_type == "output_weight_gradient_square_sum": 523 importance_step = ( 524 self.get_importance_step_layer_weight_gradient_square_sum( 525 layer_name=layer_name, 526 activation=activation, 527 if_output_weight=True, 528 ) 529 ) 530 elif ( 531 self.importance_type 532 == "input_weight_gradient_square_sum_x_activation_abs" 533 ): 534 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 535 layer_name=layer_name, 536 activation=activation, 537 if_output_weight=False, 538 ) 539 elif ( 540 self.importance_type 541 == "output_weight_gradient_square_sum_x_activation_abs" 542 ): 543 importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 544 layer_name=layer_name, 545 activation=activation, 546 if_output_weight=True, 547 ) 548 elif self.importance_type == "conductance_abs": 549 importance_step = self.get_importance_step_layer_conductance_abs( 550 layer_name=layer_name, 551 input=input, 552 baselines=None, 553 target=target, 554 batch_idx=batch_idx, 555 num_batches=num_batches, 556 ) 557 elif self.importance_type == "internal_influence_abs": 558 importance_step = self.get_importance_step_layer_internal_influence_abs( 559 layer_name=layer_name, 560 input=input, 561 baselines=None, 562 target=target, 563 batch_idx=batch_idx, 564 num_batches=num_batches, 565 ) 566 elif self.importance_type == "gradcam_abs": 567 importance_step = self.get_importance_step_layer_gradcam_abs( 568 layer_name=layer_name, 569 input=input, 570 target=target, 571 batch_idx=batch_idx, 572 num_batches=num_batches, 573 ) 574 elif self.importance_type == "deeplift_abs": 575 importance_step = self.get_importance_step_layer_deeplift_abs( 576 layer_name=layer_name, 577 input=input, 578 baselines=None, 579 target=target, 580 batch_idx=batch_idx, 581 num_batches=num_batches, 582 ) 583 elif self.importance_type == "deepliftshap_abs": 584 importance_step = self.get_importance_step_layer_deepliftshap_abs( 585 layer_name=layer_name, 586 input=input, 587 baselines=None, 588 target=target, 589 batch_idx=batch_idx, 590 num_batches=num_batches, 591 ) 592 elif self.importance_type == "gradientshap_abs": 593 importance_step = self.get_importance_step_layer_gradientshap_abs( 594 layer_name=layer_name, 595 input=input, 596 baselines=None, 597 target=target, 598 batch_idx=batch_idx, 599 num_batches=num_batches, 600 ) 601 elif self.importance_type == "integrated_gradients_abs": 602 importance_step = ( 603 self.get_importance_step_layer_integrated_gradients_abs( 604 layer_name=layer_name, 605 input=input, 606 baselines=None, 607 target=target, 608 batch_idx=batch_idx, 609 num_batches=num_batches, 610 ) 611 ) 612 elif self.importance_type == "feature_ablation_abs": 613 importance_step = self.get_importance_step_layer_feature_ablation_abs( 614 layer_name=layer_name, 615 input=input, 616 layer_baselines=None, 617 target=target, 618 batch_idx=batch_idx, 619 num_batches=num_batches, 620 ) 621 elif self.importance_type == "lrp_abs": 622 importance_step = self.get_importance_step_layer_lrp_abs( 623 layer_name=layer_name, 624 input=input, 625 target=target, 626 batch_idx=batch_idx, 627 num_batches=num_batches, 628 ) 629 elif self.importance_type == "cbp_adaptation": 630 importance_step = self.get_importance_step_layer_weight_abs_sum( 631 layer_name=layer_name, 632 if_output_weight=False, 633 reciprocal=True, 634 ) 635 elif self.importance_type == "cbp_adaptive_contribution": 636 importance_step = ( 637 self.get_importance_step_layer_cbp_adaptive_contribution( 638 layer_name=layer_name, 639 activation=activation, 640 ) 641 ) 642 643 importance_step = min_max_normalize( 644 importance_step 645 ) # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper 646 647 # multiply the importance by the training mask. See Eq. (6) in the paper 648 if self.step_multiply_training_mask: 649 importance_step = importance_step * mask[layer_name] 650 651 # update accumulated importance 652 self.importances[self.task_id][layer_name] = ( 653 self.importances[self.task_id][layer_name] + importance_step 654 ) 655 656 # update number of steps counter 657 self.num_steps_t += 1
Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.
Args:
- outputs (
dict[str, Any]): outputs of the training step (returns oftraining_step()inCLAlgorithm). - batch (
Any): training data batch. - batch_idx (
int): index of the current batch (for mask figure file name).
659 def on_train_end(self) -> None: 660 r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task.""" 661 super().on_train_end() # store the mask and update cumulative and summative masks 662 663 for layer_name in self.backbone.weighted_layer_names: 664 665 # average the neuron-wise step importance. See Eq. (4) in the paper 666 self.importances[self.task_id][layer_name] = ( 667 self.importances[self.task_id][layer_name] 668 ) / self.num_steps_t 669 670 # add the base importance. See Eq. (6) in the paper 671 self.importances[self.task_id][layer_name] = ( 672 self.importances[self.task_id][layer_name] + self.base_importance 673 ) 674 675 # filter unmasked importance 676 if self.filter_unmasked_importance: 677 self.importances[self.task_id][layer_name] = ( 678 self.importances[self.task_id][layer_name] 679 * self.backbone.masks[f"{self.task_id}"][layer_name] 680 ) 681 682 # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper 683 if self.importance_summing_strategy == "add_latest": 684 self.summative_importance_for_previous_tasks[ 685 layer_name 686 ] += self.importances[self.task_id][layer_name] 687 688 elif self.importance_summing_strategy == "add_all": 689 for t in range(1, self.task_id + 1): 690 self.summative_importance_for_previous_tasks[ 691 layer_name 692 ] += self.importances[t][layer_name] 693 694 elif self.importance_summing_strategy == "add_average": 695 for t in range(1, self.task_id + 1): 696 self.summative_importance_for_previous_tasks[layer_name] += ( 697 self.importances[t][layer_name] / self.task_id 698 ) 699 else: 700 self.summative_importance_for_previous_tasks[ 701 layer_name 702 ] = torch.zeros_like( 703 self.summative_importance_for_previous_tasks[layer_name] 704 ).to( 705 self.device 706 ) # starting adding from 0 707 708 if self.importance_summing_strategy == "linear_decrease": 709 s = self.importance_summing_strategy_linear_step 710 for t in range(1, self.task_id + 1): 711 w_t = s * (self.task_id - t) + 1 712 713 elif self.importance_summing_strategy == "quadratic_decrease": 714 for t in range(1, self.task_id + 1): 715 w_t = (self.task_id - t + 1) ** 2 716 elif self.importance_summing_strategy == "cubic_decrease": 717 for t in range(1, self.task_id + 1): 718 w_t = (self.task_id - t + 1) ** 3 719 elif self.importance_summing_strategy == "exponential_decrease": 720 for t in range(1, self.task_id + 1): 721 r = self.importance_summing_strategy_exponential_rate 722 723 w_t = r ** (self.task_id - t + 1) 724 elif self.importance_summing_strategy == "log_decrease": 725 a = self.importance_summing_strategy_log_base 726 for t in range(1, self.task_id + 1): 727 w_t = math.log(self.task_id - t, a) + 1 728 elif self.importance_summing_strategy == "factorial_decrease": 729 for t in range(1, self.task_id + 1): 730 w_t = math.factorial(self.task_id - t + 1) 731 else: 732 raise ValueError 733 self.summative_importance_for_previous_tasks[layer_name] += ( 734 self.importances[t][layer_name] * w_t 735 )
Additionally calculate neuron-wise importance for previous tasks at the end of training each task.
737 def get_importance_step_layer_weight_abs_sum( 738 self: str, 739 layer_name: str, 740 if_output_weight: bool, 741 reciprocal: bool, 742 ) -> Tensor: 743 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights. 744 745 **Args:** 746 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 747 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 748 - **reciprocal** (`bool`): whether to take reciprocal. 749 750 **Returns:** 751 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 752 """ 753 layer = self.backbone.get_layer_by_name(layer_name) 754 755 if not if_output_weight: 756 weight_abs = torch.abs(layer.weight.data) 757 weight_abs_sum = torch.sum( 758 weight_abs, 759 dim=[ 760 i for i in range(weight_abs.dim()) if i != 0 761 ], # sum over the input dimension 762 ) 763 else: 764 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 765 weight_abs_sum = torch.sum( 766 weight_abs, 767 dim=[ 768 i for i in range(weight_abs.dim()) if i != 1 769 ], # sum over the output dimension 770 ) 771 772 if reciprocal: 773 weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum) 774 importance_step_layer = weight_abs_sum_reciprocal 775 else: 776 importance_step_layer = weight_abs_sum 777 importance_step_layer = importance_step_layer.detach() 778 779 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - if_output_weight (
bool): whether to use the output weights or input weights. - reciprocal (
bool): whether to take reciprocal.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
781 def get_importance_step_layer_weight_gradient_abs_sum( 782 self: str, 783 layer_name: str, 784 if_output_weight: bool, 785 ) -> Tensor: 786 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights. 787 788 **Args:** 789 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 790 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 791 792 **Returns:** 793 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 794 """ 795 layer = self.backbone.get_layer_by_name(layer_name) 796 797 if not if_output_weight: 798 gradient_abs = torch.abs(layer.weight.grad.data) 799 gradient_abs_sum = torch.sum( 800 gradient_abs, 801 dim=[ 802 i for i in range(gradient_abs.dim()) if i != 0 803 ], # sum over the input dimension 804 ) 805 else: 806 gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data) 807 gradient_abs_sum = torch.sum( 808 gradient_abs, 809 dim=[ 810 i for i in range(gradient_abs.dim()) if i != 1 811 ], # sum over the output dimension 812 ) 813 814 importance_step_layer = gradient_abs_sum 815 importance_step_layer = importance_step_layer.detach() 816 817 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - if_output_weight (
bool): whether to use the output weights or input weights.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
819 def get_importance_step_layer_activation_abs( 820 self: str, 821 activation: Tensor, 822 ) -> Tensor: 823 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum. 824 825 **Args:** 826 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 827 828 **Returns:** 829 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 830 """ 831 activation_abs_batch_mean = torch.mean( 832 torch.abs(activation), 833 dim=[ 834 i for i in range(activation.dim()) if i != 1 835 ], # average the features over batch samples 836 ) 837 importance_step_layer = activation_abs_batch_mean 838 importance_step_layer = importance_step_layer.detach() 839 840 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of Layer Activation in Captum.
Args:
- activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
842 def get_importance_step_layer_weight_abs_sum_x_activation_abs( 843 self: str, 844 layer_name: str, 845 activation: Tensor, 846 if_output_weight: bool, 847 ) -> Tensor: 848 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 849 850 **Args:** 851 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 852 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 853 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 854 855 **Returns:** 856 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 857 """ 858 layer = self.backbone.get_layer_by_name(layer_name) 859 860 if not if_output_weight: 861 weight_abs = torch.abs(layer.weight.data) 862 weight_abs_sum = torch.sum( 863 weight_abs, 864 dim=[ 865 i for i in range(weight_abs.dim()) if i != 0 866 ], # sum over the input dimension 867 ) 868 else: 869 weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 870 weight_abs_sum = torch.sum( 871 weight_abs, 872 dim=[ 873 i for i in range(weight_abs.dim()) if i != 1 874 ], # sum over the output dimension 875 ) 876 877 activation_abs_batch_mean = torch.mean( 878 torch.abs(activation), 879 dim=[ 880 i for i in range(activation.dim()) if i != 1 881 ], # average the features over batch samples 882 ) 883 884 importance_step_layer = weight_abs_sum * activation_abs_batch_mean 885 importance_step_layer = importance_step_layer.detach() 886 887 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in CBP.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ). - if_output_weight (
bool): whether to use the output weights or input weights.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
889 def get_importance_step_layer_gradient_x_activation_abs( 890 self: str, 891 layer_name: str, 892 input: Tensor | tuple[Tensor, ...], 893 target: Tensor | None, 894 batch_idx: int, 895 num_batches: int, 896 ) -> Tensor: 897 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum. 898 899 **Args:** 900 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 901 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 902 - **target** (`Tensor` | `None`): the target batch of the training step. 903 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 904 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 905 906 **Returns:** 907 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 908 """ 909 layer = self.backbone.get_layer_by_name(layer_name) 910 911 input = input.requires_grad_() 912 913 # initialize the Layer Gradient X Activation object 914 layer_gradient_x_activation = LayerGradientXActivation( 915 forward_func=self.forward, layer=layer 916 ) 917 918 self.set_forward_func_return_logits_only(True) 919 # calculate layer attribution of the step 920 attribution = layer_gradient_x_activation.attribute( 921 inputs=input, 922 target=target, 923 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 924 ) 925 self.set_forward_func_return_logits_only(False) 926 927 attribution_abs_batch_mean = torch.mean( 928 torch.abs(attribution), 929 dim=[ 930 i for i in range(attribution.dim()) if i != 1 931 ], # average the features over batch samples 932 ) 933 934 importance_step_layer = attribution_abs_batch_mean 935 importance_step_layer = importance_step_layer.detach() 936 937 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using Layer Gradient X Activation in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
939 def get_importance_step_layer_weight_gradient_square_sum( 940 self: str, 941 layer_name: str, 942 activation: Tensor, 943 if_output_weight: bool, 944 ) -> Tensor: 945 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 946 947 **Args:** 948 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 949 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 950 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 951 952 **Returns:** 953 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 954 """ 955 layer = self.backbone.get_layer_by_name(layer_name) 956 957 if not if_output_weight: 958 gradient_square = layer.weight.grad.data**2 959 gradient_square_sum = torch.sum( 960 gradient_square, 961 dim=[ 962 i for i in range(gradient_square.dim()) if i != 0 963 ], # sum over the input dimension 964 ) 965 else: 966 gradient_square = self.next_layer(layer_name).weight.grad.data**2 967 gradient_square_sum = torch.sum( 968 gradient_square, 969 dim=[ 970 i for i in range(gradient_square.dim()) if i != 1 971 ], # sum over the output dimension 972 ) 973 974 importance_step_layer = gradient_square_sum 975 importance_step_layer = importance_step_layer.detach() 976 977 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in EWC.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ). - if_output_weight (
bool): whether to use the output weights or input weights.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
979 def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( 980 self: str, 981 layer_name: str, 982 activation: Tensor, 983 if_output_weight: bool, 984 ) -> Tensor: 985 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114). 986 987 **Args:** 988 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 989 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 990 - **if_output_weight** (`bool`): whether to use the output weights or input weights. 991 992 **Returns:** 993 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 994 """ 995 layer = self.backbone.get_layer_by_name(layer_name) 996 997 if not if_output_weight: 998 gradient_square = layer.weight.grad.data**2 999 gradient_square_sum = torch.sum( 1000 gradient_square, 1001 dim=[ 1002 i for i in range(gradient_square.dim()) if i != 0 1003 ], # sum over the input dimension 1004 ) 1005 else: 1006 gradient_square = self.next_layer(layer_name).weight.grad.data**2 1007 gradient_square_sum = torch.sum( 1008 gradient_square, 1009 dim=[ 1010 i for i in range(gradient_square.dim()) if i != 1 1011 ], # sum over the output dimension 1012 ) 1013 1014 activation_abs_batch_mean = torch.mean( 1015 torch.abs(activation), 1016 dim=[ 1017 i for i in range(activation.dim()) if i != 1 1018 ], # average the features over batch samples 1019 ) 1020 1021 importance_step_layer = gradient_square_sum * activation_abs_batch_mean 1022 importance_step_layer = importance_step_layer.detach() 1023 1024 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in EWC.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ). - if_output_weight (
bool): whether to use the output weights or input weights.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1026 def get_importance_step_layer_conductance_abs( 1027 self: str, 1028 layer_name: str, 1029 input: Tensor | tuple[Tensor, ...], 1030 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1031 target: Tensor | None, 1032 batch_idx: int, 1033 num_batches: int, 1034 ) -> Tensor: 1035 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum. 1036 1037 **Args:** 1038 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1039 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1040 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details. 1041 - **target** (`Tensor` | `None`): the target batch of the training step. 1042 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1043 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ). 1044 1045 **Returns:** 1046 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1047 """ 1048 layer = self.backbone.get_layer_by_name(layer_name) 1049 1050 # initialize the Layer Conductance object 1051 layer_conductance = LayerConductance(forward_func=self.forward, layer=layer) 1052 1053 self.set_forward_func_return_logits_only(True) 1054 # calculate layer attribution of the step 1055 attribution = layer_conductance.attribute( 1056 inputs=input, 1057 baselines=baselines, 1058 target=target, 1059 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1060 ) 1061 self.set_forward_func_return_logits_only(False) 1062 1063 attribution_abs_batch_mean = torch.mean( 1064 torch.abs(attribution), 1065 dim=[ 1066 i for i in range(attribution.dim()) if i != 1 1067 ], # average the features over batch samples 1068 ) 1069 1070 importance_step_layer = attribution_abs_batch_mean 1071 importance_step_layer = importance_step_layer.detach() 1072 1073 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of conductance. We implement this using Layer Conductance in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): starting point from which integral is computed in this method. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.- mask (Tensor): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1075 def get_importance_step_layer_internal_influence_abs( 1076 self: str, 1077 layer_name: str, 1078 input: Tensor | tuple[Tensor, ...], 1079 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1080 target: Tensor | None, 1081 batch_idx: int, 1082 num_batches: int, 1083 ) -> Tensor: 1084 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum. 1085 1086 **Args:** 1087 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1088 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1089 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details. 1090 - **target** (`Tensor` | `None`): the target batch of the training step. 1091 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1092 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1093 1094 **Returns:** 1095 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1096 """ 1097 layer = self.backbone.get_layer_by_name(layer_name) 1098 1099 # initialize the Internal Influence object 1100 internal_influence = InternalInfluence(forward_func=self.forward, layer=layer) 1101 1102 # convert the target to long type to avoid error 1103 target = target.long() if target is not None else None 1104 1105 self.set_forward_func_return_logits_only(True) 1106 # calculate layer attribution of the step 1107 attribution = internal_influence.attribute( 1108 inputs=input, 1109 baselines=baselines, 1110 target=target, 1111 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1112 n_steps=5, # set 10 instead of default 50 to accelerate the computation 1113 ) 1114 self.set_forward_func_return_logits_only(False) 1115 1116 attribution_abs_batch_mean = torch.mean( 1117 torch.abs(attribution), 1118 dim=[ 1119 i for i in range(attribution.dim()) if i != 1 1120 ], # average the features over batch samples 1121 ) 1122 1123 importance_step_layer = attribution_abs_batch_mean 1124 importance_step_layer = importance_step_layer.detach() 1125 1126 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of internal influence. We implement this using Internal Influence in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): starting point from which integral is computed in this method. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1128 def get_importance_step_layer_gradcam_abs( 1129 self: str, 1130 layer_name: str, 1131 input: Tensor | tuple[Tensor, ...], 1132 target: Tensor | None, 1133 batch_idx: int, 1134 num_batches: int, 1135 ) -> Tensor: 1136 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum. 1137 1138 **Args:** 1139 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1140 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1141 - **target** (`Tensor` | `None`): the target batch of the training step. 1142 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1143 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1144 1145 **Returns:** 1146 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1147 """ 1148 layer = self.backbone.get_layer_by_name(layer_name) 1149 1150 # initialize the GradCAM object 1151 gradcam = LayerGradCam(forward_func=self.forward, layer=layer) 1152 1153 self.set_forward_func_return_logits_only(True) 1154 # calculate layer attribution of the step 1155 attribution = gradcam.attribute( 1156 inputs=input, 1157 target=target, 1158 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1159 ) 1160 self.set_forward_func_return_logits_only(False) 1161 1162 attribution_abs_batch_mean = torch.mean( 1163 torch.abs(attribution), 1164 dim=[ 1165 i for i in range(attribution.dim()) if i != 1 1166 ], # average the features over batch samples 1167 ) 1168 1169 importance_step_layer = attribution_abs_batch_mean 1170 importance_step_layer = importance_step_layer.detach() 1171 1172 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of Grad-CAM. We implement this using Layer Grad-CAM in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1174 def get_importance_step_layer_deeplift_abs( 1175 self: str, 1176 layer_name: str, 1177 input: Tensor | tuple[Tensor, ...], 1178 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1179 target: Tensor | None, 1180 batch_idx: int, 1181 num_batches: int, 1182 ) -> Tensor: 1183 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum. 1184 1185 **Args:** 1186 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1187 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1188 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details. 1189 - **target** (`Tensor` | `None`): the target batch of the training step. 1190 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1191 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1192 1193 **Returns:** 1194 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1195 """ 1196 layer = self.backbone.get_layer_by_name(layer_name) 1197 1198 # initialize the Layer DeepLift object 1199 layer_deeplift = LayerDeepLift(model=self, layer=layer) 1200 1201 # convert the target to long type to avoid error 1202 target = target.long() if target is not None else None 1203 1204 self.set_forward_func_return_logits_only(True) 1205 # calculate layer attribution of the step 1206 attribution = layer_deeplift.attribute( 1207 inputs=input, 1208 baselines=baselines, 1209 target=target, 1210 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1211 ) 1212 self.set_forward_func_return_logits_only(False) 1213 1214 attribution_abs_batch_mean = torch.mean( 1215 torch.abs(attribution), 1216 dim=[ 1217 i for i in range(attribution.dim()) if i != 1 1218 ], # average the features over batch samples 1219 ) 1220 1221 importance_step_layer = attribution_abs_batch_mean 1222 importance_step_layer = importance_step_layer.detach() 1223 1224 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of DeepLift. We implement this using Layer DeepLift in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): baselines define reference samples that are compared with the inputs. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1226 def get_importance_step_layer_deepliftshap_abs( 1227 self: str, 1228 layer_name: str, 1229 input: Tensor | tuple[Tensor, ...], 1230 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1231 target: Tensor | None, 1232 batch_idx: int, 1233 num_batches: int, 1234 ) -> Tensor: 1235 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum. 1236 1237 **Args:** 1238 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1239 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1240 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details. 1241 - **target** (`Tensor` | `None`): the target batch of the training step. 1242 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1243 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1244 1245 **Returns:** 1246 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1247 """ 1248 layer = self.backbone.get_layer_by_name(layer_name) 1249 1250 # initialize the Layer DeepLiftShap object 1251 layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer) 1252 1253 # convert the target to long type to avoid error 1254 target = target.long() if target is not None else None 1255 1256 self.set_forward_func_return_logits_only(True) 1257 # calculate layer attribution of the step 1258 attribution = layer_deepliftshap.attribute( 1259 inputs=input, 1260 baselines=baselines, 1261 target=target, 1262 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1263 ) 1264 self.set_forward_func_return_logits_only(False) 1265 1266 attribution_abs_batch_mean = torch.mean( 1267 torch.abs(attribution), 1268 dim=[ 1269 i for i in range(attribution.dim()) if i != 1 1270 ], # average the features over batch samples 1271 ) 1272 1273 importance_step_layer = attribution_abs_batch_mean 1274 importance_step_layer = importance_step_layer.detach() 1275 1276 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of DeepLift SHAP. We implement this using Layer DeepLiftShap in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): baselines define reference samples that are compared with the inputs. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1278 def get_importance_step_layer_gradientshap_abs( 1279 self: str, 1280 layer_name: str, 1281 input: Tensor | tuple[Tensor, ...], 1282 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1283 target: Tensor | None, 1284 batch_idx: int, 1285 num_batches: int, 1286 ) -> Tensor: 1287 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum. 1288 1289 **Args:** 1290 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1291 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1292 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero. 1293 - **target** (`Tensor` | `None`): the target batch of the training step. 1294 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1295 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1296 1297 **Returns:** 1298 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1299 """ 1300 layer = self.backbone.get_layer_by_name(layer_name) 1301 1302 if baselines is None: 1303 baselines = torch.zeros_like( 1304 input 1305 ) # baselines are mandatory for GradientShap API. We explicitly set them to zero 1306 1307 # initialize the Layer GradientShap object 1308 layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer) 1309 1310 # convert the target to long type to avoid error 1311 target = target.long() if target is not None else None 1312 1313 self.set_forward_func_return_logits_only(True) 1314 # calculate layer attribution of the step 1315 attribution = layer_gradientshap.attribute( 1316 inputs=input, 1317 baselines=baselines, 1318 target=target, 1319 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1320 ) 1321 self.set_forward_func_return_logits_only(False) 1322 1323 attribution_abs_batch_mean = torch.mean( 1324 torch.abs(attribution), 1325 dim=[ 1326 i for i in range(attribution.dim()) if i != 1 1327 ], # average the features over batch samples 1328 ) 1329 1330 importance_step_layer = attribution_abs_batch_mean 1331 importance_step_layer = importance_step_layer.detach() 1332 1333 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using Layer GradientShap in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): starting point from which expectation is computed. Please refer to the Captum documentation for more details. IfNone, the baselines are set to zero. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1335 def get_importance_step_layer_integrated_gradients_abs( 1336 self: str, 1337 layer_name: str, 1338 input: Tensor | tuple[Tensor, ...], 1339 baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1340 target: Tensor | None, 1341 batch_idx: int, 1342 num_batches: int, 1343 ) -> Tensor: 1344 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum. 1345 1346 **Args:** 1347 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1348 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1349 - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details. 1350 - **target** (`Tensor` | `None`): the target batch of the training step. 1351 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1352 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1353 1354 **Returns:** 1355 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1356 """ 1357 layer = self.backbone.get_layer_by_name(layer_name) 1358 1359 # initialize the Layer Integrated Gradients object 1360 layer_integrated_gradients = LayerIntegratedGradients( 1361 forward_func=self.forward, layer=layer 1362 ) 1363 1364 self.set_forward_func_return_logits_only(True) 1365 # calculate layer attribution of the step 1366 attribution = layer_integrated_gradients.attribute( 1367 inputs=input, 1368 baselines=baselines, 1369 target=target, 1370 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1371 ) 1372 self.set_forward_func_return_logits_only(False) 1373 1374 attribution_abs_batch_mean = torch.mean( 1375 torch.abs(attribution), 1376 dim=[ 1377 i for i in range(attribution.dim()) if i != 1 1378 ], # average the features over batch samples 1379 ) 1380 1381 importance_step_layer = attribution_abs_batch_mean 1382 importance_step_layer = importance_step_layer.detach() 1383 1384 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of integrated gradients. We implement this using Layer Integrated Gradients in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): starting point from which integral is computed. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1386 def get_importance_step_layer_feature_ablation_abs( 1387 self: str, 1388 layer_name: str, 1389 input: Tensor | tuple[Tensor, ...], 1390 layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...], 1391 target: Tensor | None, 1392 batch_idx: int, 1393 num_batches: int, 1394 if_captum: bool = False, 1395 ) -> Tensor: 1396 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum. 1397 1398 **Args:** 1399 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1400 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1401 - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details. 1402 - **target** (`Tensor` | `None`): the target batch of the training step. 1403 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1404 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1405 - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster. 1406 1407 **Returns:** 1408 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1409 """ 1410 layer = self.backbone.get_layer_by_name(layer_name) 1411 1412 if not if_captum: 1413 # 1. Baseline logits (take first element of forward output) 1414 baseline_out, _, _ = self.forward( 1415 input, "train", batch_idx, num_batches, self.task_id 1416 ) 1417 if target is not None: 1418 baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1) 1419 else: 1420 baseline_scores = baseline_out.sum(dim=1) 1421 1422 # 2. Capture layer’s output shape 1423 activs = {} 1424 handle = layer.register_forward_hook( 1425 lambda module, inp, out: activs.setdefault("output", out.detach()) 1426 ) 1427 _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id) 1428 handle.remove() 1429 layer_output = activs["output"] # shape (B, F, ...) 1430 1431 # 3. Build baseline tensor matching that shape 1432 if layer_baselines is None: 1433 baseline_tensor = torch.zeros_like(layer_output) 1434 elif isinstance(layer_baselines, (int, float)): 1435 baseline_tensor = torch.full_like(layer_output, layer_baselines) 1436 elif isinstance(layer_baselines, Tensor): 1437 if layer_baselines.shape == layer_output.shape: 1438 baseline_tensor = layer_baselines 1439 elif layer_baselines.shape == layer_output.shape[1:]: 1440 baseline_tensor = layer_baselines.unsqueeze(0).repeat( 1441 layer_output.size(0), *([1] * layer_baselines.ndim) 1442 ) 1443 else: 1444 raise ValueError(...) 1445 else: 1446 raise ValueError(...) 1447 1448 B, F = layer_output.size(0), layer_output.size(1) 1449 1450 # 4. Create a “mega-batch” replicating the input F times 1451 if isinstance(input, tuple): 1452 mega_inputs = tuple( 1453 t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:]) 1454 for t in input 1455 ) 1456 else: 1457 mega_inputs = ( 1458 input.unsqueeze(0) 1459 .repeat(F, *([1] * input.ndim)) 1460 .view(-1, *input.shape[1:]) 1461 ) 1462 1463 # 5. Equally replicate the baseline tensor 1464 mega_baseline = ( 1465 baseline_tensor.unsqueeze(0) 1466 .repeat(F, *([1] * baseline_tensor.ndim)) 1467 .view(-1, *baseline_tensor.shape[1:]) 1468 ) 1469 1470 # 6. Precompute vectorized indices 1471 device = layer_output.device 1472 positions = torch.arange(F * B, device=device) # [0,1,...,F*B-1] 1473 feat_idx = torch.arange(F, device=device).repeat_interleave( 1474 B 1475 ) # [0,0,...,1,1,...,F-1] 1476 1477 # 7. One hook to zero out each channel slice across the mega-batch 1478 def mega_ablate_hook(module, inp, out): 1479 out_mod = out.clone() 1480 # for each sample in mega-batch, zero its corresponding channel 1481 out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx] 1482 return out_mod 1483 1484 h = layer.register_forward_hook(mega_ablate_hook) 1485 out_all, _, _ = self.forward( 1486 mega_inputs, "train", batch_idx, num_batches, self.task_id 1487 ) 1488 h.remove() 1489 1490 # 8. Recover scores, reshape [F*B] → [F, B], diff & mean 1491 if target is not None: 1492 tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1) 1493 scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1) 1494 else: 1495 scores_all = out_all.sum(dim=1) 1496 1497 scores_all = scores_all.view(F, B) 1498 diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all) 1499 importance_step_layer = diffs.mean(dim=1).detach() # [F] 1500 1501 return importance_step_layer 1502 1503 else: 1504 # initialize the Layer Feature Ablation object 1505 layer_feature_ablation = LayerFeatureAblation( 1506 forward_func=self.forward, layer=layer 1507 ) 1508 1509 # calculate layer attribution of the step 1510 self.set_forward_func_return_logits_only(True) 1511 attribution = layer_feature_ablation.attribute( 1512 inputs=input, 1513 layer_baselines=layer_baselines, 1514 # target=target, # disable target to enable perturbations_per_eval 1515 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1516 perturbations_per_eval=128, # to accelerate the computation 1517 ) 1518 self.set_forward_func_return_logits_only(False) 1519 1520 attribution_abs_batch_mean = torch.mean( 1521 torch.abs(attribution), 1522 dim=[ 1523 i for i in range(attribution.dim()) if i != 1 1524 ], # average the features over batch samples 1525 ) 1526 1527 importance_step_layer = attribution_abs_batch_mean 1528 importance_step_layer = importance_step_layer.detach() 1529 1530 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of feature ablation attribution. We implement this using Layer Feature Ablation in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - layer_baselines (
None|int|float|Tensor|tuple[int | float | Tensor, ...]): reference values which replace each layer input / output value when ablated. Please refer to the Captum documentation for more details. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training. - if_captum (
bool): whether to use Captum or not. IfTrue, we use Captum to calculate the feature ablation. IfFalse, we use our implementation. Default isFalse, because our implementation is much faster.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1532 def get_importance_step_layer_lrp_abs( 1533 self: str, 1534 layer_name: str, 1535 input: Tensor | tuple[Tensor, ...], 1536 target: Tensor | None, 1537 batch_idx: int, 1538 num_batches: int, 1539 ) -> Tensor: 1540 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum. 1541 1542 **Args:** 1543 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1544 - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step. 1545 - **target** (`Tensor` | `None`): the target batch of the training step. 1546 - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training. 1547 - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training. 1548 1549 **Returns:** 1550 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1551 """ 1552 layer = self.backbone.get_layer_by_name(layer_name) 1553 1554 # initialize the Layer LRP object 1555 layer_lrp = LayerLRP(model=self, layer=layer) 1556 1557 # set model to evaluation mode to prevent updating the model parameters 1558 self.eval() 1559 1560 self.set_forward_func_return_logits_only(True) 1561 # calculate layer attribution of the step 1562 attribution = layer_lrp.attribute( 1563 inputs=input, 1564 target=target, 1565 additional_forward_args=("train", batch_idx, num_batches, self.task_id), 1566 ) 1567 self.set_forward_func_return_logits_only(False) 1568 1569 attribution_abs_batch_mean = torch.mean( 1570 torch.abs(attribution), 1571 dim=[ 1572 i for i in range(attribution.dim()) if i != 1 1573 ], # average the features over batch samples 1574 ) 1575 1576 importance_step_layer = attribution_abs_batch_mean 1577 importance_step_layer = importance_step_layer.detach() 1578 1579 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of LRP. We implement this using Layer LRP in Captum.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - input (
Tensor|tuple[Tensor, ...]): the input batch of the training step. - target (
Tensor|None): the target batch of the training step. - batch_idx (
int): the index of the current batch. This is an argument of the forward function during training. - num_batches (
int): the number of batches in the training step. This is an argument of the forward function during training.
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.
1581 def get_importance_step_layer_cbp_adaptive_contribution( 1582 self: str, 1583 layer_name: str, 1584 activation: Tensor, 1585 ) -> Tensor: 1586 r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7). 1587 1588 **Args:** 1589 - **layer_name** (`str`): the name of layer to get neuron-wise importance. 1590 - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ). 1591 1592 **Returns:** 1593 - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step. 1594 """ 1595 layer = self.backbone.get_layer_by_name(layer_name) 1596 1597 input_weight_abs = torch.abs(layer.weight.data) 1598 input_weight_abs_sum = torch.sum( 1599 input_weight_abs, 1600 dim=[ 1601 i for i in range(input_weight_abs.dim()) if i != 0 1602 ], # sum over the input dimension 1603 ) 1604 input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum) 1605 1606 output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data) 1607 output_weight_abs_sum = torch.sum( 1608 output_weight_abs, 1609 dim=[ 1610 i for i in range(output_weight_abs.dim()) if i != 1 1611 ], # sum over the output dimension 1612 ) 1613 1614 activation_abs_batch_mean = torch.mean( 1615 torch.abs(activation), 1616 dim=[ 1617 i for i in range(activation.dim()) if i != 1 1618 ], # average the features over batch samples 1619 ) 1620 1621 importance_step_layer = ( 1622 output_weight_abs_sum 1623 * activation_abs_batch_mean 1624 * input_weight_abs_sum_reciprocal 1625 ) 1626 importance_step_layer = importance_step_layer.detach() 1627 1628 return importance_step_layer
Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in CBP.
Args:
- layer_name (
str): the name of layer to get neuron-wise importance. - activation (
Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
Returns:
- importance_step_layer (
Tensor): the neuron-wise importance of the layer of the training step.