clarena.cl_algorithms.fgadahat

The submodule in cl_algorithms for FG-AdaHAT algorithm.

   1r"""
   2The submodule in `cl_algorithms` for FG-AdaHAT algorithm.
   3"""
   4
   5__all__ = ["FGAdaHAT"]
   6
   7import logging
   8import math
   9from typing import Any
  10
  11import torch
  12from captum.attr import (
  13    InternalInfluence,
  14    LayerConductance,
  15    LayerDeepLift,
  16    LayerDeepLiftShap,
  17    LayerFeatureAblation,
  18    LayerGradCam,
  19    LayerGradientShap,
  20    LayerGradientXActivation,
  21    LayerIntegratedGradients,
  22    LayerLRP,
  23)
  24from torch import Tensor
  25
  26from clarena.backbones import HATMaskBackbone
  27from clarena.cl_algorithms import AdaHAT
  28from clarena.heads import HeadDIL, HeadsTIL
  29from clarena.utils.metrics import HATNetworkCapacityMetric
  30from clarena.utils.transforms import min_max_normalize
  31
  32# always get logger for built-in logging in each module
  33pylogger = logging.getLogger(__name__)
  34
  35
  36class FGAdaHAT(AdaHAT):
  37    r"""FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm.
  38
  39    An architecture-based continual learning approach that improves [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
  40
  41    We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components.
  42    """
  43
  44    def __init__(
  45        self,
  46        backbone: HATMaskBackbone,
  47        heads: HeadsTIL | HeadDIL,
  48        adjustment_intensity: float,
  49        importance_type: str,
  50        importance_summing_strategy: str,
  51        importance_scheduler_type: str,
  52        neuron_to_weight_importance_aggregation_mode: str,
  53        s_max: float,
  54        clamp_threshold: float,
  55        mask_sparsity_reg_factor: float,
  56        mask_sparsity_reg_mode: str = "original",
  57        base_importance: float = 0.01,
  58        base_mask_sparsity_reg: float = 0.1,
  59        base_linear: float = 10,
  60        filter_by_cumulative_mask: bool = False,
  61        filter_unmasked_importance: bool = True,
  62        step_multiply_training_mask: bool = True,
  63        task_embedding_init_mode: str = "N01",
  64        importance_summing_strategy_linear_step: float | None = None,
  65        importance_summing_strategy_exponential_rate: float | None = None,
  66        importance_summing_strategy_log_base: float | None = None,
  67        non_algorithmic_hparams: dict[str, Any] = {},
  68        **kwargs,
  69    ) -> None:
  70        r"""Initialize the FG-AdaHAT algorithm with the network.
  71
  72        **Args:**
  73        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
  74        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. FG-AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
  75        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper).
  76        - **importance_type** (`str`): the type of neuron-wise importance, must be one of:
  77            1. 'input_weight_abs_sum': sum of absolute input weights;
  78            2. 'output_weight_abs_sum': sum of absolute output weights;
  79            3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
  80            4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
  81            5. 'activation_abs': absolute activation;
  82            6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
  83            7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
  84            8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
  85            9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
  86            10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
  87            11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
  88            12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
  89            13. 'conductance_abs': absolute layer conductance;
  90            14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
  91            15. 'gradcam_abs': absolute Grad-CAM;
  92            16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
  93            17. 'deepliftshap_abs': absolute DeepLIFT-SHAP;
  94            18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
  95            19. 'integrated_gradients_abs': absolute Integrated Gradients;
  96            20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
  97            21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
  98            22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
  99            23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
 100        - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of:
 101            1. 'add_latest': add the latest neuron-wise importance to the summative importance;
 102            2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
 103            3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
 104            4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
 105            5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
 106            6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
 107            7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
 108            8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
 109            9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
 110        - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:
 111            1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`);
 112            2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$;
 113            3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$.
 114        - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:
 115            1. 'min': take the minimum of neuron-wise importance for each weight;
 116            2. 'max': take the maximum of neuron-wise importance for each weight;
 117            3. 'mean': take the mean of neuron-wise importance for each weight.
 118        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 119        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 120        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 121        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 122            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 123            2. 'cross': the cross version of mask sparsity regularization.
 124        - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01.
 125        - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1.
 126        - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10.
 127        - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False.
 128        - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False.
 129        - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True.
 130        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 131            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 132            2. 'U-11': uniform distribution $U(-1, 1)$.
 133            3. 'U01': uniform distribution $U(0, 1)$.
 134            4. 'U-10': uniform distribution $U(-1, 0)$.
 135            5. 'last': inherit the task embedding from the last task.
 136        - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0.
 137        - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1.
 138        - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1.
 139        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 140        - **kwargs**: Reserved for multiple inheritance.
 141
 142        """
 143        super().__init__(
 144            backbone=backbone,
 145            heads=heads,
 146            adjustment_mode=None,  # use the own adjustment mechanism of FG-AdaHAT
 147            adjustment_intensity=adjustment_intensity,
 148            s_max=s_max,
 149            clamp_threshold=clamp_threshold,
 150            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 151            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 152            task_embedding_init_mode=task_embedding_init_mode,
 153            epsilon=base_mask_sparsity_reg,  # the epsilon is now the base mask sparsity regularization factor
 154            non_algorithmic_hparams=non_algorithmic_hparams,
 155            **kwargs,
 156        )
 157
 158        # save additional algorithmic hyperparameters
 159        self.save_hyperparameters(
 160            "adjustment_intensity",
 161            "importance_type",
 162            "importance_summing_strategy",
 163            "importance_scheduler_type",
 164            "neuron_to_weight_importance_aggregation_mode",
 165            "s_max",
 166            "clamp_threshold",
 167            "mask_sparsity_reg_factor",
 168            "mask_sparsity_reg_mode",
 169            "base_importance",
 170            "base_mask_sparsity_reg",
 171            "base_linear",
 172            "filter_by_cumulative_mask",
 173            "filter_unmasked_importance",
 174            "step_multiply_training_mask",
 175        )
 176
 177        self.importance_type: str | None = importance_type
 178        r"""The type of the neuron-wise importance added to AdaHAT importance."""
 179
 180        self.importance_scheduler_type: str = importance_scheduler_type
 181        r"""The type of the importance scheduler."""
 182        self.neuron_to_weight_importance_aggregation_mode: str = (
 183            neuron_to_weight_importance_aggregation_mode
 184        )
 185        r"""The mode of aggregation from neuron-wise to weight-wise importance. """
 186        self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask
 187        r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate."""
 188        self.filter_unmasked_importance: bool = filter_unmasked_importance
 189        r"""The flag to filter unmasked importance values (set them to 0) at the end of task training."""
 190        self.step_multiply_training_mask: bool = step_multiply_training_mask
 191        r"""The flag to multiply the training mask to the importance at each training step."""
 192
 193        # importance summing strategy
 194        self.importance_summing_strategy: str = importance_summing_strategy
 195        r"""The strategy to sum the neuron-wise importance for previous tasks."""
 196        if importance_summing_strategy_linear_step is not None:
 197            self.importance_summing_strategy_linear_step: float = (
 198                importance_summing_strategy_linear_step
 199            )
 200            r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease')."""
 201        if importance_summing_strategy_exponential_rate is not None:
 202            self.importance_summing_strategy_exponential_rate: float = (
 203                importance_summing_strategy_exponential_rate
 204            )
 205            r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """
 206        if importance_summing_strategy_log_base is not None:
 207            self.importance_summing_strategy_log_base: float = (
 208                importance_summing_strategy_log_base
 209            )
 210            r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """
 211
 212        # base values
 213        self.base_importance: float = base_importance
 214        r"""The base value added to the importance to avoid zero. """
 215        self.base_mask_sparsity_reg: float = base_mask_sparsity_reg
 216        r"""The base value added to the mask sparsity regularization to avoid zero. """
 217        self.base_linear: float = base_linear
 218        r"""The base value added to the linear layer to avoid zero. """
 219
 220        self.importances: dict[int, dict[str, Tensor]] = {}
 221        r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
 222        self.summative_importance_for_previous_tasks: dict[str, Tensor] = {}
 223        r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """
 224
 225        self.num_steps_t: int
 226        r"""The number of training steps for the current task `self.task_id`."""
 227        # set manual optimization
 228        self.automatic_optimization = False
 229
 230        FGAdaHAT.sanity_check(self)
 231
 232    def sanity_check(self) -> None:
 233        r"""Sanity check."""
 234
 235        # check importance type
 236        if self.importance_type not in [
 237            "input_weight_abs_sum",
 238            "output_weight_abs_sum",
 239            "input_weight_gradient_abs_sum",
 240            "output_weight_gradient_abs_sum",
 241            "activation_abs",
 242            "input_weight_abs_sum_x_activation_abs",
 243            "output_weight_abs_sum_x_activation_abs",
 244            "gradient_x_activation_abs",
 245            "input_weight_gradient_square_sum",
 246            "output_weight_gradient_square_sum",
 247            "input_weight_gradient_square_sum_x_activation_abs",
 248            "output_weight_gradient_square_sum_x_activation_abs",
 249            "conductance_abs",
 250            "internal_influence_abs",
 251            "gradcam_abs",
 252            "deeplift_abs",
 253            "deepliftshap_abs",
 254            "gradientshap_abs",
 255            "integrated_gradients_abs",
 256            "feature_ablation_abs",
 257            "lrp_abs",
 258            "cbp_adaptation",
 259            "cbp_adaptive_contribution",
 260        ]:
 261            raise ValueError(
 262                f"importance_type must be one of the predefined types, but got {self.importance_type}"
 263            )
 264
 265        # check importance summing strategy
 266        if self.importance_summing_strategy not in [
 267            "add_latest",
 268            "add_all",
 269            "add_average",
 270            "linear_decrease",
 271            "quadratic_decrease",
 272            "cubic_decrease",
 273            "exponential_decrease",
 274            "log_decrease",
 275            "factorial_decrease",
 276        ]:
 277            raise ValueError(
 278                f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}"
 279            )
 280
 281        # check importance scheduler type
 282        if self.importance_scheduler_type not in [
 283            "linear_sparsity_reg",
 284            "sparsity_reg",
 285            "summative_mask_sparsity_reg",
 286        ]:
 287            raise ValueError(
 288                f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}"
 289            )
 290
 291        # check neuron to weight importance aggregation mode
 292        if self.neuron_to_weight_importance_aggregation_mode not in [
 293            "min",
 294            "max",
 295            "mean",
 296        ]:
 297            raise ValueError(
 298                f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}"
 299            )
 300
 301        # check base values
 302        if self.base_importance < 0:
 303            raise ValueError(
 304                f"base_importance must be >= 0, but got {self.base_importance}"
 305            )
 306        if self.base_mask_sparsity_reg <= 0:
 307            raise ValueError(
 308                f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}"
 309            )
 310        if self.base_linear <= 0:
 311            raise ValueError(f"base_linear must be > 0, but got {self.base_linear}")
 312
 313    def on_train_start(self) -> None:
 314        r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization."""
 315        super().on_train_start()
 316
 317        self.importances[self.task_id] = (
 318            {}
 319        )  # initialize the importance for the current task
 320
 321        # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time.
 322        for layer_name in self.backbone.weighted_layer_names:
 323            layer = self.backbone.get_layer_by_name(
 324                layer_name
 325            )  # get the layer by its name
 326            num_units = layer.weight.shape[0]
 327
 328            # initialize the accumulated importance at the beginning of each task
 329            self.importances[self.task_id][layer_name] = torch.zeros(num_units).to(
 330                self.device
 331            )
 332
 333            # reset the number of steps counter for the current task
 334            self.num_steps_t = 0
 335
 336            # initialize the summative neuron-wise importance at the beginning of the first task
 337            if self.task_id == 1:
 338                self.summative_importance_for_previous_tasks[layer_name] = torch.zeros(
 339                    num_units
 340                ).to(
 341                    self.device
 342                )  # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$
 343
 344    def clip_grad_by_adjustment(
 345        self,
 346        network_sparsity: dict[str, Tensor],
 347    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
 348        r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper.
 349
 350        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
 351
 352        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 353
 354        **Args:**
 355        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.
 356
 357        **Returns:**
 358        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
 359        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
 360        - **capacity** (`Tensor`): the calculated network capacity.
 361        """
 362
 363        # initialize network capacity metric
 364        capacity = HATNetworkCapacityMetric().to(self.device)
 365        adjustment_rate_weight = {}
 366        adjustment_rate_bias = {}
 367
 368        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper
 369        for layer_name in self.backbone.weighted_layer_names:
 370
 371            layer = self.backbone.get_layer_by_name(
 372                layer_name
 373            )  # get the layer by its name
 374
 375            # placeholder for the adjustment rate to avoid the error of using it before assignment
 376            adjustment_rate_weight_layer = 1
 377            adjustment_rate_bias_layer = 1
 378
 379            # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask
 380            weight_importance, bias_importance = (
 381                self.backbone.get_layer_measure_parameter_wise(
 382                    neuron_wise_measure=self.summative_importance_for_previous_tasks,
 383                    layer_name=layer_name,
 384                    aggregation_mode=self.neuron_to_weight_importance_aggregation_mode,
 385                )
 386            )
 387
 388            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
 389                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
 390                layer_name=layer_name,
 391                aggregation_mode="min",
 392            )
 393
 394            # filter the weight importance by the cumulative mask
 395            if self.filter_by_cumulative_mask:
 396                weight_importance = weight_importance * weight_mask
 397                bias_importance = bias_importance * bias_mask
 398
 399            network_sparsity_layer = network_sparsity[layer_name]
 400
 401            # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper
 402            factor = network_sparsity_layer + self.base_mask_sparsity_reg
 403            if self.importance_scheduler_type == "linear_sparsity_reg":
 404                factor = factor * (self.task_id + self.base_linear)
 405            elif self.importance_scheduler_type == "sparsity_reg":
 406                pass
 407            elif self.importance_scheduler_type == "summative_mask_sparsity_reg":
 408                factor = factor * (
 409                    self.summative_mask_for_previous_tasks + self.base_linear
 410                )
 411
 412            # calculate the adjustment rate
 413            adjustment_rate_weight_layer = torch.div(
 414                self.adjustment_intensity,
 415                (factor * weight_importance + self.adjustment_intensity),
 416            )
 417
 418            adjustment_rate_bias_layer = torch.div(
 419                self.adjustment_intensity,
 420                (factor * bias_importance + self.adjustment_intensity),
 421            )
 422
 423            # apply the adjustment rate to the gradients
 424            layer.weight.grad.data *= adjustment_rate_weight_layer
 425            if layer.bias is not None:
 426                layer.bias.grad.data *= adjustment_rate_bias_layer
 427
 428            # store the adjustment rate for logging
 429            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
 430            if layer.bias is not None:
 431                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
 432
 433            # update network capacity metric
 434            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
 435
 436        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
 437
 438    def on_train_batch_end(
 439        self, outputs: dict[str, Any], batch: Any, batch_idx: int
 440    ) -> None:
 441        r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.
 442
 443        **Args:**
 444        - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`).
 445        - **batch** (`Any`): training data batch.
 446        - **batch_idx** (`int`): index of the current batch (for mask figure file name).
 447        """
 448
 449        # get potential useful information from training batch
 450        activations = outputs["activations"]
 451        input = outputs["input"]
 452        target = outputs["target"]
 453        mask = outputs["mask"]
 454        num_batches = self.trainer.num_training_batches
 455
 456        for layer_name in self.backbone.weighted_layer_names:
 457            # layer-wise operation
 458
 459            activation = activations[layer_name]
 460
 461            # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper.
 462            if self.importance_type == "input_weight_abs_sum":
 463                importance_step = self.get_importance_step_layer_weight_abs_sum(
 464                    layer_name=layer_name,
 465                    if_output_weight=False,
 466                    reciprocal=False,
 467                )
 468            elif self.importance_type == "output_weight_abs_sum":
 469                importance_step = self.get_importance_step_layer_weight_abs_sum(
 470                    layer_name=layer_name,
 471                    if_output_weight=True,
 472                    reciprocal=False,
 473                )
 474            elif self.importance_type == "input_weight_gradient_abs_sum":
 475                importance_step = (
 476                    self.get_importance_step_layer_weight_gradient_abs_sum(
 477                        layer_name=layer_name, if_output_weight=False
 478                    )
 479                )
 480            elif self.importance_type == "output_weight_gradient_abs_sum":
 481                importance_step = (
 482                    self.get_importance_step_layer_weight_gradient_abs_sum(
 483                        layer_name=layer_name, if_output_weight=True
 484                    )
 485                )
 486            elif self.importance_type == "activation_abs":
 487                importance_step = self.get_importance_step_layer_activation_abs(
 488                    activation=activation
 489                )
 490            elif self.importance_type == "input_weight_abs_sum_x_activation_abs":
 491                importance_step = (
 492                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
 493                        layer_name=layer_name,
 494                        activation=activation,
 495                        if_output_weight=False,
 496                    )
 497                )
 498            elif self.importance_type == "output_weight_abs_sum_x_activation_abs":
 499                importance_step = (
 500                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
 501                        layer_name=layer_name,
 502                        activation=activation,
 503                        if_output_weight=True,
 504                    )
 505                )
 506            elif self.importance_type == "gradient_x_activation_abs":
 507                importance_step = (
 508                    self.get_importance_step_layer_gradient_x_activation_abs(
 509                        layer_name=layer_name,
 510                        input=input,
 511                        target=target,
 512                        batch_idx=batch_idx,
 513                        num_batches=num_batches,
 514                    )
 515                )
 516            elif self.importance_type == "input_weight_gradient_square_sum":
 517                importance_step = (
 518                    self.get_importance_step_layer_weight_gradient_square_sum(
 519                        layer_name=layer_name,
 520                        activation=activation,
 521                        if_output_weight=False,
 522                    )
 523                )
 524            elif self.importance_type == "output_weight_gradient_square_sum":
 525                importance_step = (
 526                    self.get_importance_step_layer_weight_gradient_square_sum(
 527                        layer_name=layer_name,
 528                        activation=activation,
 529                        if_output_weight=True,
 530                    )
 531                )
 532            elif (
 533                self.importance_type
 534                == "input_weight_gradient_square_sum_x_activation_abs"
 535            ):
 536                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 537                    layer_name=layer_name,
 538                    activation=activation,
 539                    if_output_weight=False,
 540                )
 541            elif (
 542                self.importance_type
 543                == "output_weight_gradient_square_sum_x_activation_abs"
 544            ):
 545                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 546                    layer_name=layer_name,
 547                    activation=activation,
 548                    if_output_weight=True,
 549                )
 550            elif self.importance_type == "conductance_abs":
 551                importance_step = self.get_importance_step_layer_conductance_abs(
 552                    layer_name=layer_name,
 553                    input=input,
 554                    baselines=None,
 555                    target=target,
 556                    batch_idx=batch_idx,
 557                    num_batches=num_batches,
 558                )
 559            elif self.importance_type == "internal_influence_abs":
 560                importance_step = self.get_importance_step_layer_internal_influence_abs(
 561                    layer_name=layer_name,
 562                    input=input,
 563                    baselines=None,
 564                    target=target,
 565                    batch_idx=batch_idx,
 566                    num_batches=num_batches,
 567                )
 568            elif self.importance_type == "gradcam_abs":
 569                importance_step = self.get_importance_step_layer_gradcam_abs(
 570                    layer_name=layer_name,
 571                    input=input,
 572                    target=target,
 573                    batch_idx=batch_idx,
 574                    num_batches=num_batches,
 575                )
 576            elif self.importance_type == "deeplift_abs":
 577                importance_step = self.get_importance_step_layer_deeplift_abs(
 578                    layer_name=layer_name,
 579                    input=input,
 580                    baselines=None,
 581                    target=target,
 582                    batch_idx=batch_idx,
 583                    num_batches=num_batches,
 584                )
 585            elif self.importance_type == "deepliftshap_abs":
 586                importance_step = self.get_importance_step_layer_deepliftshap_abs(
 587                    layer_name=layer_name,
 588                    input=input,
 589                    baselines=None,
 590                    target=target,
 591                    batch_idx=batch_idx,
 592                    num_batches=num_batches,
 593                )
 594            elif self.importance_type == "gradientshap_abs":
 595                importance_step = self.get_importance_step_layer_gradientshap_abs(
 596                    layer_name=layer_name,
 597                    input=input,
 598                    baselines=None,
 599                    target=target,
 600                    batch_idx=batch_idx,
 601                    num_batches=num_batches,
 602                )
 603            elif self.importance_type == "integrated_gradients_abs":
 604                importance_step = (
 605                    self.get_importance_step_layer_integrated_gradients_abs(
 606                        layer_name=layer_name,
 607                        input=input,
 608                        baselines=None,
 609                        target=target,
 610                        batch_idx=batch_idx,
 611                        num_batches=num_batches,
 612                    )
 613                )
 614            elif self.importance_type == "feature_ablation_abs":
 615                importance_step = self.get_importance_step_layer_feature_ablation_abs(
 616                    layer_name=layer_name,
 617                    input=input,
 618                    layer_baselines=None,
 619                    target=target,
 620                    batch_idx=batch_idx,
 621                    num_batches=num_batches,
 622                )
 623            elif self.importance_type == "lrp_abs":
 624                importance_step = self.get_importance_step_layer_lrp_abs(
 625                    layer_name=layer_name,
 626                    input=input,
 627                    target=target,
 628                    batch_idx=batch_idx,
 629                    num_batches=num_batches,
 630                )
 631            elif self.importance_type == "cbp_adaptation":
 632                importance_step = self.get_importance_step_layer_weight_abs_sum(
 633                    layer_name=layer_name,
 634                    if_output_weight=False,
 635                    reciprocal=True,
 636                )
 637            elif self.importance_type == "cbp_adaptive_contribution":
 638                importance_step = (
 639                    self.get_importance_step_layer_cbp_adaptive_contribution(
 640                        layer_name=layer_name,
 641                        activation=activation,
 642                    )
 643                )
 644
 645            importance_step = min_max_normalize(
 646                importance_step
 647            )  # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper
 648
 649            # multiply the importance by the training mask. See Eq. (6) in the paper
 650            if self.step_multiply_training_mask:
 651                importance_step = importance_step * mask[layer_name]
 652
 653            # update accumulated importance
 654            self.importances[self.task_id][layer_name] = (
 655                self.importances[self.task_id][layer_name] + importance_step
 656            )
 657
 658        # update number of steps counter
 659        self.num_steps_t += 1
 660
 661    def on_train_end(self) -> None:
 662        r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task."""
 663        super().on_train_end()  # store the mask and update cumulative and summative masks
 664
 665        for layer_name in self.backbone.weighted_layer_names:
 666
 667            # average the neuron-wise step importance. See Eq. (4) in the paper
 668            self.importances[self.task_id][layer_name] = (
 669                self.importances[self.task_id][layer_name]
 670            ) / self.num_steps_t
 671
 672            # add the base importance. See Eq. (6) in the paper
 673            self.importances[self.task_id][layer_name] = (
 674                self.importances[self.task_id][layer_name] + self.base_importance
 675            )
 676
 677            # filter unmasked importance
 678            if self.filter_unmasked_importance:
 679                self.importances[self.task_id][layer_name] = (
 680                    self.importances[self.task_id][layer_name]
 681                    * self.backbone.masks[f"{self.task_id}"][layer_name]
 682                )
 683
 684            # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper
 685            if self.importance_summing_strategy == "add_latest":
 686                self.summative_importance_for_previous_tasks[
 687                    layer_name
 688                ] += self.importances[self.task_id][layer_name]
 689
 690            elif self.importance_summing_strategy == "add_all":
 691                for t in range(1, self.task_id + 1):
 692                    self.summative_importance_for_previous_tasks[
 693                        layer_name
 694                    ] += self.importances[t][layer_name]
 695
 696            elif self.importance_summing_strategy == "add_average":
 697                for t in range(1, self.task_id + 1):
 698                    self.summative_importance_for_previous_tasks[layer_name] += (
 699                        self.importances[t][layer_name] / self.task_id
 700                    )
 701            else:
 702                self.summative_importance_for_previous_tasks[
 703                    layer_name
 704                ] = torch.zeros_like(
 705                    self.summative_importance_for_previous_tasks[layer_name]
 706                ).to(
 707                    self.device
 708                )  # starting adding from 0
 709
 710                if self.importance_summing_strategy == "linear_decrease":
 711                    s = self.importance_summing_strategy_linear_step
 712                    for t in range(1, self.task_id + 1):
 713                        w_t = s * (self.task_id - t) + 1
 714
 715                elif self.importance_summing_strategy == "quadratic_decrease":
 716                    for t in range(1, self.task_id + 1):
 717                        w_t = (self.task_id - t + 1) ** 2
 718                elif self.importance_summing_strategy == "cubic_decrease":
 719                    for t in range(1, self.task_id + 1):
 720                        w_t = (self.task_id - t + 1) ** 3
 721                elif self.importance_summing_strategy == "exponential_decrease":
 722                    for t in range(1, self.task_id + 1):
 723                        r = self.importance_summing_strategy_exponential_rate
 724
 725                        w_t = r ** (self.task_id - t + 1)
 726                elif self.importance_summing_strategy == "log_decrease":
 727                    a = self.importance_summing_strategy_log_base
 728                    for t in range(1, self.task_id + 1):
 729                        w_t = math.log(self.task_id - t, a) + 1
 730                elif self.importance_summing_strategy == "factorial_decrease":
 731                    for t in range(1, self.task_id + 1):
 732                        w_t = math.factorial(self.task_id - t + 1)
 733                else:
 734                    raise ValueError
 735                self.summative_importance_for_previous_tasks[layer_name] += (
 736                    self.importances[t][layer_name] * w_t
 737                )
 738
 739    def get_importance_step_layer_weight_abs_sum(
 740        self: str,
 741        layer_name: str,
 742        if_output_weight: bool,
 743        reciprocal: bool,
 744    ) -> Tensor:
 745        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.
 746
 747        **Args:**
 748        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 749        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 750        - **reciprocal** (`bool`): whether to take reciprocal.
 751
 752        **Returns:**
 753        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 754        """
 755        layer = self.backbone.get_layer_by_name(layer_name)
 756
 757        if not if_output_weight:
 758            weight_abs = torch.abs(layer.weight.data)
 759            weight_abs_sum = torch.sum(
 760                weight_abs,
 761                dim=[
 762                    i for i in range(weight_abs.dim()) if i != 0
 763                ],  # sum over the input dimension
 764            )
 765        else:
 766            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
 767            weight_abs_sum = torch.sum(
 768                weight_abs,
 769                dim=[
 770                    i for i in range(weight_abs.dim()) if i != 1
 771                ],  # sum over the output dimension
 772            )
 773
 774        if reciprocal:
 775            weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum)
 776            importance_step_layer = weight_abs_sum_reciprocal
 777        else:
 778            importance_step_layer = weight_abs_sum
 779        importance_step_layer = importance_step_layer.detach()
 780
 781        return importance_step_layer
 782
 783    def get_importance_step_layer_weight_gradient_abs_sum(
 784        self: str,
 785        layer_name: str,
 786        if_output_weight: bool,
 787    ) -> Tensor:
 788        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.
 789
 790        **Args:**
 791        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 792        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 793
 794        **Returns:**
 795        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 796        """
 797        layer = self.backbone.get_layer_by_name(layer_name)
 798
 799        if not if_output_weight:
 800            gradient_abs = torch.abs(layer.weight.grad.data)
 801            gradient_abs_sum = torch.sum(
 802                gradient_abs,
 803                dim=[
 804                    i for i in range(gradient_abs.dim()) if i != 0
 805                ],  # sum over the input dimension
 806            )
 807        else:
 808            gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data)
 809            gradient_abs_sum = torch.sum(
 810                gradient_abs,
 811                dim=[
 812                    i for i in range(gradient_abs.dim()) if i != 1
 813                ],  # sum over the output dimension
 814            )
 815
 816        importance_step_layer = gradient_abs_sum
 817        importance_step_layer = importance_step_layer.detach()
 818
 819        return importance_step_layer
 820
 821    def get_importance_step_layer_activation_abs(
 822        self: str,
 823        activation: Tensor,
 824    ) -> Tensor:
 825        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum.
 826
 827        **Args:**
 828        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 829
 830        **Returns:**
 831        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 832        """
 833        activation_abs_batch_mean = torch.mean(
 834            torch.abs(activation),
 835            dim=[
 836                i for i in range(activation.dim()) if i != 1
 837            ],  # average the features over batch samples
 838        )
 839        importance_step_layer = activation_abs_batch_mean
 840        importance_step_layer = importance_step_layer.detach()
 841
 842        return importance_step_layer
 843
 844    def get_importance_step_layer_weight_abs_sum_x_activation_abs(
 845        self: str,
 846        layer_name: str,
 847        activation: Tensor,
 848        if_output_weight: bool,
 849    ) -> Tensor:
 850        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
 851
 852        **Args:**
 853        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 854        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 855        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 856
 857        **Returns:**
 858        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 859        """
 860        layer = self.backbone.get_layer_by_name(layer_name)
 861
 862        if not if_output_weight:
 863            weight_abs = torch.abs(layer.weight.data)
 864            weight_abs_sum = torch.sum(
 865                weight_abs,
 866                dim=[
 867                    i for i in range(weight_abs.dim()) if i != 0
 868                ],  # sum over the input dimension
 869            )
 870        else:
 871            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
 872            weight_abs_sum = torch.sum(
 873                weight_abs,
 874                dim=[
 875                    i for i in range(weight_abs.dim()) if i != 1
 876                ],  # sum over the output dimension
 877            )
 878
 879        activation_abs_batch_mean = torch.mean(
 880            torch.abs(activation),
 881            dim=[
 882                i for i in range(activation.dim()) if i != 1
 883            ],  # average the features over batch samples
 884        )
 885
 886        importance_step_layer = weight_abs_sum * activation_abs_batch_mean
 887        importance_step_layer = importance_step_layer.detach()
 888
 889        return importance_step_layer
 890
 891    def get_importance_step_layer_gradient_x_activation_abs(
 892        self: str,
 893        layer_name: str,
 894        input: Tensor | tuple[Tensor, ...],
 895        target: Tensor | None,
 896        batch_idx: int,
 897        num_batches: int,
 898    ) -> Tensor:
 899        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum.
 900
 901        **Args:**
 902        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 903        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
 904        - **target** (`Tensor` | `None`): the target batch of the training step.
 905        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
 906        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
 907
 908        **Returns:**
 909        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 910        """
 911        layer = self.backbone.get_layer_by_name(layer_name)
 912
 913        input = input.requires_grad_()
 914
 915        # initialize the Layer Gradient X Activation object
 916        layer_gradient_x_activation = LayerGradientXActivation(
 917            forward_func=self.forward, layer=layer
 918        )
 919
 920        self.set_forward_func_return_logits_only(True)
 921        # calculate layer attribution of the step
 922        attribution = layer_gradient_x_activation.attribute(
 923            inputs=input,
 924            target=target,
 925            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
 926        )
 927        self.set_forward_func_return_logits_only(False)
 928
 929        attribution_abs_batch_mean = torch.mean(
 930            torch.abs(attribution),
 931            dim=[
 932                i for i in range(attribution.dim()) if i != 1
 933            ],  # average the features over batch samples
 934        )
 935
 936        importance_step_layer = attribution_abs_batch_mean
 937        importance_step_layer = importance_step_layer.detach()
 938
 939        return importance_step_layer
 940
 941    def get_importance_step_layer_weight_gradient_square_sum(
 942        self: str,
 943        layer_name: str,
 944        activation: Tensor,
 945        if_output_weight: bool,
 946    ) -> Tensor:
 947        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 948
 949        **Args:**
 950        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 951        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 952        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 953
 954        **Returns:**
 955        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 956        """
 957        layer = self.backbone.get_layer_by_name(layer_name)
 958
 959        if not if_output_weight:
 960            gradient_square = layer.weight.grad.data**2
 961            gradient_square_sum = torch.sum(
 962                gradient_square,
 963                dim=[
 964                    i for i in range(gradient_square.dim()) if i != 0
 965                ],  # sum over the input dimension
 966            )
 967        else:
 968            gradient_square = self.next_layer(layer_name).weight.grad.data**2
 969            gradient_square_sum = torch.sum(
 970                gradient_square,
 971                dim=[
 972                    i for i in range(gradient_square.dim()) if i != 1
 973                ],  # sum over the output dimension
 974            )
 975
 976        importance_step_layer = gradient_square_sum
 977        importance_step_layer = importance_step_layer.detach()
 978
 979        return importance_step_layer
 980
 981    def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 982        self: str,
 983        layer_name: str,
 984        activation: Tensor,
 985        if_output_weight: bool,
 986    ) -> Tensor:
 987        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 988
 989        **Args:**
 990        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 991        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 992        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 993
 994        **Returns:**
 995        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 996        """
 997        layer = self.backbone.get_layer_by_name(layer_name)
 998
 999        if not if_output_weight:
1000            gradient_square = layer.weight.grad.data**2
1001            gradient_square_sum = torch.sum(
1002                gradient_square,
1003                dim=[
1004                    i for i in range(gradient_square.dim()) if i != 0
1005                ],  # sum over the input dimension
1006            )
1007        else:
1008            gradient_square = self.next_layer(layer_name).weight.grad.data**2
1009            gradient_square_sum = torch.sum(
1010                gradient_square,
1011                dim=[
1012                    i for i in range(gradient_square.dim()) if i != 1
1013                ],  # sum over the output dimension
1014            )
1015
1016        activation_abs_batch_mean = torch.mean(
1017            torch.abs(activation),
1018            dim=[
1019                i for i in range(activation.dim()) if i != 1
1020            ],  # average the features over batch samples
1021        )
1022
1023        importance_step_layer = gradient_square_sum * activation_abs_batch_mean
1024        importance_step_layer = importance_step_layer.detach()
1025
1026        return importance_step_layer
1027
1028    def get_importance_step_layer_conductance_abs(
1029        self: str,
1030        layer_name: str,
1031        input: Tensor | tuple[Tensor, ...],
1032        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1033        target: Tensor | None,
1034        batch_idx: int,
1035        num_batches: int,
1036    ) -> Tensor:
1037        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum.
1038
1039        **Args:**
1040        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1041        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1042        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details.
1043        - **target** (`Tensor` | `None`): the target batch of the training step.
1044        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1045        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).
1046
1047        **Returns:**
1048        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1049        """
1050        layer = self.backbone.get_layer_by_name(layer_name)
1051
1052        # initialize the Layer Conductance object
1053        layer_conductance = LayerConductance(forward_func=self.forward, layer=layer)
1054
1055        self.set_forward_func_return_logits_only(True)
1056        # calculate layer attribution of the step
1057        attribution = layer_conductance.attribute(
1058            inputs=input,
1059            baselines=baselines,
1060            target=target,
1061            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1062        )
1063        self.set_forward_func_return_logits_only(False)
1064
1065        attribution_abs_batch_mean = torch.mean(
1066            torch.abs(attribution),
1067            dim=[
1068                i for i in range(attribution.dim()) if i != 1
1069            ],  # average the features over batch samples
1070        )
1071
1072        importance_step_layer = attribution_abs_batch_mean
1073        importance_step_layer = importance_step_layer.detach()
1074
1075        return importance_step_layer
1076
1077    def get_importance_step_layer_internal_influence_abs(
1078        self: str,
1079        layer_name: str,
1080        input: Tensor | tuple[Tensor, ...],
1081        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1082        target: Tensor | None,
1083        batch_idx: int,
1084        num_batches: int,
1085    ) -> Tensor:
1086        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum.
1087
1088        **Args:**
1089        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1090        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1091        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details.
1092        - **target** (`Tensor` | `None`): the target batch of the training step.
1093        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1094        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1095
1096        **Returns:**
1097        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1098        """
1099        layer = self.backbone.get_layer_by_name(layer_name)
1100
1101        # initialize the Internal Influence object
1102        internal_influence = InternalInfluence(forward_func=self.forward, layer=layer)
1103
1104        # convert the target to long type to avoid error
1105        target = target.long() if target is not None else None
1106
1107        self.set_forward_func_return_logits_only(True)
1108        # calculate layer attribution of the step
1109        attribution = internal_influence.attribute(
1110            inputs=input,
1111            baselines=baselines,
1112            target=target,
1113            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1114            n_steps=5,  # set 10 instead of default 50 to accelerate the computation
1115        )
1116        self.set_forward_func_return_logits_only(False)
1117
1118        attribution_abs_batch_mean = torch.mean(
1119            torch.abs(attribution),
1120            dim=[
1121                i for i in range(attribution.dim()) if i != 1
1122            ],  # average the features over batch samples
1123        )
1124
1125        importance_step_layer = attribution_abs_batch_mean
1126        importance_step_layer = importance_step_layer.detach()
1127
1128        return importance_step_layer
1129
1130    def get_importance_step_layer_gradcam_abs(
1131        self: str,
1132        layer_name: str,
1133        input: Tensor | tuple[Tensor, ...],
1134        target: Tensor | None,
1135        batch_idx: int,
1136        num_batches: int,
1137    ) -> Tensor:
1138        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum.
1139
1140        **Args:**
1141        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1142        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1143        - **target** (`Tensor` | `None`): the target batch of the training step.
1144        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1145        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1146
1147        **Returns:**
1148        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1149        """
1150        layer = self.backbone.get_layer_by_name(layer_name)
1151
1152        # initialize the GradCAM object
1153        gradcam = LayerGradCam(forward_func=self.forward, layer=layer)
1154
1155        self.set_forward_func_return_logits_only(True)
1156        # calculate layer attribution of the step
1157        attribution = gradcam.attribute(
1158            inputs=input,
1159            target=target,
1160            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1161        )
1162        self.set_forward_func_return_logits_only(False)
1163
1164        attribution_abs_batch_mean = torch.mean(
1165            torch.abs(attribution),
1166            dim=[
1167                i for i in range(attribution.dim()) if i != 1
1168            ],  # average the features over batch samples
1169        )
1170
1171        importance_step_layer = attribution_abs_batch_mean
1172        importance_step_layer = importance_step_layer.detach()
1173
1174        return importance_step_layer
1175
1176    def get_importance_step_layer_deeplift_abs(
1177        self: str,
1178        layer_name: str,
1179        input: Tensor | tuple[Tensor, ...],
1180        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1181        target: Tensor | None,
1182        batch_idx: int,
1183        num_batches: int,
1184    ) -> Tensor:
1185        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum.
1186
1187        **Args:**
1188        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1189        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1190        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details.
1191        - **target** (`Tensor` | `None`): the target batch of the training step.
1192        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1193        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1194
1195        **Returns:**
1196        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1197        """
1198        layer = self.backbone.get_layer_by_name(layer_name)
1199
1200        # initialize the Layer DeepLift object
1201        layer_deeplift = LayerDeepLift(model=self, layer=layer)
1202
1203        # convert the target to long type to avoid error
1204        target = target.long() if target is not None else None
1205
1206        self.set_forward_func_return_logits_only(True)
1207        # calculate layer attribution of the step
1208        attribution = layer_deeplift.attribute(
1209            inputs=input,
1210            baselines=baselines,
1211            target=target,
1212            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1213        )
1214        self.set_forward_func_return_logits_only(False)
1215
1216        attribution_abs_batch_mean = torch.mean(
1217            torch.abs(attribution),
1218            dim=[
1219                i for i in range(attribution.dim()) if i != 1
1220            ],  # average the features over batch samples
1221        )
1222
1223        importance_step_layer = attribution_abs_batch_mean
1224        importance_step_layer = importance_step_layer.detach()
1225
1226        return importance_step_layer
1227
1228    def get_importance_step_layer_deepliftshap_abs(
1229        self: str,
1230        layer_name: str,
1231        input: Tensor | tuple[Tensor, ...],
1232        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1233        target: Tensor | None,
1234        batch_idx: int,
1235        num_batches: int,
1236    ) -> Tensor:
1237        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum.
1238
1239        **Args:**
1240        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1241        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1242        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details.
1243        - **target** (`Tensor` | `None`): the target batch of the training step.
1244        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1245        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1246
1247        **Returns:**
1248        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1249        """
1250        layer = self.backbone.get_layer_by_name(layer_name)
1251
1252        # initialize the Layer DeepLiftShap object
1253        layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer)
1254
1255        # convert the target to long type to avoid error
1256        target = target.long() if target is not None else None
1257
1258        self.set_forward_func_return_logits_only(True)
1259        # calculate layer attribution of the step
1260        attribution = layer_deepliftshap.attribute(
1261            inputs=input,
1262            baselines=baselines,
1263            target=target,
1264            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1265        )
1266        self.set_forward_func_return_logits_only(False)
1267
1268        attribution_abs_batch_mean = torch.mean(
1269            torch.abs(attribution),
1270            dim=[
1271                i for i in range(attribution.dim()) if i != 1
1272            ],  # average the features over batch samples
1273        )
1274
1275        importance_step_layer = attribution_abs_batch_mean
1276        importance_step_layer = importance_step_layer.detach()
1277
1278        return importance_step_layer
1279
1280    def get_importance_step_layer_gradientshap_abs(
1281        self: str,
1282        layer_name: str,
1283        input: Tensor | tuple[Tensor, ...],
1284        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1285        target: Tensor | None,
1286        batch_idx: int,
1287        num_batches: int,
1288    ) -> Tensor:
1289        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum.
1290
1291        **Args:**
1292        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1293        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1294        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero.
1295        - **target** (`Tensor` | `None`): the target batch of the training step.
1296        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1297        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1298
1299        **Returns:**
1300        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1301        """
1302        layer = self.backbone.get_layer_by_name(layer_name)
1303
1304        if baselines is None:
1305            baselines = torch.zeros_like(
1306                input
1307            )  # baselines are mandatory for GradientShap API. We explicitly set them to zero
1308
1309        # initialize the Layer GradientShap object
1310        layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer)
1311
1312        # convert the target to long type to avoid error
1313        target = target.long() if target is not None else None
1314
1315        self.set_forward_func_return_logits_only(True)
1316        # calculate layer attribution of the step
1317        attribution = layer_gradientshap.attribute(
1318            inputs=input,
1319            baselines=baselines,
1320            target=target,
1321            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1322        )
1323        self.set_forward_func_return_logits_only(False)
1324
1325        attribution_abs_batch_mean = torch.mean(
1326            torch.abs(attribution),
1327            dim=[
1328                i for i in range(attribution.dim()) if i != 1
1329            ],  # average the features over batch samples
1330        )
1331
1332        importance_step_layer = attribution_abs_batch_mean
1333        importance_step_layer = importance_step_layer.detach()
1334
1335        return importance_step_layer
1336
1337    def get_importance_step_layer_integrated_gradients_abs(
1338        self: str,
1339        layer_name: str,
1340        input: Tensor | tuple[Tensor, ...],
1341        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1342        target: Tensor | None,
1343        batch_idx: int,
1344        num_batches: int,
1345    ) -> Tensor:
1346        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum.
1347
1348        **Args:**
1349        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1350        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1351        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details.
1352        - **target** (`Tensor` | `None`): the target batch of the training step.
1353        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1354        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1355
1356        **Returns:**
1357        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1358        """
1359        layer = self.backbone.get_layer_by_name(layer_name)
1360
1361        # initialize the Layer Integrated Gradients object
1362        layer_integrated_gradients = LayerIntegratedGradients(
1363            forward_func=self.forward, layer=layer
1364        )
1365
1366        self.set_forward_func_return_logits_only(True)
1367        # calculate layer attribution of the step
1368        attribution = layer_integrated_gradients.attribute(
1369            inputs=input,
1370            baselines=baselines,
1371            target=target,
1372            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1373        )
1374        self.set_forward_func_return_logits_only(False)
1375
1376        attribution_abs_batch_mean = torch.mean(
1377            torch.abs(attribution),
1378            dim=[
1379                i for i in range(attribution.dim()) if i != 1
1380            ],  # average the features over batch samples
1381        )
1382
1383        importance_step_layer = attribution_abs_batch_mean
1384        importance_step_layer = importance_step_layer.detach()
1385
1386        return importance_step_layer
1387
1388    def get_importance_step_layer_feature_ablation_abs(
1389        self: str,
1390        layer_name: str,
1391        input: Tensor | tuple[Tensor, ...],
1392        layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1393        target: Tensor | None,
1394        batch_idx: int,
1395        num_batches: int,
1396        if_captum: bool = False,
1397    ) -> Tensor:
1398        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum.
1399
1400        **Args:**
1401        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1402        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1403        - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details.
1404        - **target** (`Tensor` | `None`): the target batch of the training step.
1405        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1406        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1407        - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster.
1408
1409        **Returns:**
1410        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1411        """
1412        layer = self.backbone.get_layer_by_name(layer_name)
1413
1414        if not if_captum:
1415            # 1. Baseline logits (take first element of forward output)
1416            baseline_out, _, _ = self.forward(
1417                input, "train", batch_idx, num_batches, self.task_id
1418            )
1419            if target is not None:
1420                baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1)
1421            else:
1422                baseline_scores = baseline_out.sum(dim=1)
1423
1424            # 2. Capture layer’s output shape
1425            activs = {}
1426            handle = layer.register_forward_hook(
1427                lambda module, inp, out: activs.setdefault("output", out.detach())
1428            )
1429            _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id)
1430            handle.remove()
1431            layer_output = activs["output"]  # shape (B, F, ...)
1432
1433            # 3. Build baseline tensor matching that shape
1434            if layer_baselines is None:
1435                baseline_tensor = torch.zeros_like(layer_output)
1436            elif isinstance(layer_baselines, (int, float)):
1437                baseline_tensor = torch.full_like(layer_output, layer_baselines)
1438            elif isinstance(layer_baselines, Tensor):
1439                if layer_baselines.shape == layer_output.shape:
1440                    baseline_tensor = layer_baselines
1441                elif layer_baselines.shape == layer_output.shape[1:]:
1442                    baseline_tensor = layer_baselines.unsqueeze(0).repeat(
1443                        layer_output.size(0), *([1] * layer_baselines.ndim)
1444                    )
1445                else:
1446                    raise ValueError(...)
1447            else:
1448                raise ValueError(...)
1449
1450            B, F = layer_output.size(0), layer_output.size(1)
1451
1452            # 4. Create a “mega-batch” replicating the input F times
1453            if isinstance(input, tuple):
1454                mega_inputs = tuple(
1455                    t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:])
1456                    for t in input
1457                )
1458            else:
1459                mega_inputs = (
1460                    input.unsqueeze(0)
1461                    .repeat(F, *([1] * input.ndim))
1462                    .view(-1, *input.shape[1:])
1463                )
1464
1465            # 5. Equally replicate the baseline tensor
1466            mega_baseline = (
1467                baseline_tensor.unsqueeze(0)
1468                .repeat(F, *([1] * baseline_tensor.ndim))
1469                .view(-1, *baseline_tensor.shape[1:])
1470            )
1471
1472            # 6. Precompute vectorized indices
1473            device = layer_output.device
1474            positions = torch.arange(F * B, device=device)  # [0,1,...,F*B-1]
1475            feat_idx = torch.arange(F, device=device).repeat_interleave(
1476                B
1477            )  # [0,0,...,1,1,...,F-1]
1478
1479            # 7. One hook to zero out each channel slice across the mega-batch
1480            def mega_ablate_hook(module, inp, out):
1481                out_mod = out.clone()
1482                # for each sample in mega-batch, zero its corresponding channel
1483                out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx]
1484                return out_mod
1485
1486            h = layer.register_forward_hook(mega_ablate_hook)
1487            out_all, _, _ = self.forward(
1488                mega_inputs, "train", batch_idx, num_batches, self.task_id
1489            )
1490            h.remove()
1491
1492            # 8. Recover scores, reshape [F*B] → [F, B], diff & mean
1493            if target is not None:
1494                tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1)
1495                scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1)
1496            else:
1497                scores_all = out_all.sum(dim=1)
1498
1499            scores_all = scores_all.view(F, B)
1500            diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all)
1501            importance_step_layer = diffs.mean(dim=1).detach()  # [F]
1502
1503            return importance_step_layer
1504
1505        else:
1506            # initialize the Layer Feature Ablation object
1507            layer_feature_ablation = LayerFeatureAblation(
1508                forward_func=self.forward, layer=layer
1509            )
1510
1511            # calculate layer attribution of the step
1512            self.set_forward_func_return_logits_only(True)
1513            attribution = layer_feature_ablation.attribute(
1514                inputs=input,
1515                layer_baselines=layer_baselines,
1516                # target=target, # disable target to enable perturbations_per_eval
1517                additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1518                perturbations_per_eval=128,  # to accelerate the computation
1519            )
1520            self.set_forward_func_return_logits_only(False)
1521
1522            attribution_abs_batch_mean = torch.mean(
1523                torch.abs(attribution),
1524                dim=[
1525                    i for i in range(attribution.dim()) if i != 1
1526                ],  # average the features over batch samples
1527            )
1528
1529        importance_step_layer = attribution_abs_batch_mean
1530        importance_step_layer = importance_step_layer.detach()
1531
1532        return importance_step_layer
1533
1534    def get_importance_step_layer_lrp_abs(
1535        self: str,
1536        layer_name: str,
1537        input: Tensor | tuple[Tensor, ...],
1538        target: Tensor | None,
1539        batch_idx: int,
1540        num_batches: int,
1541    ) -> Tensor:
1542        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum.
1543
1544        **Args:**
1545        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1546        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1547        - **target** (`Tensor` | `None`): the target batch of the training step.
1548        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1549        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1550
1551        **Returns:**
1552        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1553        """
1554        layer = self.backbone.get_layer_by_name(layer_name)
1555
1556        # initialize the Layer LRP object
1557        layer_lrp = LayerLRP(model=self, layer=layer)
1558
1559        # set model to evaluation mode to prevent updating the model parameters
1560        self.eval()
1561
1562        self.set_forward_func_return_logits_only(True)
1563        # calculate layer attribution of the step
1564        attribution = layer_lrp.attribute(
1565            inputs=input,
1566            target=target,
1567            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1568        )
1569        self.set_forward_func_return_logits_only(False)
1570
1571        attribution_abs_batch_mean = torch.mean(
1572            torch.abs(attribution),
1573            dim=[
1574                i for i in range(attribution.dim()) if i != 1
1575            ],  # average the features over batch samples
1576        )
1577
1578        importance_step_layer = attribution_abs_batch_mean
1579        importance_step_layer = importance_step_layer.detach()
1580
1581        return importance_step_layer
1582
1583    def get_importance_step_layer_cbp_adaptive_contribution(
1584        self: str,
1585        layer_name: str,
1586        activation: Tensor,
1587    ) -> Tensor:
1588        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
1589
1590        **Args:**
1591        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1592        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
1593
1594        **Returns:**
1595        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1596        """
1597        layer = self.backbone.get_layer_by_name(layer_name)
1598
1599        input_weight_abs = torch.abs(layer.weight.data)
1600        input_weight_abs_sum = torch.sum(
1601            input_weight_abs,
1602            dim=[
1603                i for i in range(input_weight_abs.dim()) if i != 0
1604            ],  # sum over the input dimension
1605        )
1606        input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum)
1607
1608        output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
1609        output_weight_abs_sum = torch.sum(
1610            output_weight_abs,
1611            dim=[
1612                i for i in range(output_weight_abs.dim()) if i != 1
1613            ],  # sum over the output dimension
1614        )
1615
1616        activation_abs_batch_mean = torch.mean(
1617            torch.abs(activation),
1618            dim=[
1619                i for i in range(activation.dim()) if i != 1
1620            ],  # average the features over batch samples
1621        )
1622
1623        importance_step_layer = (
1624            output_weight_abs_sum
1625            * activation_abs_batch_mean
1626            * input_weight_abs_sum_reciprocal
1627        )
1628        importance_step_layer = importance_step_layer.detach()
1629
1630        return importance_step_layer
class FGAdaHAT(clarena.cl_algorithms.adahat.AdaHAT):
  37class FGAdaHAT(AdaHAT):
  38    r"""FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm.
  39
  40    An architecture-based continual learning approach that improves [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
  41
  42    We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components.
  43    """
  44
  45    def __init__(
  46        self,
  47        backbone: HATMaskBackbone,
  48        heads: HeadsTIL | HeadDIL,
  49        adjustment_intensity: float,
  50        importance_type: str,
  51        importance_summing_strategy: str,
  52        importance_scheduler_type: str,
  53        neuron_to_weight_importance_aggregation_mode: str,
  54        s_max: float,
  55        clamp_threshold: float,
  56        mask_sparsity_reg_factor: float,
  57        mask_sparsity_reg_mode: str = "original",
  58        base_importance: float = 0.01,
  59        base_mask_sparsity_reg: float = 0.1,
  60        base_linear: float = 10,
  61        filter_by_cumulative_mask: bool = False,
  62        filter_unmasked_importance: bool = True,
  63        step_multiply_training_mask: bool = True,
  64        task_embedding_init_mode: str = "N01",
  65        importance_summing_strategy_linear_step: float | None = None,
  66        importance_summing_strategy_exponential_rate: float | None = None,
  67        importance_summing_strategy_log_base: float | None = None,
  68        non_algorithmic_hparams: dict[str, Any] = {},
  69        **kwargs,
  70    ) -> None:
  71        r"""Initialize the FG-AdaHAT algorithm with the network.
  72
  73        **Args:**
  74        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
  75        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. FG-AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
  76        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper).
  77        - **importance_type** (`str`): the type of neuron-wise importance, must be one of:
  78            1. 'input_weight_abs_sum': sum of absolute input weights;
  79            2. 'output_weight_abs_sum': sum of absolute output weights;
  80            3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
  81            4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
  82            5. 'activation_abs': absolute activation;
  83            6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
  84            7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
  85            8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
  86            9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
  87            10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
  88            11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
  89            12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
  90            13. 'conductance_abs': absolute layer conductance;
  91            14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
  92            15. 'gradcam_abs': absolute Grad-CAM;
  93            16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
  94            17. 'deepliftshap_abs': absolute DeepLIFT-SHAP;
  95            18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
  96            19. 'integrated_gradients_abs': absolute Integrated Gradients;
  97            20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
  98            21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
  99            22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
 100            23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
 101        - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of:
 102            1. 'add_latest': add the latest neuron-wise importance to the summative importance;
 103            2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
 104            3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
 105            4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
 106            5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
 107            6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
 108            7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
 109            8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
 110            9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
 111        - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:
 112            1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`);
 113            2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$;
 114            3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$.
 115        - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:
 116            1. 'min': take the minimum of neuron-wise importance for each weight;
 117            2. 'max': take the maximum of neuron-wise importance for each weight;
 118            3. 'mean': take the mean of neuron-wise importance for each weight.
 119        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 120        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 121        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 122        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 123            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 124            2. 'cross': the cross version of mask sparsity regularization.
 125        - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01.
 126        - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1.
 127        - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10.
 128        - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False.
 129        - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False.
 130        - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True.
 131        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 132            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 133            2. 'U-11': uniform distribution $U(-1, 1)$.
 134            3. 'U01': uniform distribution $U(0, 1)$.
 135            4. 'U-10': uniform distribution $U(-1, 0)$.
 136            5. 'last': inherit the task embedding from the last task.
 137        - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0.
 138        - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1.
 139        - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1.
 140        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 141        - **kwargs**: Reserved for multiple inheritance.
 142
 143        """
 144        super().__init__(
 145            backbone=backbone,
 146            heads=heads,
 147            adjustment_mode=None,  # use the own adjustment mechanism of FG-AdaHAT
 148            adjustment_intensity=adjustment_intensity,
 149            s_max=s_max,
 150            clamp_threshold=clamp_threshold,
 151            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 152            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 153            task_embedding_init_mode=task_embedding_init_mode,
 154            epsilon=base_mask_sparsity_reg,  # the epsilon is now the base mask sparsity regularization factor
 155            non_algorithmic_hparams=non_algorithmic_hparams,
 156            **kwargs,
 157        )
 158
 159        # save additional algorithmic hyperparameters
 160        self.save_hyperparameters(
 161            "adjustment_intensity",
 162            "importance_type",
 163            "importance_summing_strategy",
 164            "importance_scheduler_type",
 165            "neuron_to_weight_importance_aggregation_mode",
 166            "s_max",
 167            "clamp_threshold",
 168            "mask_sparsity_reg_factor",
 169            "mask_sparsity_reg_mode",
 170            "base_importance",
 171            "base_mask_sparsity_reg",
 172            "base_linear",
 173            "filter_by_cumulative_mask",
 174            "filter_unmasked_importance",
 175            "step_multiply_training_mask",
 176        )
 177
 178        self.importance_type: str | None = importance_type
 179        r"""The type of the neuron-wise importance added to AdaHAT importance."""
 180
 181        self.importance_scheduler_type: str = importance_scheduler_type
 182        r"""The type of the importance scheduler."""
 183        self.neuron_to_weight_importance_aggregation_mode: str = (
 184            neuron_to_weight_importance_aggregation_mode
 185        )
 186        r"""The mode of aggregation from neuron-wise to weight-wise importance. """
 187        self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask
 188        r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate."""
 189        self.filter_unmasked_importance: bool = filter_unmasked_importance
 190        r"""The flag to filter unmasked importance values (set them to 0) at the end of task training."""
 191        self.step_multiply_training_mask: bool = step_multiply_training_mask
 192        r"""The flag to multiply the training mask to the importance at each training step."""
 193
 194        # importance summing strategy
 195        self.importance_summing_strategy: str = importance_summing_strategy
 196        r"""The strategy to sum the neuron-wise importance for previous tasks."""
 197        if importance_summing_strategy_linear_step is not None:
 198            self.importance_summing_strategy_linear_step: float = (
 199                importance_summing_strategy_linear_step
 200            )
 201            r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease')."""
 202        if importance_summing_strategy_exponential_rate is not None:
 203            self.importance_summing_strategy_exponential_rate: float = (
 204                importance_summing_strategy_exponential_rate
 205            )
 206            r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """
 207        if importance_summing_strategy_log_base is not None:
 208            self.importance_summing_strategy_log_base: float = (
 209                importance_summing_strategy_log_base
 210            )
 211            r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """
 212
 213        # base values
 214        self.base_importance: float = base_importance
 215        r"""The base value added to the importance to avoid zero. """
 216        self.base_mask_sparsity_reg: float = base_mask_sparsity_reg
 217        r"""The base value added to the mask sparsity regularization to avoid zero. """
 218        self.base_linear: float = base_linear
 219        r"""The base value added to the linear layer to avoid zero. """
 220
 221        self.importances: dict[int, dict[str, Tensor]] = {}
 222        r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
 223        self.summative_importance_for_previous_tasks: dict[str, Tensor] = {}
 224        r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """
 225
 226        self.num_steps_t: int
 227        r"""The number of training steps for the current task `self.task_id`."""
 228        # set manual optimization
 229        self.automatic_optimization = False
 230
 231        FGAdaHAT.sanity_check(self)
 232
 233    def sanity_check(self) -> None:
 234        r"""Sanity check."""
 235
 236        # check importance type
 237        if self.importance_type not in [
 238            "input_weight_abs_sum",
 239            "output_weight_abs_sum",
 240            "input_weight_gradient_abs_sum",
 241            "output_weight_gradient_abs_sum",
 242            "activation_abs",
 243            "input_weight_abs_sum_x_activation_abs",
 244            "output_weight_abs_sum_x_activation_abs",
 245            "gradient_x_activation_abs",
 246            "input_weight_gradient_square_sum",
 247            "output_weight_gradient_square_sum",
 248            "input_weight_gradient_square_sum_x_activation_abs",
 249            "output_weight_gradient_square_sum_x_activation_abs",
 250            "conductance_abs",
 251            "internal_influence_abs",
 252            "gradcam_abs",
 253            "deeplift_abs",
 254            "deepliftshap_abs",
 255            "gradientshap_abs",
 256            "integrated_gradients_abs",
 257            "feature_ablation_abs",
 258            "lrp_abs",
 259            "cbp_adaptation",
 260            "cbp_adaptive_contribution",
 261        ]:
 262            raise ValueError(
 263                f"importance_type must be one of the predefined types, but got {self.importance_type}"
 264            )
 265
 266        # check importance summing strategy
 267        if self.importance_summing_strategy not in [
 268            "add_latest",
 269            "add_all",
 270            "add_average",
 271            "linear_decrease",
 272            "quadratic_decrease",
 273            "cubic_decrease",
 274            "exponential_decrease",
 275            "log_decrease",
 276            "factorial_decrease",
 277        ]:
 278            raise ValueError(
 279                f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}"
 280            )
 281
 282        # check importance scheduler type
 283        if self.importance_scheduler_type not in [
 284            "linear_sparsity_reg",
 285            "sparsity_reg",
 286            "summative_mask_sparsity_reg",
 287        ]:
 288            raise ValueError(
 289                f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}"
 290            )
 291
 292        # check neuron to weight importance aggregation mode
 293        if self.neuron_to_weight_importance_aggregation_mode not in [
 294            "min",
 295            "max",
 296            "mean",
 297        ]:
 298            raise ValueError(
 299                f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}"
 300            )
 301
 302        # check base values
 303        if self.base_importance < 0:
 304            raise ValueError(
 305                f"base_importance must be >= 0, but got {self.base_importance}"
 306            )
 307        if self.base_mask_sparsity_reg <= 0:
 308            raise ValueError(
 309                f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}"
 310            )
 311        if self.base_linear <= 0:
 312            raise ValueError(f"base_linear must be > 0, but got {self.base_linear}")
 313
 314    def on_train_start(self) -> None:
 315        r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization."""
 316        super().on_train_start()
 317
 318        self.importances[self.task_id] = (
 319            {}
 320        )  # initialize the importance for the current task
 321
 322        # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time.
 323        for layer_name in self.backbone.weighted_layer_names:
 324            layer = self.backbone.get_layer_by_name(
 325                layer_name
 326            )  # get the layer by its name
 327            num_units = layer.weight.shape[0]
 328
 329            # initialize the accumulated importance at the beginning of each task
 330            self.importances[self.task_id][layer_name] = torch.zeros(num_units).to(
 331                self.device
 332            )
 333
 334            # reset the number of steps counter for the current task
 335            self.num_steps_t = 0
 336
 337            # initialize the summative neuron-wise importance at the beginning of the first task
 338            if self.task_id == 1:
 339                self.summative_importance_for_previous_tasks[layer_name] = torch.zeros(
 340                    num_units
 341                ).to(
 342                    self.device
 343                )  # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$
 344
 345    def clip_grad_by_adjustment(
 346        self,
 347        network_sparsity: dict[str, Tensor],
 348    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
 349        r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper.
 350
 351        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
 352
 353        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 354
 355        **Args:**
 356        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.
 357
 358        **Returns:**
 359        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
 360        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
 361        - **capacity** (`Tensor`): the calculated network capacity.
 362        """
 363
 364        # initialize network capacity metric
 365        capacity = HATNetworkCapacityMetric().to(self.device)
 366        adjustment_rate_weight = {}
 367        adjustment_rate_bias = {}
 368
 369        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper
 370        for layer_name in self.backbone.weighted_layer_names:
 371
 372            layer = self.backbone.get_layer_by_name(
 373                layer_name
 374            )  # get the layer by its name
 375
 376            # placeholder for the adjustment rate to avoid the error of using it before assignment
 377            adjustment_rate_weight_layer = 1
 378            adjustment_rate_bias_layer = 1
 379
 380            # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask
 381            weight_importance, bias_importance = (
 382                self.backbone.get_layer_measure_parameter_wise(
 383                    neuron_wise_measure=self.summative_importance_for_previous_tasks,
 384                    layer_name=layer_name,
 385                    aggregation_mode=self.neuron_to_weight_importance_aggregation_mode,
 386                )
 387            )
 388
 389            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
 390                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
 391                layer_name=layer_name,
 392                aggregation_mode="min",
 393            )
 394
 395            # filter the weight importance by the cumulative mask
 396            if self.filter_by_cumulative_mask:
 397                weight_importance = weight_importance * weight_mask
 398                bias_importance = bias_importance * bias_mask
 399
 400            network_sparsity_layer = network_sparsity[layer_name]
 401
 402            # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper
 403            factor = network_sparsity_layer + self.base_mask_sparsity_reg
 404            if self.importance_scheduler_type == "linear_sparsity_reg":
 405                factor = factor * (self.task_id + self.base_linear)
 406            elif self.importance_scheduler_type == "sparsity_reg":
 407                pass
 408            elif self.importance_scheduler_type == "summative_mask_sparsity_reg":
 409                factor = factor * (
 410                    self.summative_mask_for_previous_tasks + self.base_linear
 411                )
 412
 413            # calculate the adjustment rate
 414            adjustment_rate_weight_layer = torch.div(
 415                self.adjustment_intensity,
 416                (factor * weight_importance + self.adjustment_intensity),
 417            )
 418
 419            adjustment_rate_bias_layer = torch.div(
 420                self.adjustment_intensity,
 421                (factor * bias_importance + self.adjustment_intensity),
 422            )
 423
 424            # apply the adjustment rate to the gradients
 425            layer.weight.grad.data *= adjustment_rate_weight_layer
 426            if layer.bias is not None:
 427                layer.bias.grad.data *= adjustment_rate_bias_layer
 428
 429            # store the adjustment rate for logging
 430            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
 431            if layer.bias is not None:
 432                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
 433
 434            # update network capacity metric
 435            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
 436
 437        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
 438
 439    def on_train_batch_end(
 440        self, outputs: dict[str, Any], batch: Any, batch_idx: int
 441    ) -> None:
 442        r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.
 443
 444        **Args:**
 445        - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`).
 446        - **batch** (`Any`): training data batch.
 447        - **batch_idx** (`int`): index of the current batch (for mask figure file name).
 448        """
 449
 450        # get potential useful information from training batch
 451        activations = outputs["activations"]
 452        input = outputs["input"]
 453        target = outputs["target"]
 454        mask = outputs["mask"]
 455        num_batches = self.trainer.num_training_batches
 456
 457        for layer_name in self.backbone.weighted_layer_names:
 458            # layer-wise operation
 459
 460            activation = activations[layer_name]
 461
 462            # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper.
 463            if self.importance_type == "input_weight_abs_sum":
 464                importance_step = self.get_importance_step_layer_weight_abs_sum(
 465                    layer_name=layer_name,
 466                    if_output_weight=False,
 467                    reciprocal=False,
 468                )
 469            elif self.importance_type == "output_weight_abs_sum":
 470                importance_step = self.get_importance_step_layer_weight_abs_sum(
 471                    layer_name=layer_name,
 472                    if_output_weight=True,
 473                    reciprocal=False,
 474                )
 475            elif self.importance_type == "input_weight_gradient_abs_sum":
 476                importance_step = (
 477                    self.get_importance_step_layer_weight_gradient_abs_sum(
 478                        layer_name=layer_name, if_output_weight=False
 479                    )
 480                )
 481            elif self.importance_type == "output_weight_gradient_abs_sum":
 482                importance_step = (
 483                    self.get_importance_step_layer_weight_gradient_abs_sum(
 484                        layer_name=layer_name, if_output_weight=True
 485                    )
 486                )
 487            elif self.importance_type == "activation_abs":
 488                importance_step = self.get_importance_step_layer_activation_abs(
 489                    activation=activation
 490                )
 491            elif self.importance_type == "input_weight_abs_sum_x_activation_abs":
 492                importance_step = (
 493                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
 494                        layer_name=layer_name,
 495                        activation=activation,
 496                        if_output_weight=False,
 497                    )
 498                )
 499            elif self.importance_type == "output_weight_abs_sum_x_activation_abs":
 500                importance_step = (
 501                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
 502                        layer_name=layer_name,
 503                        activation=activation,
 504                        if_output_weight=True,
 505                    )
 506                )
 507            elif self.importance_type == "gradient_x_activation_abs":
 508                importance_step = (
 509                    self.get_importance_step_layer_gradient_x_activation_abs(
 510                        layer_name=layer_name,
 511                        input=input,
 512                        target=target,
 513                        batch_idx=batch_idx,
 514                        num_batches=num_batches,
 515                    )
 516                )
 517            elif self.importance_type == "input_weight_gradient_square_sum":
 518                importance_step = (
 519                    self.get_importance_step_layer_weight_gradient_square_sum(
 520                        layer_name=layer_name,
 521                        activation=activation,
 522                        if_output_weight=False,
 523                    )
 524                )
 525            elif self.importance_type == "output_weight_gradient_square_sum":
 526                importance_step = (
 527                    self.get_importance_step_layer_weight_gradient_square_sum(
 528                        layer_name=layer_name,
 529                        activation=activation,
 530                        if_output_weight=True,
 531                    )
 532                )
 533            elif (
 534                self.importance_type
 535                == "input_weight_gradient_square_sum_x_activation_abs"
 536            ):
 537                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 538                    layer_name=layer_name,
 539                    activation=activation,
 540                    if_output_weight=False,
 541                )
 542            elif (
 543                self.importance_type
 544                == "output_weight_gradient_square_sum_x_activation_abs"
 545            ):
 546                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 547                    layer_name=layer_name,
 548                    activation=activation,
 549                    if_output_weight=True,
 550                )
 551            elif self.importance_type == "conductance_abs":
 552                importance_step = self.get_importance_step_layer_conductance_abs(
 553                    layer_name=layer_name,
 554                    input=input,
 555                    baselines=None,
 556                    target=target,
 557                    batch_idx=batch_idx,
 558                    num_batches=num_batches,
 559                )
 560            elif self.importance_type == "internal_influence_abs":
 561                importance_step = self.get_importance_step_layer_internal_influence_abs(
 562                    layer_name=layer_name,
 563                    input=input,
 564                    baselines=None,
 565                    target=target,
 566                    batch_idx=batch_idx,
 567                    num_batches=num_batches,
 568                )
 569            elif self.importance_type == "gradcam_abs":
 570                importance_step = self.get_importance_step_layer_gradcam_abs(
 571                    layer_name=layer_name,
 572                    input=input,
 573                    target=target,
 574                    batch_idx=batch_idx,
 575                    num_batches=num_batches,
 576                )
 577            elif self.importance_type == "deeplift_abs":
 578                importance_step = self.get_importance_step_layer_deeplift_abs(
 579                    layer_name=layer_name,
 580                    input=input,
 581                    baselines=None,
 582                    target=target,
 583                    batch_idx=batch_idx,
 584                    num_batches=num_batches,
 585                )
 586            elif self.importance_type == "deepliftshap_abs":
 587                importance_step = self.get_importance_step_layer_deepliftshap_abs(
 588                    layer_name=layer_name,
 589                    input=input,
 590                    baselines=None,
 591                    target=target,
 592                    batch_idx=batch_idx,
 593                    num_batches=num_batches,
 594                )
 595            elif self.importance_type == "gradientshap_abs":
 596                importance_step = self.get_importance_step_layer_gradientshap_abs(
 597                    layer_name=layer_name,
 598                    input=input,
 599                    baselines=None,
 600                    target=target,
 601                    batch_idx=batch_idx,
 602                    num_batches=num_batches,
 603                )
 604            elif self.importance_type == "integrated_gradients_abs":
 605                importance_step = (
 606                    self.get_importance_step_layer_integrated_gradients_abs(
 607                        layer_name=layer_name,
 608                        input=input,
 609                        baselines=None,
 610                        target=target,
 611                        batch_idx=batch_idx,
 612                        num_batches=num_batches,
 613                    )
 614                )
 615            elif self.importance_type == "feature_ablation_abs":
 616                importance_step = self.get_importance_step_layer_feature_ablation_abs(
 617                    layer_name=layer_name,
 618                    input=input,
 619                    layer_baselines=None,
 620                    target=target,
 621                    batch_idx=batch_idx,
 622                    num_batches=num_batches,
 623                )
 624            elif self.importance_type == "lrp_abs":
 625                importance_step = self.get_importance_step_layer_lrp_abs(
 626                    layer_name=layer_name,
 627                    input=input,
 628                    target=target,
 629                    batch_idx=batch_idx,
 630                    num_batches=num_batches,
 631                )
 632            elif self.importance_type == "cbp_adaptation":
 633                importance_step = self.get_importance_step_layer_weight_abs_sum(
 634                    layer_name=layer_name,
 635                    if_output_weight=False,
 636                    reciprocal=True,
 637                )
 638            elif self.importance_type == "cbp_adaptive_contribution":
 639                importance_step = (
 640                    self.get_importance_step_layer_cbp_adaptive_contribution(
 641                        layer_name=layer_name,
 642                        activation=activation,
 643                    )
 644                )
 645
 646            importance_step = min_max_normalize(
 647                importance_step
 648            )  # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper
 649
 650            # multiply the importance by the training mask. See Eq. (6) in the paper
 651            if self.step_multiply_training_mask:
 652                importance_step = importance_step * mask[layer_name]
 653
 654            # update accumulated importance
 655            self.importances[self.task_id][layer_name] = (
 656                self.importances[self.task_id][layer_name] + importance_step
 657            )
 658
 659        # update number of steps counter
 660        self.num_steps_t += 1
 661
 662    def on_train_end(self) -> None:
 663        r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task."""
 664        super().on_train_end()  # store the mask and update cumulative and summative masks
 665
 666        for layer_name in self.backbone.weighted_layer_names:
 667
 668            # average the neuron-wise step importance. See Eq. (4) in the paper
 669            self.importances[self.task_id][layer_name] = (
 670                self.importances[self.task_id][layer_name]
 671            ) / self.num_steps_t
 672
 673            # add the base importance. See Eq. (6) in the paper
 674            self.importances[self.task_id][layer_name] = (
 675                self.importances[self.task_id][layer_name] + self.base_importance
 676            )
 677
 678            # filter unmasked importance
 679            if self.filter_unmasked_importance:
 680                self.importances[self.task_id][layer_name] = (
 681                    self.importances[self.task_id][layer_name]
 682                    * self.backbone.masks[f"{self.task_id}"][layer_name]
 683                )
 684
 685            # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper
 686            if self.importance_summing_strategy == "add_latest":
 687                self.summative_importance_for_previous_tasks[
 688                    layer_name
 689                ] += self.importances[self.task_id][layer_name]
 690
 691            elif self.importance_summing_strategy == "add_all":
 692                for t in range(1, self.task_id + 1):
 693                    self.summative_importance_for_previous_tasks[
 694                        layer_name
 695                    ] += self.importances[t][layer_name]
 696
 697            elif self.importance_summing_strategy == "add_average":
 698                for t in range(1, self.task_id + 1):
 699                    self.summative_importance_for_previous_tasks[layer_name] += (
 700                        self.importances[t][layer_name] / self.task_id
 701                    )
 702            else:
 703                self.summative_importance_for_previous_tasks[
 704                    layer_name
 705                ] = torch.zeros_like(
 706                    self.summative_importance_for_previous_tasks[layer_name]
 707                ).to(
 708                    self.device
 709                )  # starting adding from 0
 710
 711                if self.importance_summing_strategy == "linear_decrease":
 712                    s = self.importance_summing_strategy_linear_step
 713                    for t in range(1, self.task_id + 1):
 714                        w_t = s * (self.task_id - t) + 1
 715
 716                elif self.importance_summing_strategy == "quadratic_decrease":
 717                    for t in range(1, self.task_id + 1):
 718                        w_t = (self.task_id - t + 1) ** 2
 719                elif self.importance_summing_strategy == "cubic_decrease":
 720                    for t in range(1, self.task_id + 1):
 721                        w_t = (self.task_id - t + 1) ** 3
 722                elif self.importance_summing_strategy == "exponential_decrease":
 723                    for t in range(1, self.task_id + 1):
 724                        r = self.importance_summing_strategy_exponential_rate
 725
 726                        w_t = r ** (self.task_id - t + 1)
 727                elif self.importance_summing_strategy == "log_decrease":
 728                    a = self.importance_summing_strategy_log_base
 729                    for t in range(1, self.task_id + 1):
 730                        w_t = math.log(self.task_id - t, a) + 1
 731                elif self.importance_summing_strategy == "factorial_decrease":
 732                    for t in range(1, self.task_id + 1):
 733                        w_t = math.factorial(self.task_id - t + 1)
 734                else:
 735                    raise ValueError
 736                self.summative_importance_for_previous_tasks[layer_name] += (
 737                    self.importances[t][layer_name] * w_t
 738                )
 739
 740    def get_importance_step_layer_weight_abs_sum(
 741        self: str,
 742        layer_name: str,
 743        if_output_weight: bool,
 744        reciprocal: bool,
 745    ) -> Tensor:
 746        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.
 747
 748        **Args:**
 749        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 750        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 751        - **reciprocal** (`bool`): whether to take reciprocal.
 752
 753        **Returns:**
 754        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 755        """
 756        layer = self.backbone.get_layer_by_name(layer_name)
 757
 758        if not if_output_weight:
 759            weight_abs = torch.abs(layer.weight.data)
 760            weight_abs_sum = torch.sum(
 761                weight_abs,
 762                dim=[
 763                    i for i in range(weight_abs.dim()) if i != 0
 764                ],  # sum over the input dimension
 765            )
 766        else:
 767            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
 768            weight_abs_sum = torch.sum(
 769                weight_abs,
 770                dim=[
 771                    i for i in range(weight_abs.dim()) if i != 1
 772                ],  # sum over the output dimension
 773            )
 774
 775        if reciprocal:
 776            weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum)
 777            importance_step_layer = weight_abs_sum_reciprocal
 778        else:
 779            importance_step_layer = weight_abs_sum
 780        importance_step_layer = importance_step_layer.detach()
 781
 782        return importance_step_layer
 783
 784    def get_importance_step_layer_weight_gradient_abs_sum(
 785        self: str,
 786        layer_name: str,
 787        if_output_weight: bool,
 788    ) -> Tensor:
 789        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.
 790
 791        **Args:**
 792        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 793        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 794
 795        **Returns:**
 796        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 797        """
 798        layer = self.backbone.get_layer_by_name(layer_name)
 799
 800        if not if_output_weight:
 801            gradient_abs = torch.abs(layer.weight.grad.data)
 802            gradient_abs_sum = torch.sum(
 803                gradient_abs,
 804                dim=[
 805                    i for i in range(gradient_abs.dim()) if i != 0
 806                ],  # sum over the input dimension
 807            )
 808        else:
 809            gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data)
 810            gradient_abs_sum = torch.sum(
 811                gradient_abs,
 812                dim=[
 813                    i for i in range(gradient_abs.dim()) if i != 1
 814                ],  # sum over the output dimension
 815            )
 816
 817        importance_step_layer = gradient_abs_sum
 818        importance_step_layer = importance_step_layer.detach()
 819
 820        return importance_step_layer
 821
 822    def get_importance_step_layer_activation_abs(
 823        self: str,
 824        activation: Tensor,
 825    ) -> Tensor:
 826        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum.
 827
 828        **Args:**
 829        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 830
 831        **Returns:**
 832        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 833        """
 834        activation_abs_batch_mean = torch.mean(
 835            torch.abs(activation),
 836            dim=[
 837                i for i in range(activation.dim()) if i != 1
 838            ],  # average the features over batch samples
 839        )
 840        importance_step_layer = activation_abs_batch_mean
 841        importance_step_layer = importance_step_layer.detach()
 842
 843        return importance_step_layer
 844
 845    def get_importance_step_layer_weight_abs_sum_x_activation_abs(
 846        self: str,
 847        layer_name: str,
 848        activation: Tensor,
 849        if_output_weight: bool,
 850    ) -> Tensor:
 851        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
 852
 853        **Args:**
 854        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 855        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 856        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 857
 858        **Returns:**
 859        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 860        """
 861        layer = self.backbone.get_layer_by_name(layer_name)
 862
 863        if not if_output_weight:
 864            weight_abs = torch.abs(layer.weight.data)
 865            weight_abs_sum = torch.sum(
 866                weight_abs,
 867                dim=[
 868                    i for i in range(weight_abs.dim()) if i != 0
 869                ],  # sum over the input dimension
 870            )
 871        else:
 872            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
 873            weight_abs_sum = torch.sum(
 874                weight_abs,
 875                dim=[
 876                    i for i in range(weight_abs.dim()) if i != 1
 877                ],  # sum over the output dimension
 878            )
 879
 880        activation_abs_batch_mean = torch.mean(
 881            torch.abs(activation),
 882            dim=[
 883                i for i in range(activation.dim()) if i != 1
 884            ],  # average the features over batch samples
 885        )
 886
 887        importance_step_layer = weight_abs_sum * activation_abs_batch_mean
 888        importance_step_layer = importance_step_layer.detach()
 889
 890        return importance_step_layer
 891
 892    def get_importance_step_layer_gradient_x_activation_abs(
 893        self: str,
 894        layer_name: str,
 895        input: Tensor | tuple[Tensor, ...],
 896        target: Tensor | None,
 897        batch_idx: int,
 898        num_batches: int,
 899    ) -> Tensor:
 900        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum.
 901
 902        **Args:**
 903        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 904        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
 905        - **target** (`Tensor` | `None`): the target batch of the training step.
 906        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
 907        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
 908
 909        **Returns:**
 910        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 911        """
 912        layer = self.backbone.get_layer_by_name(layer_name)
 913
 914        input = input.requires_grad_()
 915
 916        # initialize the Layer Gradient X Activation object
 917        layer_gradient_x_activation = LayerGradientXActivation(
 918            forward_func=self.forward, layer=layer
 919        )
 920
 921        self.set_forward_func_return_logits_only(True)
 922        # calculate layer attribution of the step
 923        attribution = layer_gradient_x_activation.attribute(
 924            inputs=input,
 925            target=target,
 926            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
 927        )
 928        self.set_forward_func_return_logits_only(False)
 929
 930        attribution_abs_batch_mean = torch.mean(
 931            torch.abs(attribution),
 932            dim=[
 933                i for i in range(attribution.dim()) if i != 1
 934            ],  # average the features over batch samples
 935        )
 936
 937        importance_step_layer = attribution_abs_batch_mean
 938        importance_step_layer = importance_step_layer.detach()
 939
 940        return importance_step_layer
 941
 942    def get_importance_step_layer_weight_gradient_square_sum(
 943        self: str,
 944        layer_name: str,
 945        activation: Tensor,
 946        if_output_weight: bool,
 947    ) -> Tensor:
 948        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 949
 950        **Args:**
 951        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 952        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 953        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 954
 955        **Returns:**
 956        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 957        """
 958        layer = self.backbone.get_layer_by_name(layer_name)
 959
 960        if not if_output_weight:
 961            gradient_square = layer.weight.grad.data**2
 962            gradient_square_sum = torch.sum(
 963                gradient_square,
 964                dim=[
 965                    i for i in range(gradient_square.dim()) if i != 0
 966                ],  # sum over the input dimension
 967            )
 968        else:
 969            gradient_square = self.next_layer(layer_name).weight.grad.data**2
 970            gradient_square_sum = torch.sum(
 971                gradient_square,
 972                dim=[
 973                    i for i in range(gradient_square.dim()) if i != 1
 974                ],  # sum over the output dimension
 975            )
 976
 977        importance_step_layer = gradient_square_sum
 978        importance_step_layer = importance_step_layer.detach()
 979
 980        return importance_step_layer
 981
 982    def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 983        self: str,
 984        layer_name: str,
 985        activation: Tensor,
 986        if_output_weight: bool,
 987    ) -> Tensor:
 988        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 989
 990        **Args:**
 991        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 992        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 993        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 994
 995        **Returns:**
 996        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 997        """
 998        layer = self.backbone.get_layer_by_name(layer_name)
 999
1000        if not if_output_weight:
1001            gradient_square = layer.weight.grad.data**2
1002            gradient_square_sum = torch.sum(
1003                gradient_square,
1004                dim=[
1005                    i for i in range(gradient_square.dim()) if i != 0
1006                ],  # sum over the input dimension
1007            )
1008        else:
1009            gradient_square = self.next_layer(layer_name).weight.grad.data**2
1010            gradient_square_sum = torch.sum(
1011                gradient_square,
1012                dim=[
1013                    i for i in range(gradient_square.dim()) if i != 1
1014                ],  # sum over the output dimension
1015            )
1016
1017        activation_abs_batch_mean = torch.mean(
1018            torch.abs(activation),
1019            dim=[
1020                i for i in range(activation.dim()) if i != 1
1021            ],  # average the features over batch samples
1022        )
1023
1024        importance_step_layer = gradient_square_sum * activation_abs_batch_mean
1025        importance_step_layer = importance_step_layer.detach()
1026
1027        return importance_step_layer
1028
1029    def get_importance_step_layer_conductance_abs(
1030        self: str,
1031        layer_name: str,
1032        input: Tensor | tuple[Tensor, ...],
1033        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1034        target: Tensor | None,
1035        batch_idx: int,
1036        num_batches: int,
1037    ) -> Tensor:
1038        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum.
1039
1040        **Args:**
1041        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1042        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1043        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details.
1044        - **target** (`Tensor` | `None`): the target batch of the training step.
1045        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1046        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).
1047
1048        **Returns:**
1049        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1050        """
1051        layer = self.backbone.get_layer_by_name(layer_name)
1052
1053        # initialize the Layer Conductance object
1054        layer_conductance = LayerConductance(forward_func=self.forward, layer=layer)
1055
1056        self.set_forward_func_return_logits_only(True)
1057        # calculate layer attribution of the step
1058        attribution = layer_conductance.attribute(
1059            inputs=input,
1060            baselines=baselines,
1061            target=target,
1062            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1063        )
1064        self.set_forward_func_return_logits_only(False)
1065
1066        attribution_abs_batch_mean = torch.mean(
1067            torch.abs(attribution),
1068            dim=[
1069                i for i in range(attribution.dim()) if i != 1
1070            ],  # average the features over batch samples
1071        )
1072
1073        importance_step_layer = attribution_abs_batch_mean
1074        importance_step_layer = importance_step_layer.detach()
1075
1076        return importance_step_layer
1077
1078    def get_importance_step_layer_internal_influence_abs(
1079        self: str,
1080        layer_name: str,
1081        input: Tensor | tuple[Tensor, ...],
1082        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1083        target: Tensor | None,
1084        batch_idx: int,
1085        num_batches: int,
1086    ) -> Tensor:
1087        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum.
1088
1089        **Args:**
1090        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1091        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1092        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details.
1093        - **target** (`Tensor` | `None`): the target batch of the training step.
1094        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1095        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1096
1097        **Returns:**
1098        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1099        """
1100        layer = self.backbone.get_layer_by_name(layer_name)
1101
1102        # initialize the Internal Influence object
1103        internal_influence = InternalInfluence(forward_func=self.forward, layer=layer)
1104
1105        # convert the target to long type to avoid error
1106        target = target.long() if target is not None else None
1107
1108        self.set_forward_func_return_logits_only(True)
1109        # calculate layer attribution of the step
1110        attribution = internal_influence.attribute(
1111            inputs=input,
1112            baselines=baselines,
1113            target=target,
1114            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1115            n_steps=5,  # set 10 instead of default 50 to accelerate the computation
1116        )
1117        self.set_forward_func_return_logits_only(False)
1118
1119        attribution_abs_batch_mean = torch.mean(
1120            torch.abs(attribution),
1121            dim=[
1122                i for i in range(attribution.dim()) if i != 1
1123            ],  # average the features over batch samples
1124        )
1125
1126        importance_step_layer = attribution_abs_batch_mean
1127        importance_step_layer = importance_step_layer.detach()
1128
1129        return importance_step_layer
1130
1131    def get_importance_step_layer_gradcam_abs(
1132        self: str,
1133        layer_name: str,
1134        input: Tensor | tuple[Tensor, ...],
1135        target: Tensor | None,
1136        batch_idx: int,
1137        num_batches: int,
1138    ) -> Tensor:
1139        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum.
1140
1141        **Args:**
1142        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1143        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1144        - **target** (`Tensor` | `None`): the target batch of the training step.
1145        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1146        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1147
1148        **Returns:**
1149        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1150        """
1151        layer = self.backbone.get_layer_by_name(layer_name)
1152
1153        # initialize the GradCAM object
1154        gradcam = LayerGradCam(forward_func=self.forward, layer=layer)
1155
1156        self.set_forward_func_return_logits_only(True)
1157        # calculate layer attribution of the step
1158        attribution = gradcam.attribute(
1159            inputs=input,
1160            target=target,
1161            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1162        )
1163        self.set_forward_func_return_logits_only(False)
1164
1165        attribution_abs_batch_mean = torch.mean(
1166            torch.abs(attribution),
1167            dim=[
1168                i for i in range(attribution.dim()) if i != 1
1169            ],  # average the features over batch samples
1170        )
1171
1172        importance_step_layer = attribution_abs_batch_mean
1173        importance_step_layer = importance_step_layer.detach()
1174
1175        return importance_step_layer
1176
1177    def get_importance_step_layer_deeplift_abs(
1178        self: str,
1179        layer_name: str,
1180        input: Tensor | tuple[Tensor, ...],
1181        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1182        target: Tensor | None,
1183        batch_idx: int,
1184        num_batches: int,
1185    ) -> Tensor:
1186        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum.
1187
1188        **Args:**
1189        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1190        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1191        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details.
1192        - **target** (`Tensor` | `None`): the target batch of the training step.
1193        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1194        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1195
1196        **Returns:**
1197        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1198        """
1199        layer = self.backbone.get_layer_by_name(layer_name)
1200
1201        # initialize the Layer DeepLift object
1202        layer_deeplift = LayerDeepLift(model=self, layer=layer)
1203
1204        # convert the target to long type to avoid error
1205        target = target.long() if target is not None else None
1206
1207        self.set_forward_func_return_logits_only(True)
1208        # calculate layer attribution of the step
1209        attribution = layer_deeplift.attribute(
1210            inputs=input,
1211            baselines=baselines,
1212            target=target,
1213            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1214        )
1215        self.set_forward_func_return_logits_only(False)
1216
1217        attribution_abs_batch_mean = torch.mean(
1218            torch.abs(attribution),
1219            dim=[
1220                i for i in range(attribution.dim()) if i != 1
1221            ],  # average the features over batch samples
1222        )
1223
1224        importance_step_layer = attribution_abs_batch_mean
1225        importance_step_layer = importance_step_layer.detach()
1226
1227        return importance_step_layer
1228
1229    def get_importance_step_layer_deepliftshap_abs(
1230        self: str,
1231        layer_name: str,
1232        input: Tensor | tuple[Tensor, ...],
1233        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1234        target: Tensor | None,
1235        batch_idx: int,
1236        num_batches: int,
1237    ) -> Tensor:
1238        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum.
1239
1240        **Args:**
1241        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1242        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1243        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details.
1244        - **target** (`Tensor` | `None`): the target batch of the training step.
1245        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1246        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1247
1248        **Returns:**
1249        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1250        """
1251        layer = self.backbone.get_layer_by_name(layer_name)
1252
1253        # initialize the Layer DeepLiftShap object
1254        layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer)
1255
1256        # convert the target to long type to avoid error
1257        target = target.long() if target is not None else None
1258
1259        self.set_forward_func_return_logits_only(True)
1260        # calculate layer attribution of the step
1261        attribution = layer_deepliftshap.attribute(
1262            inputs=input,
1263            baselines=baselines,
1264            target=target,
1265            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1266        )
1267        self.set_forward_func_return_logits_only(False)
1268
1269        attribution_abs_batch_mean = torch.mean(
1270            torch.abs(attribution),
1271            dim=[
1272                i for i in range(attribution.dim()) if i != 1
1273            ],  # average the features over batch samples
1274        )
1275
1276        importance_step_layer = attribution_abs_batch_mean
1277        importance_step_layer = importance_step_layer.detach()
1278
1279        return importance_step_layer
1280
1281    def get_importance_step_layer_gradientshap_abs(
1282        self: str,
1283        layer_name: str,
1284        input: Tensor | tuple[Tensor, ...],
1285        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1286        target: Tensor | None,
1287        batch_idx: int,
1288        num_batches: int,
1289    ) -> Tensor:
1290        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum.
1291
1292        **Args:**
1293        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1294        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1295        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero.
1296        - **target** (`Tensor` | `None`): the target batch of the training step.
1297        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1298        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1299
1300        **Returns:**
1301        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1302        """
1303        layer = self.backbone.get_layer_by_name(layer_name)
1304
1305        if baselines is None:
1306            baselines = torch.zeros_like(
1307                input
1308            )  # baselines are mandatory for GradientShap API. We explicitly set them to zero
1309
1310        # initialize the Layer GradientShap object
1311        layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer)
1312
1313        # convert the target to long type to avoid error
1314        target = target.long() if target is not None else None
1315
1316        self.set_forward_func_return_logits_only(True)
1317        # calculate layer attribution of the step
1318        attribution = layer_gradientshap.attribute(
1319            inputs=input,
1320            baselines=baselines,
1321            target=target,
1322            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1323        )
1324        self.set_forward_func_return_logits_only(False)
1325
1326        attribution_abs_batch_mean = torch.mean(
1327            torch.abs(attribution),
1328            dim=[
1329                i for i in range(attribution.dim()) if i != 1
1330            ],  # average the features over batch samples
1331        )
1332
1333        importance_step_layer = attribution_abs_batch_mean
1334        importance_step_layer = importance_step_layer.detach()
1335
1336        return importance_step_layer
1337
1338    def get_importance_step_layer_integrated_gradients_abs(
1339        self: str,
1340        layer_name: str,
1341        input: Tensor | tuple[Tensor, ...],
1342        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1343        target: Tensor | None,
1344        batch_idx: int,
1345        num_batches: int,
1346    ) -> Tensor:
1347        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum.
1348
1349        **Args:**
1350        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1351        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1352        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details.
1353        - **target** (`Tensor` | `None`): the target batch of the training step.
1354        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1355        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1356
1357        **Returns:**
1358        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1359        """
1360        layer = self.backbone.get_layer_by_name(layer_name)
1361
1362        # initialize the Layer Integrated Gradients object
1363        layer_integrated_gradients = LayerIntegratedGradients(
1364            forward_func=self.forward, layer=layer
1365        )
1366
1367        self.set_forward_func_return_logits_only(True)
1368        # calculate layer attribution of the step
1369        attribution = layer_integrated_gradients.attribute(
1370            inputs=input,
1371            baselines=baselines,
1372            target=target,
1373            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1374        )
1375        self.set_forward_func_return_logits_only(False)
1376
1377        attribution_abs_batch_mean = torch.mean(
1378            torch.abs(attribution),
1379            dim=[
1380                i for i in range(attribution.dim()) if i != 1
1381            ],  # average the features over batch samples
1382        )
1383
1384        importance_step_layer = attribution_abs_batch_mean
1385        importance_step_layer = importance_step_layer.detach()
1386
1387        return importance_step_layer
1388
1389    def get_importance_step_layer_feature_ablation_abs(
1390        self: str,
1391        layer_name: str,
1392        input: Tensor | tuple[Tensor, ...],
1393        layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1394        target: Tensor | None,
1395        batch_idx: int,
1396        num_batches: int,
1397        if_captum: bool = False,
1398    ) -> Tensor:
1399        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum.
1400
1401        **Args:**
1402        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1403        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1404        - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details.
1405        - **target** (`Tensor` | `None`): the target batch of the training step.
1406        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1407        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1408        - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster.
1409
1410        **Returns:**
1411        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1412        """
1413        layer = self.backbone.get_layer_by_name(layer_name)
1414
1415        if not if_captum:
1416            # 1. Baseline logits (take first element of forward output)
1417            baseline_out, _, _ = self.forward(
1418                input, "train", batch_idx, num_batches, self.task_id
1419            )
1420            if target is not None:
1421                baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1)
1422            else:
1423                baseline_scores = baseline_out.sum(dim=1)
1424
1425            # 2. Capture layer’s output shape
1426            activs = {}
1427            handle = layer.register_forward_hook(
1428                lambda module, inp, out: activs.setdefault("output", out.detach())
1429            )
1430            _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id)
1431            handle.remove()
1432            layer_output = activs["output"]  # shape (B, F, ...)
1433
1434            # 3. Build baseline tensor matching that shape
1435            if layer_baselines is None:
1436                baseline_tensor = torch.zeros_like(layer_output)
1437            elif isinstance(layer_baselines, (int, float)):
1438                baseline_tensor = torch.full_like(layer_output, layer_baselines)
1439            elif isinstance(layer_baselines, Tensor):
1440                if layer_baselines.shape == layer_output.shape:
1441                    baseline_tensor = layer_baselines
1442                elif layer_baselines.shape == layer_output.shape[1:]:
1443                    baseline_tensor = layer_baselines.unsqueeze(0).repeat(
1444                        layer_output.size(0), *([1] * layer_baselines.ndim)
1445                    )
1446                else:
1447                    raise ValueError(...)
1448            else:
1449                raise ValueError(...)
1450
1451            B, F = layer_output.size(0), layer_output.size(1)
1452
1453            # 4. Create a “mega-batch” replicating the input F times
1454            if isinstance(input, tuple):
1455                mega_inputs = tuple(
1456                    t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:])
1457                    for t in input
1458                )
1459            else:
1460                mega_inputs = (
1461                    input.unsqueeze(0)
1462                    .repeat(F, *([1] * input.ndim))
1463                    .view(-1, *input.shape[1:])
1464                )
1465
1466            # 5. Equally replicate the baseline tensor
1467            mega_baseline = (
1468                baseline_tensor.unsqueeze(0)
1469                .repeat(F, *([1] * baseline_tensor.ndim))
1470                .view(-1, *baseline_tensor.shape[1:])
1471            )
1472
1473            # 6. Precompute vectorized indices
1474            device = layer_output.device
1475            positions = torch.arange(F * B, device=device)  # [0,1,...,F*B-1]
1476            feat_idx = torch.arange(F, device=device).repeat_interleave(
1477                B
1478            )  # [0,0,...,1,1,...,F-1]
1479
1480            # 7. One hook to zero out each channel slice across the mega-batch
1481            def mega_ablate_hook(module, inp, out):
1482                out_mod = out.clone()
1483                # for each sample in mega-batch, zero its corresponding channel
1484                out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx]
1485                return out_mod
1486
1487            h = layer.register_forward_hook(mega_ablate_hook)
1488            out_all, _, _ = self.forward(
1489                mega_inputs, "train", batch_idx, num_batches, self.task_id
1490            )
1491            h.remove()
1492
1493            # 8. Recover scores, reshape [F*B] → [F, B], diff & mean
1494            if target is not None:
1495                tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1)
1496                scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1)
1497            else:
1498                scores_all = out_all.sum(dim=1)
1499
1500            scores_all = scores_all.view(F, B)
1501            diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all)
1502            importance_step_layer = diffs.mean(dim=1).detach()  # [F]
1503
1504            return importance_step_layer
1505
1506        else:
1507            # initialize the Layer Feature Ablation object
1508            layer_feature_ablation = LayerFeatureAblation(
1509                forward_func=self.forward, layer=layer
1510            )
1511
1512            # calculate layer attribution of the step
1513            self.set_forward_func_return_logits_only(True)
1514            attribution = layer_feature_ablation.attribute(
1515                inputs=input,
1516                layer_baselines=layer_baselines,
1517                # target=target, # disable target to enable perturbations_per_eval
1518                additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1519                perturbations_per_eval=128,  # to accelerate the computation
1520            )
1521            self.set_forward_func_return_logits_only(False)
1522
1523            attribution_abs_batch_mean = torch.mean(
1524                torch.abs(attribution),
1525                dim=[
1526                    i for i in range(attribution.dim()) if i != 1
1527                ],  # average the features over batch samples
1528            )
1529
1530        importance_step_layer = attribution_abs_batch_mean
1531        importance_step_layer = importance_step_layer.detach()
1532
1533        return importance_step_layer
1534
1535    def get_importance_step_layer_lrp_abs(
1536        self: str,
1537        layer_name: str,
1538        input: Tensor | tuple[Tensor, ...],
1539        target: Tensor | None,
1540        batch_idx: int,
1541        num_batches: int,
1542    ) -> Tensor:
1543        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum.
1544
1545        **Args:**
1546        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1547        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1548        - **target** (`Tensor` | `None`): the target batch of the training step.
1549        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1550        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1551
1552        **Returns:**
1553        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1554        """
1555        layer = self.backbone.get_layer_by_name(layer_name)
1556
1557        # initialize the Layer LRP object
1558        layer_lrp = LayerLRP(model=self, layer=layer)
1559
1560        # set model to evaluation mode to prevent updating the model parameters
1561        self.eval()
1562
1563        self.set_forward_func_return_logits_only(True)
1564        # calculate layer attribution of the step
1565        attribution = layer_lrp.attribute(
1566            inputs=input,
1567            target=target,
1568            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1569        )
1570        self.set_forward_func_return_logits_only(False)
1571
1572        attribution_abs_batch_mean = torch.mean(
1573            torch.abs(attribution),
1574            dim=[
1575                i for i in range(attribution.dim()) if i != 1
1576            ],  # average the features over batch samples
1577        )
1578
1579        importance_step_layer = attribution_abs_batch_mean
1580        importance_step_layer = importance_step_layer.detach()
1581
1582        return importance_step_layer
1583
1584    def get_importance_step_layer_cbp_adaptive_contribution(
1585        self: str,
1586        layer_name: str,
1587        activation: Tensor,
1588    ) -> Tensor:
1589        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
1590
1591        **Args:**
1592        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1593        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
1594
1595        **Returns:**
1596        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1597        """
1598        layer = self.backbone.get_layer_by_name(layer_name)
1599
1600        input_weight_abs = torch.abs(layer.weight.data)
1601        input_weight_abs_sum = torch.sum(
1602            input_weight_abs,
1603            dim=[
1604                i for i in range(input_weight_abs.dim()) if i != 0
1605            ],  # sum over the input dimension
1606        )
1607        input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum)
1608
1609        output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
1610        output_weight_abs_sum = torch.sum(
1611            output_weight_abs,
1612            dim=[
1613                i for i in range(output_weight_abs.dim()) if i != 1
1614            ],  # sum over the output dimension
1615        )
1616
1617        activation_abs_batch_mean = torch.mean(
1618            torch.abs(activation),
1619            dim=[
1620                i for i in range(activation.dim()) if i != 1
1621            ],  # average the features over batch samples
1622        )
1623
1624        importance_step_layer = (
1625            output_weight_abs_sum
1626            * activation_abs_batch_mean
1627            * input_weight_abs_sum_reciprocal
1628        )
1629        importance_step_layer = importance_step_layer.detach()
1630
1631        return importance_step_layer

FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm.

An architecture-based continual learning approach that improves AdaHAT (Adaptive Hard Attention to the Task) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.

We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components.

FGAdaHAT( backbone: clarena.backbones.HATMaskBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadDIL, adjustment_intensity: float, importance_type: str, importance_summing_strategy: str, importance_scheduler_type: str, neuron_to_weight_importance_aggregation_mode: str, s_max: float, clamp_threshold: float, mask_sparsity_reg_factor: float, mask_sparsity_reg_mode: str = 'original', base_importance: float = 0.01, base_mask_sparsity_reg: float = 0.1, base_linear: float = 10, filter_by_cumulative_mask: bool = False, filter_unmasked_importance: bool = True, step_multiply_training_mask: bool = True, task_embedding_init_mode: str = 'N01', importance_summing_strategy_linear_step: float | None = None, importance_summing_strategy_exponential_rate: float | None = None, importance_summing_strategy_log_base: float | None = None, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
 45    def __init__(
 46        self,
 47        backbone: HATMaskBackbone,
 48        heads: HeadsTIL | HeadDIL,
 49        adjustment_intensity: float,
 50        importance_type: str,
 51        importance_summing_strategy: str,
 52        importance_scheduler_type: str,
 53        neuron_to_weight_importance_aggregation_mode: str,
 54        s_max: float,
 55        clamp_threshold: float,
 56        mask_sparsity_reg_factor: float,
 57        mask_sparsity_reg_mode: str = "original",
 58        base_importance: float = 0.01,
 59        base_mask_sparsity_reg: float = 0.1,
 60        base_linear: float = 10,
 61        filter_by_cumulative_mask: bool = False,
 62        filter_unmasked_importance: bool = True,
 63        step_multiply_training_mask: bool = True,
 64        task_embedding_init_mode: str = "N01",
 65        importance_summing_strategy_linear_step: float | None = None,
 66        importance_summing_strategy_exponential_rate: float | None = None,
 67        importance_summing_strategy_log_base: float | None = None,
 68        non_algorithmic_hparams: dict[str, Any] = {},
 69        **kwargs,
 70    ) -> None:
 71        r"""Initialize the FG-AdaHAT algorithm with the network.
 72
 73        **Args:**
 74        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 75        - **heads** (`HeadsTIL` | `HeadDIL`): output heads. FG-AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
 76        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper).
 77        - **importance_type** (`str`): the type of neuron-wise importance, must be one of:
 78            1. 'input_weight_abs_sum': sum of absolute input weights;
 79            2. 'output_weight_abs_sum': sum of absolute output weights;
 80            3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
 81            4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
 82            5. 'activation_abs': absolute activation;
 83            6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
 84            7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
 85            8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
 86            9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
 87            10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
 88            11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
 89            12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
 90            13. 'conductance_abs': absolute layer conductance;
 91            14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
 92            15. 'gradcam_abs': absolute Grad-CAM;
 93            16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
 94            17. 'deepliftshap_abs': absolute DeepLIFT-SHAP;
 95            18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
 96            19. 'integrated_gradients_abs': absolute Integrated Gradients;
 97            20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
 98            21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
 99            22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
100            23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
101        - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of:
102            1. 'add_latest': add the latest neuron-wise importance to the summative importance;
103            2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
104            3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
105            4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
106            5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
107            6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
108            7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
109            8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
110            9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
111        - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:
112            1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`);
113            2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$;
114            3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$.
115        - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:
116            1. 'min': take the minimum of neuron-wise importance for each weight;
117            2. 'max': take the maximum of neuron-wise importance for each weight;
118            3. 'mean': take the mean of neuron-wise importance for each weight.
119        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
120        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
121        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
122        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
123            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
124            2. 'cross': the cross version of mask sparsity regularization.
125        - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01.
126        - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1.
127        - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10.
128        - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False.
129        - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False.
130        - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True.
131        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
132            1. 'N01' (default): standard normal distribution $N(0, 1)$.
133            2. 'U-11': uniform distribution $U(-1, 1)$.
134            3. 'U01': uniform distribution $U(0, 1)$.
135            4. 'U-10': uniform distribution $U(-1, 0)$.
136            5. 'last': inherit the task embedding from the last task.
137        - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0.
138        - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1.
139        - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1.
140        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
141        - **kwargs**: Reserved for multiple inheritance.
142
143        """
144        super().__init__(
145            backbone=backbone,
146            heads=heads,
147            adjustment_mode=None,  # use the own adjustment mechanism of FG-AdaHAT
148            adjustment_intensity=adjustment_intensity,
149            s_max=s_max,
150            clamp_threshold=clamp_threshold,
151            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
152            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
153            task_embedding_init_mode=task_embedding_init_mode,
154            epsilon=base_mask_sparsity_reg,  # the epsilon is now the base mask sparsity regularization factor
155            non_algorithmic_hparams=non_algorithmic_hparams,
156            **kwargs,
157        )
158
159        # save additional algorithmic hyperparameters
160        self.save_hyperparameters(
161            "adjustment_intensity",
162            "importance_type",
163            "importance_summing_strategy",
164            "importance_scheduler_type",
165            "neuron_to_weight_importance_aggregation_mode",
166            "s_max",
167            "clamp_threshold",
168            "mask_sparsity_reg_factor",
169            "mask_sparsity_reg_mode",
170            "base_importance",
171            "base_mask_sparsity_reg",
172            "base_linear",
173            "filter_by_cumulative_mask",
174            "filter_unmasked_importance",
175            "step_multiply_training_mask",
176        )
177
178        self.importance_type: str | None = importance_type
179        r"""The type of the neuron-wise importance added to AdaHAT importance."""
180
181        self.importance_scheduler_type: str = importance_scheduler_type
182        r"""The type of the importance scheduler."""
183        self.neuron_to_weight_importance_aggregation_mode: str = (
184            neuron_to_weight_importance_aggregation_mode
185        )
186        r"""The mode of aggregation from neuron-wise to weight-wise importance. """
187        self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask
188        r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate."""
189        self.filter_unmasked_importance: bool = filter_unmasked_importance
190        r"""The flag to filter unmasked importance values (set them to 0) at the end of task training."""
191        self.step_multiply_training_mask: bool = step_multiply_training_mask
192        r"""The flag to multiply the training mask to the importance at each training step."""
193
194        # importance summing strategy
195        self.importance_summing_strategy: str = importance_summing_strategy
196        r"""The strategy to sum the neuron-wise importance for previous tasks."""
197        if importance_summing_strategy_linear_step is not None:
198            self.importance_summing_strategy_linear_step: float = (
199                importance_summing_strategy_linear_step
200            )
201            r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease')."""
202        if importance_summing_strategy_exponential_rate is not None:
203            self.importance_summing_strategy_exponential_rate: float = (
204                importance_summing_strategy_exponential_rate
205            )
206            r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """
207        if importance_summing_strategy_log_base is not None:
208            self.importance_summing_strategy_log_base: float = (
209                importance_summing_strategy_log_base
210            )
211            r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """
212
213        # base values
214        self.base_importance: float = base_importance
215        r"""The base value added to the importance to avoid zero. """
216        self.base_mask_sparsity_reg: float = base_mask_sparsity_reg
217        r"""The base value added to the mask sparsity regularization to avoid zero. """
218        self.base_linear: float = base_linear
219        r"""The base value added to the linear layer to avoid zero. """
220
221        self.importances: dict[int, dict[str, Tensor]] = {}
222        r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
223        self.summative_importance_for_previous_tasks: dict[str, Tensor] = {}
224        r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """
225
226        self.num_steps_t: int
227        r"""The number of training steps for the current task `self.task_id`."""
228        # set manual optimization
229        self.automatic_optimization = False
230
231        FGAdaHAT.sanity_check(self)

Initialize the FG-AdaHAT algorithm with the network.

Args:

  • backbone (HATMaskBackbone): must be a backbone network with the HAT mask mechanism.
  • heads (HeadsTIL | HeadDIL): output heads. FG-AdaHAT supports TIL (Task-Incremental Learning) and DIL (Domain-Incremental Learning).
  • adjustment_intensity (float): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper).
  • importance_type (str): the type of neuron-wise importance, must be one of:
    1. 'input_weight_abs_sum': sum of absolute input weights;
    2. 'output_weight_abs_sum': sum of absolute output weights;
    3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
    4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
    5. 'activation_abs': absolute activation;
    6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
    7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
    8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
    9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
    10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
    11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
    12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
    13. 'conductance_abs': absolute layer conductance;
    14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
    15. 'gradcam_abs': absolute Grad-CAM;
    16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
    17. 'deepliftshap_abs': absolute DeepLIFT-SHAP;
    18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
    19. 'integrated_gradients_abs': absolute Integrated Gradients;
    20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
    21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
    22. 'cbp_adaptation': the adaptation function in Continual Backpropagation (CBP);
    23. 'cbp_adaptive_contribution': the adaptive contribution function in Continual Backpropagation (CBP);
  • importance_summing_strategy (str): the strategy to sum neuron-wise importance for previous tasks, must be one of:
    1. 'add_latest': add the latest neuron-wise importance to the summative importance;
    2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
    3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
    4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
    5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
    6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
    7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
    8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
    9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
  • importance_scheduler_type (str): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:
    1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{base_linear), and $b_R$ is the base mask sparsity regularization factor (see argument base_mask_sparsity_reg);
    2. 'sparsity_reg': $c^t = [R(M^t, M^{
    3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{
  • neuron_to_weight_importance_aggregation_mode (str): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:
    1. 'min': take the minimum of neuron-wise importance for each weight;
    2. 'max': take the maximum of neuron-wise importance for each weight;
    3. 'mean': take the mean of neuron-wise importance for each weight.
  • s_max (float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper.
  • clamp_threshold (float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.
  • mask_sparsity_reg_factor (float): hyperparameter, the regularization factor for mask sparsity.
  • mask_sparsity_reg_mode (str): the mode of mask sparsity regularization, must be one of:
    1. 'original' (default): the original mask sparsity regularization in the HAT paper.
    2. 'cross': the cross version of mask sparsity regularization.
  • base_importance (float): base value added to importance ($b_I$ in the paper). Default: 0.01.
  • base_mask_sparsity_reg (float): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1.
  • base_linear (float): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10.
  • filter_by_cumulative_mask (bool): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False.
  • filter_unmasked_importance (bool): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False.
  • step_multiply_training_mask (bool): whether to multiply the training mask to the importance at each training step. Default: True.
  • task_embedding_init_mode (str): the initialization mode for task embeddings, must be one of:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit the task embedding from the last task.
  • importance_summing_strategy_linear_step (float | None): linear step for the importance summing strategy (used when importance_summing_strategy is 'linear_decrease'). Must be > 0.
  • importance_summing_strategy_exponential_rate (float | None): exponential rate for the importance summing strategy (used when importance_summing_strategy is 'exponential_decrease'). Must be > 1.
  • importance_summing_strategy_log_base (float | None): base for the logarithm in the importance summing strategy (used when importance_summing_strategy is 'log_decrease'). Must be > 1.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
importance_type: str | None

The type of the neuron-wise importance added to AdaHAT importance.

importance_scheduler_type: str

The type of the importance scheduler.

neuron_to_weight_importance_aggregation_mode: str

The mode of aggregation from neuron-wise to weight-wise importance.

filter_by_cumulative_mask: bool

The flag to filter importance by the cumulative mask when calculating the adjustment rate.

filter_unmasked_importance: bool

The flag to filter unmasked importance values (set them to 0) at the end of task training.

step_multiply_training_mask: bool

The flag to multiply the training mask to the importance at each training step.

importance_summing_strategy: str

The strategy to sum the neuron-wise importance for previous tasks.

base_importance: float

The base value added to the importance to avoid zero.

base_mask_sparsity_reg: float

The base value added to the mask sparsity regularization to avoid zero.

base_linear: float

The base value added to the linear layer to avoid zero.

importances: dict[int, dict[str, torch.Tensor]]

The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ).

summative_importance_for_previous_tasks: dict[str, torch.Tensor]

The summative neuron-wise importance values of units for previous tasks before the current task self.task_id. See $I^{

num_steps_t: int

The number of training steps for the current task self.task_id.

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
233    def sanity_check(self) -> None:
234        r"""Sanity check."""
235
236        # check importance type
237        if self.importance_type not in [
238            "input_weight_abs_sum",
239            "output_weight_abs_sum",
240            "input_weight_gradient_abs_sum",
241            "output_weight_gradient_abs_sum",
242            "activation_abs",
243            "input_weight_abs_sum_x_activation_abs",
244            "output_weight_abs_sum_x_activation_abs",
245            "gradient_x_activation_abs",
246            "input_weight_gradient_square_sum",
247            "output_weight_gradient_square_sum",
248            "input_weight_gradient_square_sum_x_activation_abs",
249            "output_weight_gradient_square_sum_x_activation_abs",
250            "conductance_abs",
251            "internal_influence_abs",
252            "gradcam_abs",
253            "deeplift_abs",
254            "deepliftshap_abs",
255            "gradientshap_abs",
256            "integrated_gradients_abs",
257            "feature_ablation_abs",
258            "lrp_abs",
259            "cbp_adaptation",
260            "cbp_adaptive_contribution",
261        ]:
262            raise ValueError(
263                f"importance_type must be one of the predefined types, but got {self.importance_type}"
264            )
265
266        # check importance summing strategy
267        if self.importance_summing_strategy not in [
268            "add_latest",
269            "add_all",
270            "add_average",
271            "linear_decrease",
272            "quadratic_decrease",
273            "cubic_decrease",
274            "exponential_decrease",
275            "log_decrease",
276            "factorial_decrease",
277        ]:
278            raise ValueError(
279                f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}"
280            )
281
282        # check importance scheduler type
283        if self.importance_scheduler_type not in [
284            "linear_sparsity_reg",
285            "sparsity_reg",
286            "summative_mask_sparsity_reg",
287        ]:
288            raise ValueError(
289                f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}"
290            )
291
292        # check neuron to weight importance aggregation mode
293        if self.neuron_to_weight_importance_aggregation_mode not in [
294            "min",
295            "max",
296            "mean",
297        ]:
298            raise ValueError(
299                f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}"
300            )
301
302        # check base values
303        if self.base_importance < 0:
304            raise ValueError(
305                f"base_importance must be >= 0, but got {self.base_importance}"
306            )
307        if self.base_mask_sparsity_reg <= 0:
308            raise ValueError(
309                f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}"
310            )
311        if self.base_linear <= 0:
312            raise ValueError(f"base_linear must be > 0, but got {self.base_linear}")

Sanity check.

def on_train_start(self) -> None:
314    def on_train_start(self) -> None:
315        r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization."""
316        super().on_train_start()
317
318        self.importances[self.task_id] = (
319            {}
320        )  # initialize the importance for the current task
321
322        # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time.
323        for layer_name in self.backbone.weighted_layer_names:
324            layer = self.backbone.get_layer_by_name(
325                layer_name
326            )  # get the layer by its name
327            num_units = layer.weight.shape[0]
328
329            # initialize the accumulated importance at the beginning of each task
330            self.importances[self.task_id][layer_name] = torch.zeros(num_units).to(
331                self.device
332            )
333
334            # reset the number of steps counter for the current task
335            self.num_steps_t = 0
336
337            # initialize the summative neuron-wise importance at the beginning of the first task
338            if self.task_id == 1:
339                self.summative_importance_for_previous_tasks[layer_name] = torch.zeros(
340                    num_units
341                ).to(
342                    self.device
343                )  # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$

Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.

def clip_grad_by_adjustment( self, network_sparsity: dict[str, torch.Tensor]) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], torch.Tensor]:
345    def clip_grad_by_adjustment(
346        self,
347        network_sparsity: dict[str, Tensor],
348    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
349        r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper.
350
351        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
352
353        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
354
355        **Args:**
356        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.
357
358        **Returns:**
359        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
360        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
361        - **capacity** (`Tensor`): the calculated network capacity.
362        """
363
364        # initialize network capacity metric
365        capacity = HATNetworkCapacityMetric().to(self.device)
366        adjustment_rate_weight = {}
367        adjustment_rate_bias = {}
368
369        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper
370        for layer_name in self.backbone.weighted_layer_names:
371
372            layer = self.backbone.get_layer_by_name(
373                layer_name
374            )  # get the layer by its name
375
376            # placeholder for the adjustment rate to avoid the error of using it before assignment
377            adjustment_rate_weight_layer = 1
378            adjustment_rate_bias_layer = 1
379
380            # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask
381            weight_importance, bias_importance = (
382                self.backbone.get_layer_measure_parameter_wise(
383                    neuron_wise_measure=self.summative_importance_for_previous_tasks,
384                    layer_name=layer_name,
385                    aggregation_mode=self.neuron_to_weight_importance_aggregation_mode,
386                )
387            )
388
389            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
390                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
391                layer_name=layer_name,
392                aggregation_mode="min",
393            )
394
395            # filter the weight importance by the cumulative mask
396            if self.filter_by_cumulative_mask:
397                weight_importance = weight_importance * weight_mask
398                bias_importance = bias_importance * bias_mask
399
400            network_sparsity_layer = network_sparsity[layer_name]
401
402            # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper
403            factor = network_sparsity_layer + self.base_mask_sparsity_reg
404            if self.importance_scheduler_type == "linear_sparsity_reg":
405                factor = factor * (self.task_id + self.base_linear)
406            elif self.importance_scheduler_type == "sparsity_reg":
407                pass
408            elif self.importance_scheduler_type == "summative_mask_sparsity_reg":
409                factor = factor * (
410                    self.summative_mask_for_previous_tasks + self.base_linear
411                )
412
413            # calculate the adjustment rate
414            adjustment_rate_weight_layer = torch.div(
415                self.adjustment_intensity,
416                (factor * weight_importance + self.adjustment_intensity),
417            )
418
419            adjustment_rate_bias_layer = torch.div(
420                self.adjustment_intensity,
421                (factor * bias_importance + self.adjustment_intensity),
422            )
423
424            # apply the adjustment rate to the gradients
425            layer.weight.grad.data *= adjustment_rate_weight_layer
426            if layer.bias is not None:
427                layer.bias.grad.data *= adjustment_rate_bias_layer
428
429            # store the adjustment rate for logging
430            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
431            if layer.bias is not None:
432                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
433
434            # update network capacity metric
435            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
436
437        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()

Clip the gradients by the adjustment rate. See Eq. (1) in the paper.

Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.

Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.

Args:

  • network_sparsity (dict[str, Tensor]): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.

Returns:

  • adjustment_rate_weight (dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors.
  • adjustment_rate_bias (dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors.
  • capacity (Tensor): the calculated network capacity.
def on_train_batch_end(self, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
439    def on_train_batch_end(
440        self, outputs: dict[str, Any], batch: Any, batch_idx: int
441    ) -> None:
442        r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.
443
444        **Args:**
445        - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`).
446        - **batch** (`Any`): training data batch.
447        - **batch_idx** (`int`): index of the current batch (for mask figure file name).
448        """
449
450        # get potential useful information from training batch
451        activations = outputs["activations"]
452        input = outputs["input"]
453        target = outputs["target"]
454        mask = outputs["mask"]
455        num_batches = self.trainer.num_training_batches
456
457        for layer_name in self.backbone.weighted_layer_names:
458            # layer-wise operation
459
460            activation = activations[layer_name]
461
462            # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper.
463            if self.importance_type == "input_weight_abs_sum":
464                importance_step = self.get_importance_step_layer_weight_abs_sum(
465                    layer_name=layer_name,
466                    if_output_weight=False,
467                    reciprocal=False,
468                )
469            elif self.importance_type == "output_weight_abs_sum":
470                importance_step = self.get_importance_step_layer_weight_abs_sum(
471                    layer_name=layer_name,
472                    if_output_weight=True,
473                    reciprocal=False,
474                )
475            elif self.importance_type == "input_weight_gradient_abs_sum":
476                importance_step = (
477                    self.get_importance_step_layer_weight_gradient_abs_sum(
478                        layer_name=layer_name, if_output_weight=False
479                    )
480                )
481            elif self.importance_type == "output_weight_gradient_abs_sum":
482                importance_step = (
483                    self.get_importance_step_layer_weight_gradient_abs_sum(
484                        layer_name=layer_name, if_output_weight=True
485                    )
486                )
487            elif self.importance_type == "activation_abs":
488                importance_step = self.get_importance_step_layer_activation_abs(
489                    activation=activation
490                )
491            elif self.importance_type == "input_weight_abs_sum_x_activation_abs":
492                importance_step = (
493                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
494                        layer_name=layer_name,
495                        activation=activation,
496                        if_output_weight=False,
497                    )
498                )
499            elif self.importance_type == "output_weight_abs_sum_x_activation_abs":
500                importance_step = (
501                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
502                        layer_name=layer_name,
503                        activation=activation,
504                        if_output_weight=True,
505                    )
506                )
507            elif self.importance_type == "gradient_x_activation_abs":
508                importance_step = (
509                    self.get_importance_step_layer_gradient_x_activation_abs(
510                        layer_name=layer_name,
511                        input=input,
512                        target=target,
513                        batch_idx=batch_idx,
514                        num_batches=num_batches,
515                    )
516                )
517            elif self.importance_type == "input_weight_gradient_square_sum":
518                importance_step = (
519                    self.get_importance_step_layer_weight_gradient_square_sum(
520                        layer_name=layer_name,
521                        activation=activation,
522                        if_output_weight=False,
523                    )
524                )
525            elif self.importance_type == "output_weight_gradient_square_sum":
526                importance_step = (
527                    self.get_importance_step_layer_weight_gradient_square_sum(
528                        layer_name=layer_name,
529                        activation=activation,
530                        if_output_weight=True,
531                    )
532                )
533            elif (
534                self.importance_type
535                == "input_weight_gradient_square_sum_x_activation_abs"
536            ):
537                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
538                    layer_name=layer_name,
539                    activation=activation,
540                    if_output_weight=False,
541                )
542            elif (
543                self.importance_type
544                == "output_weight_gradient_square_sum_x_activation_abs"
545            ):
546                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
547                    layer_name=layer_name,
548                    activation=activation,
549                    if_output_weight=True,
550                )
551            elif self.importance_type == "conductance_abs":
552                importance_step = self.get_importance_step_layer_conductance_abs(
553                    layer_name=layer_name,
554                    input=input,
555                    baselines=None,
556                    target=target,
557                    batch_idx=batch_idx,
558                    num_batches=num_batches,
559                )
560            elif self.importance_type == "internal_influence_abs":
561                importance_step = self.get_importance_step_layer_internal_influence_abs(
562                    layer_name=layer_name,
563                    input=input,
564                    baselines=None,
565                    target=target,
566                    batch_idx=batch_idx,
567                    num_batches=num_batches,
568                )
569            elif self.importance_type == "gradcam_abs":
570                importance_step = self.get_importance_step_layer_gradcam_abs(
571                    layer_name=layer_name,
572                    input=input,
573                    target=target,
574                    batch_idx=batch_idx,
575                    num_batches=num_batches,
576                )
577            elif self.importance_type == "deeplift_abs":
578                importance_step = self.get_importance_step_layer_deeplift_abs(
579                    layer_name=layer_name,
580                    input=input,
581                    baselines=None,
582                    target=target,
583                    batch_idx=batch_idx,
584                    num_batches=num_batches,
585                )
586            elif self.importance_type == "deepliftshap_abs":
587                importance_step = self.get_importance_step_layer_deepliftshap_abs(
588                    layer_name=layer_name,
589                    input=input,
590                    baselines=None,
591                    target=target,
592                    batch_idx=batch_idx,
593                    num_batches=num_batches,
594                )
595            elif self.importance_type == "gradientshap_abs":
596                importance_step = self.get_importance_step_layer_gradientshap_abs(
597                    layer_name=layer_name,
598                    input=input,
599                    baselines=None,
600                    target=target,
601                    batch_idx=batch_idx,
602                    num_batches=num_batches,
603                )
604            elif self.importance_type == "integrated_gradients_abs":
605                importance_step = (
606                    self.get_importance_step_layer_integrated_gradients_abs(
607                        layer_name=layer_name,
608                        input=input,
609                        baselines=None,
610                        target=target,
611                        batch_idx=batch_idx,
612                        num_batches=num_batches,
613                    )
614                )
615            elif self.importance_type == "feature_ablation_abs":
616                importance_step = self.get_importance_step_layer_feature_ablation_abs(
617                    layer_name=layer_name,
618                    input=input,
619                    layer_baselines=None,
620                    target=target,
621                    batch_idx=batch_idx,
622                    num_batches=num_batches,
623                )
624            elif self.importance_type == "lrp_abs":
625                importance_step = self.get_importance_step_layer_lrp_abs(
626                    layer_name=layer_name,
627                    input=input,
628                    target=target,
629                    batch_idx=batch_idx,
630                    num_batches=num_batches,
631                )
632            elif self.importance_type == "cbp_adaptation":
633                importance_step = self.get_importance_step_layer_weight_abs_sum(
634                    layer_name=layer_name,
635                    if_output_weight=False,
636                    reciprocal=True,
637                )
638            elif self.importance_type == "cbp_adaptive_contribution":
639                importance_step = (
640                    self.get_importance_step_layer_cbp_adaptive_contribution(
641                        layer_name=layer_name,
642                        activation=activation,
643                    )
644                )
645
646            importance_step = min_max_normalize(
647                importance_step
648            )  # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper
649
650            # multiply the importance by the training mask. See Eq. (6) in the paper
651            if self.step_multiply_training_mask:
652                importance_step = importance_step * mask[layer_name]
653
654            # update accumulated importance
655            self.importances[self.task_id][layer_name] = (
656                self.importances[self.task_id][layer_name] + importance_step
657            )
658
659        # update number of steps counter
660        self.num_steps_t += 1

Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.

Args:

  • outputs (dict[str, Any]): outputs of the training step (returns of training_step() in CLAlgorithm).
  • batch (Any): training data batch.
  • batch_idx (int): index of the current batch (for mask figure file name).
def on_train_end(self) -> None:
662    def on_train_end(self) -> None:
663        r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task."""
664        super().on_train_end()  # store the mask and update cumulative and summative masks
665
666        for layer_name in self.backbone.weighted_layer_names:
667
668            # average the neuron-wise step importance. See Eq. (4) in the paper
669            self.importances[self.task_id][layer_name] = (
670                self.importances[self.task_id][layer_name]
671            ) / self.num_steps_t
672
673            # add the base importance. See Eq. (6) in the paper
674            self.importances[self.task_id][layer_name] = (
675                self.importances[self.task_id][layer_name] + self.base_importance
676            )
677
678            # filter unmasked importance
679            if self.filter_unmasked_importance:
680                self.importances[self.task_id][layer_name] = (
681                    self.importances[self.task_id][layer_name]
682                    * self.backbone.masks[f"{self.task_id}"][layer_name]
683                )
684
685            # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper
686            if self.importance_summing_strategy == "add_latest":
687                self.summative_importance_for_previous_tasks[
688                    layer_name
689                ] += self.importances[self.task_id][layer_name]
690
691            elif self.importance_summing_strategy == "add_all":
692                for t in range(1, self.task_id + 1):
693                    self.summative_importance_for_previous_tasks[
694                        layer_name
695                    ] += self.importances[t][layer_name]
696
697            elif self.importance_summing_strategy == "add_average":
698                for t in range(1, self.task_id + 1):
699                    self.summative_importance_for_previous_tasks[layer_name] += (
700                        self.importances[t][layer_name] / self.task_id
701                    )
702            else:
703                self.summative_importance_for_previous_tasks[
704                    layer_name
705                ] = torch.zeros_like(
706                    self.summative_importance_for_previous_tasks[layer_name]
707                ).to(
708                    self.device
709                )  # starting adding from 0
710
711                if self.importance_summing_strategy == "linear_decrease":
712                    s = self.importance_summing_strategy_linear_step
713                    for t in range(1, self.task_id + 1):
714                        w_t = s * (self.task_id - t) + 1
715
716                elif self.importance_summing_strategy == "quadratic_decrease":
717                    for t in range(1, self.task_id + 1):
718                        w_t = (self.task_id - t + 1) ** 2
719                elif self.importance_summing_strategy == "cubic_decrease":
720                    for t in range(1, self.task_id + 1):
721                        w_t = (self.task_id - t + 1) ** 3
722                elif self.importance_summing_strategy == "exponential_decrease":
723                    for t in range(1, self.task_id + 1):
724                        r = self.importance_summing_strategy_exponential_rate
725
726                        w_t = r ** (self.task_id - t + 1)
727                elif self.importance_summing_strategy == "log_decrease":
728                    a = self.importance_summing_strategy_log_base
729                    for t in range(1, self.task_id + 1):
730                        w_t = math.log(self.task_id - t, a) + 1
731                elif self.importance_summing_strategy == "factorial_decrease":
732                    for t in range(1, self.task_id + 1):
733                        w_t = math.factorial(self.task_id - t + 1)
734                else:
735                    raise ValueError
736                self.summative_importance_for_previous_tasks[layer_name] += (
737                    self.importances[t][layer_name] * w_t
738                )

Additionally calculate neuron-wise importance for previous tasks at the end of training each task.

def get_importance_step_layer_weight_abs_sum( self: str, layer_name: str, if_output_weight: bool, reciprocal: bool) -> torch.Tensor:
740    def get_importance_step_layer_weight_abs_sum(
741        self: str,
742        layer_name: str,
743        if_output_weight: bool,
744        reciprocal: bool,
745    ) -> Tensor:
746        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.
747
748        **Args:**
749        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
750        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
751        - **reciprocal** (`bool`): whether to take reciprocal.
752
753        **Returns:**
754        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
755        """
756        layer = self.backbone.get_layer_by_name(layer_name)
757
758        if not if_output_weight:
759            weight_abs = torch.abs(layer.weight.data)
760            weight_abs_sum = torch.sum(
761                weight_abs,
762                dim=[
763                    i for i in range(weight_abs.dim()) if i != 0
764                ],  # sum over the input dimension
765            )
766        else:
767            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
768            weight_abs_sum = torch.sum(
769                weight_abs,
770                dim=[
771                    i for i in range(weight_abs.dim()) if i != 1
772                ],  # sum over the output dimension
773            )
774
775        if reciprocal:
776            weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum)
777            importance_step_layer = weight_abs_sum_reciprocal
778        else:
779            importance_step_layer = weight_abs_sum
780        importance_step_layer = importance_step_layer.detach()
781
782        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • if_output_weight (bool): whether to use the output weights or input weights.
  • reciprocal (bool): whether to take reciprocal.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_weight_gradient_abs_sum(self: str, layer_name: str, if_output_weight: bool) -> torch.Tensor:
784    def get_importance_step_layer_weight_gradient_abs_sum(
785        self: str,
786        layer_name: str,
787        if_output_weight: bool,
788    ) -> Tensor:
789        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.
790
791        **Args:**
792        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
793        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
794
795        **Returns:**
796        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
797        """
798        layer = self.backbone.get_layer_by_name(layer_name)
799
800        if not if_output_weight:
801            gradient_abs = torch.abs(layer.weight.grad.data)
802            gradient_abs_sum = torch.sum(
803                gradient_abs,
804                dim=[
805                    i for i in range(gradient_abs.dim()) if i != 0
806                ],  # sum over the input dimension
807            )
808        else:
809            gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data)
810            gradient_abs_sum = torch.sum(
811                gradient_abs,
812                dim=[
813                    i for i in range(gradient_abs.dim()) if i != 1
814                ],  # sum over the output dimension
815            )
816
817        importance_step_layer = gradient_abs_sum
818        importance_step_layer = importance_step_layer.detach()
819
820        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • if_output_weight (bool): whether to use the output weights or input weights.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_activation_abs(self: str, activation: torch.Tensor) -> torch.Tensor:
822    def get_importance_step_layer_activation_abs(
823        self: str,
824        activation: Tensor,
825    ) -> Tensor:
826        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum.
827
828        **Args:**
829        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
830
831        **Returns:**
832        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
833        """
834        activation_abs_batch_mean = torch.mean(
835            torch.abs(activation),
836            dim=[
837                i for i in range(activation.dim()) if i != 1
838            ],  # average the features over batch samples
839        )
840        importance_step_layer = activation_abs_batch_mean
841        importance_step_layer = importance_step_layer.detach()
842
843        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of Layer Activation in Captum.

Args:

  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_weight_abs_sum_x_activation_abs( self: str, layer_name: str, activation: torch.Tensor, if_output_weight: bool) -> torch.Tensor:
845    def get_importance_step_layer_weight_abs_sum_x_activation_abs(
846        self: str,
847        layer_name: str,
848        activation: Tensor,
849        if_output_weight: bool,
850    ) -> Tensor:
851        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
852
853        **Args:**
854        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
855        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
856        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
857
858        **Returns:**
859        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
860        """
861        layer = self.backbone.get_layer_by_name(layer_name)
862
863        if not if_output_weight:
864            weight_abs = torch.abs(layer.weight.data)
865            weight_abs_sum = torch.sum(
866                weight_abs,
867                dim=[
868                    i for i in range(weight_abs.dim()) if i != 0
869                ],  # sum over the input dimension
870            )
871        else:
872            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
873            weight_abs_sum = torch.sum(
874                weight_abs,
875                dim=[
876                    i for i in range(weight_abs.dim()) if i != 1
877                ],  # sum over the output dimension
878            )
879
880        activation_abs_batch_mean = torch.mean(
881            torch.abs(activation),
882            dim=[
883                i for i in range(activation.dim()) if i != 1
884            ],  # average the features over batch samples
885        )
886
887        importance_step_layer = weight_abs_sum * activation_abs_batch_mean
888        importance_step_layer = importance_step_layer.detach()
889
890        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in CBP.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
  • if_output_weight (bool): whether to use the output weights or input weights.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_gradient_x_activation_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
892    def get_importance_step_layer_gradient_x_activation_abs(
893        self: str,
894        layer_name: str,
895        input: Tensor | tuple[Tensor, ...],
896        target: Tensor | None,
897        batch_idx: int,
898        num_batches: int,
899    ) -> Tensor:
900        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum.
901
902        **Args:**
903        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
904        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
905        - **target** (`Tensor` | `None`): the target batch of the training step.
906        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
907        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
908
909        **Returns:**
910        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
911        """
912        layer = self.backbone.get_layer_by_name(layer_name)
913
914        input = input.requires_grad_()
915
916        # initialize the Layer Gradient X Activation object
917        layer_gradient_x_activation = LayerGradientXActivation(
918            forward_func=self.forward, layer=layer
919        )
920
921        self.set_forward_func_return_logits_only(True)
922        # calculate layer attribution of the step
923        attribution = layer_gradient_x_activation.attribute(
924            inputs=input,
925            target=target,
926            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
927        )
928        self.set_forward_func_return_logits_only(False)
929
930        attribution_abs_batch_mean = torch.mean(
931            torch.abs(attribution),
932            dim=[
933                i for i in range(attribution.dim()) if i != 1
934            ],  # average the features over batch samples
935        )
936
937        importance_step_layer = attribution_abs_batch_mean
938        importance_step_layer = importance_step_layer.detach()
939
940        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using Layer Gradient X Activation in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_weight_gradient_square_sum( self: str, layer_name: str, activation: torch.Tensor, if_output_weight: bool) -> torch.Tensor:
942    def get_importance_step_layer_weight_gradient_square_sum(
943        self: str,
944        layer_name: str,
945        activation: Tensor,
946        if_output_weight: bool,
947    ) -> Tensor:
948        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
949
950        **Args:**
951        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
952        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
953        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
954
955        **Returns:**
956        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
957        """
958        layer = self.backbone.get_layer_by_name(layer_name)
959
960        if not if_output_weight:
961            gradient_square = layer.weight.grad.data**2
962            gradient_square_sum = torch.sum(
963                gradient_square,
964                dim=[
965                    i for i in range(gradient_square.dim()) if i != 0
966                ],  # sum over the input dimension
967            )
968        else:
969            gradient_square = self.next_layer(layer_name).weight.grad.data**2
970            gradient_square_sum = torch.sum(
971                gradient_square,
972                dim=[
973                    i for i in range(gradient_square.dim()) if i != 1
974                ],  # sum over the output dimension
975            )
976
977        importance_step_layer = gradient_square_sum
978        importance_step_layer = importance_step_layer.detach()
979
980        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in EWC.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
  • if_output_weight (bool): whether to use the output weights or input weights.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( self: str, layer_name: str, activation: torch.Tensor, if_output_weight: bool) -> torch.Tensor:
 982    def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 983        self: str,
 984        layer_name: str,
 985        activation: Tensor,
 986        if_output_weight: bool,
 987    ) -> Tensor:
 988        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 989
 990        **Args:**
 991        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 992        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 993        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 994
 995        **Returns:**
 996        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 997        """
 998        layer = self.backbone.get_layer_by_name(layer_name)
 999
1000        if not if_output_weight:
1001            gradient_square = layer.weight.grad.data**2
1002            gradient_square_sum = torch.sum(
1003                gradient_square,
1004                dim=[
1005                    i for i in range(gradient_square.dim()) if i != 0
1006                ],  # sum over the input dimension
1007            )
1008        else:
1009            gradient_square = self.next_layer(layer_name).weight.grad.data**2
1010            gradient_square_sum = torch.sum(
1011                gradient_square,
1012                dim=[
1013                    i for i in range(gradient_square.dim()) if i != 1
1014                ],  # sum over the output dimension
1015            )
1016
1017        activation_abs_batch_mean = torch.mean(
1018            torch.abs(activation),
1019            dim=[
1020                i for i in range(activation.dim()) if i != 1
1021            ],  # average the features over batch samples
1022        )
1023
1024        importance_step_layer = gradient_square_sum * activation_abs_batch_mean
1025        importance_step_layer = importance_step_layer.detach()
1026
1027        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in EWC.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
  • if_output_weight (bool): whether to use the output weights or input weights.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_conductance_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1029    def get_importance_step_layer_conductance_abs(
1030        self: str,
1031        layer_name: str,
1032        input: Tensor | tuple[Tensor, ...],
1033        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1034        target: Tensor | None,
1035        batch_idx: int,
1036        num_batches: int,
1037    ) -> Tensor:
1038        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum.
1039
1040        **Args:**
1041        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1042        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1043        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details.
1044        - **target** (`Tensor` | `None`): the target batch of the training step.
1045        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1046        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).
1047
1048        **Returns:**
1049        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1050        """
1051        layer = self.backbone.get_layer_by_name(layer_name)
1052
1053        # initialize the Layer Conductance object
1054        layer_conductance = LayerConductance(forward_func=self.forward, layer=layer)
1055
1056        self.set_forward_func_return_logits_only(True)
1057        # calculate layer attribution of the step
1058        attribution = layer_conductance.attribute(
1059            inputs=input,
1060            baselines=baselines,
1061            target=target,
1062            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1063        )
1064        self.set_forward_func_return_logits_only(False)
1065
1066        attribution_abs_batch_mean = torch.mean(
1067            torch.abs(attribution),
1068            dim=[
1069                i for i in range(attribution.dim()) if i != 1
1070            ],  # average the features over batch samples
1071        )
1072
1073        importance_step_layer = attribution_abs_batch_mean
1074        importance_step_layer = importance_step_layer.detach()
1075
1076        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of conductance. We implement this using Layer Conductance in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): starting point from which integral is computed in this method. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.- mask (Tensor): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_internal_influence_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1078    def get_importance_step_layer_internal_influence_abs(
1079        self: str,
1080        layer_name: str,
1081        input: Tensor | tuple[Tensor, ...],
1082        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1083        target: Tensor | None,
1084        batch_idx: int,
1085        num_batches: int,
1086    ) -> Tensor:
1087        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum.
1088
1089        **Args:**
1090        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1091        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1092        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details.
1093        - **target** (`Tensor` | `None`): the target batch of the training step.
1094        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1095        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1096
1097        **Returns:**
1098        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1099        """
1100        layer = self.backbone.get_layer_by_name(layer_name)
1101
1102        # initialize the Internal Influence object
1103        internal_influence = InternalInfluence(forward_func=self.forward, layer=layer)
1104
1105        # convert the target to long type to avoid error
1106        target = target.long() if target is not None else None
1107
1108        self.set_forward_func_return_logits_only(True)
1109        # calculate layer attribution of the step
1110        attribution = internal_influence.attribute(
1111            inputs=input,
1112            baselines=baselines,
1113            target=target,
1114            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1115            n_steps=5,  # set 10 instead of default 50 to accelerate the computation
1116        )
1117        self.set_forward_func_return_logits_only(False)
1118
1119        attribution_abs_batch_mean = torch.mean(
1120            torch.abs(attribution),
1121            dim=[
1122                i for i in range(attribution.dim()) if i != 1
1123            ],  # average the features over batch samples
1124        )
1125
1126        importance_step_layer = attribution_abs_batch_mean
1127        importance_step_layer = importance_step_layer.detach()
1128
1129        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of internal influence. We implement this using Internal Influence in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): starting point from which integral is computed in this method. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_gradcam_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1131    def get_importance_step_layer_gradcam_abs(
1132        self: str,
1133        layer_name: str,
1134        input: Tensor | tuple[Tensor, ...],
1135        target: Tensor | None,
1136        batch_idx: int,
1137        num_batches: int,
1138    ) -> Tensor:
1139        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum.
1140
1141        **Args:**
1142        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1143        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1144        - **target** (`Tensor` | `None`): the target batch of the training step.
1145        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1146        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1147
1148        **Returns:**
1149        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1150        """
1151        layer = self.backbone.get_layer_by_name(layer_name)
1152
1153        # initialize the GradCAM object
1154        gradcam = LayerGradCam(forward_func=self.forward, layer=layer)
1155
1156        self.set_forward_func_return_logits_only(True)
1157        # calculate layer attribution of the step
1158        attribution = gradcam.attribute(
1159            inputs=input,
1160            target=target,
1161            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1162        )
1163        self.set_forward_func_return_logits_only(False)
1164
1165        attribution_abs_batch_mean = torch.mean(
1166            torch.abs(attribution),
1167            dim=[
1168                i for i in range(attribution.dim()) if i != 1
1169            ],  # average the features over batch samples
1170        )
1171
1172        importance_step_layer = attribution_abs_batch_mean
1173        importance_step_layer = importance_step_layer.detach()
1174
1175        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of Grad-CAM. We implement this using Layer Grad-CAM in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_deeplift_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1177    def get_importance_step_layer_deeplift_abs(
1178        self: str,
1179        layer_name: str,
1180        input: Tensor | tuple[Tensor, ...],
1181        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1182        target: Tensor | None,
1183        batch_idx: int,
1184        num_batches: int,
1185    ) -> Tensor:
1186        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum.
1187
1188        **Args:**
1189        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1190        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1191        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details.
1192        - **target** (`Tensor` | `None`): the target batch of the training step.
1193        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1194        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1195
1196        **Returns:**
1197        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1198        """
1199        layer = self.backbone.get_layer_by_name(layer_name)
1200
1201        # initialize the Layer DeepLift object
1202        layer_deeplift = LayerDeepLift(model=self, layer=layer)
1203
1204        # convert the target to long type to avoid error
1205        target = target.long() if target is not None else None
1206
1207        self.set_forward_func_return_logits_only(True)
1208        # calculate layer attribution of the step
1209        attribution = layer_deeplift.attribute(
1210            inputs=input,
1211            baselines=baselines,
1212            target=target,
1213            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1214        )
1215        self.set_forward_func_return_logits_only(False)
1216
1217        attribution_abs_batch_mean = torch.mean(
1218            torch.abs(attribution),
1219            dim=[
1220                i for i in range(attribution.dim()) if i != 1
1221            ],  # average the features over batch samples
1222        )
1223
1224        importance_step_layer = attribution_abs_batch_mean
1225        importance_step_layer = importance_step_layer.detach()
1226
1227        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of DeepLift. We implement this using Layer DeepLift in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): baselines define reference samples that are compared with the inputs. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_deepliftshap_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1229    def get_importance_step_layer_deepliftshap_abs(
1230        self: str,
1231        layer_name: str,
1232        input: Tensor | tuple[Tensor, ...],
1233        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1234        target: Tensor | None,
1235        batch_idx: int,
1236        num_batches: int,
1237    ) -> Tensor:
1238        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum.
1239
1240        **Args:**
1241        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1242        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1243        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details.
1244        - **target** (`Tensor` | `None`): the target batch of the training step.
1245        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1246        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1247
1248        **Returns:**
1249        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1250        """
1251        layer = self.backbone.get_layer_by_name(layer_name)
1252
1253        # initialize the Layer DeepLiftShap object
1254        layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer)
1255
1256        # convert the target to long type to avoid error
1257        target = target.long() if target is not None else None
1258
1259        self.set_forward_func_return_logits_only(True)
1260        # calculate layer attribution of the step
1261        attribution = layer_deepliftshap.attribute(
1262            inputs=input,
1263            baselines=baselines,
1264            target=target,
1265            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1266        )
1267        self.set_forward_func_return_logits_only(False)
1268
1269        attribution_abs_batch_mean = torch.mean(
1270            torch.abs(attribution),
1271            dim=[
1272                i for i in range(attribution.dim()) if i != 1
1273            ],  # average the features over batch samples
1274        )
1275
1276        importance_step_layer = attribution_abs_batch_mean
1277        importance_step_layer = importance_step_layer.detach()
1278
1279        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of DeepLift SHAP. We implement this using Layer DeepLiftShap in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): baselines define reference samples that are compared with the inputs. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_gradientshap_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1281    def get_importance_step_layer_gradientshap_abs(
1282        self: str,
1283        layer_name: str,
1284        input: Tensor | tuple[Tensor, ...],
1285        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1286        target: Tensor | None,
1287        batch_idx: int,
1288        num_batches: int,
1289    ) -> Tensor:
1290        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum.
1291
1292        **Args:**
1293        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1294        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1295        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero.
1296        - **target** (`Tensor` | `None`): the target batch of the training step.
1297        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1298        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1299
1300        **Returns:**
1301        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1302        """
1303        layer = self.backbone.get_layer_by_name(layer_name)
1304
1305        if baselines is None:
1306            baselines = torch.zeros_like(
1307                input
1308            )  # baselines are mandatory for GradientShap API. We explicitly set them to zero
1309
1310        # initialize the Layer GradientShap object
1311        layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer)
1312
1313        # convert the target to long type to avoid error
1314        target = target.long() if target is not None else None
1315
1316        self.set_forward_func_return_logits_only(True)
1317        # calculate layer attribution of the step
1318        attribution = layer_gradientshap.attribute(
1319            inputs=input,
1320            baselines=baselines,
1321            target=target,
1322            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1323        )
1324        self.set_forward_func_return_logits_only(False)
1325
1326        attribution_abs_batch_mean = torch.mean(
1327            torch.abs(attribution),
1328            dim=[
1329                i for i in range(attribution.dim()) if i != 1
1330            ],  # average the features over batch samples
1331        )
1332
1333        importance_step_layer = attribution_abs_batch_mean
1334        importance_step_layer = importance_step_layer.detach()
1335
1336        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using Layer GradientShap in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): starting point from which expectation is computed. Please refer to the Captum documentation for more details. If None, the baselines are set to zero.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_integrated_gradients_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1338    def get_importance_step_layer_integrated_gradients_abs(
1339        self: str,
1340        layer_name: str,
1341        input: Tensor | tuple[Tensor, ...],
1342        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1343        target: Tensor | None,
1344        batch_idx: int,
1345        num_batches: int,
1346    ) -> Tensor:
1347        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum.
1348
1349        **Args:**
1350        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1351        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1352        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details.
1353        - **target** (`Tensor` | `None`): the target batch of the training step.
1354        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1355        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1356
1357        **Returns:**
1358        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1359        """
1360        layer = self.backbone.get_layer_by_name(layer_name)
1361
1362        # initialize the Layer Integrated Gradients object
1363        layer_integrated_gradients = LayerIntegratedGradients(
1364            forward_func=self.forward, layer=layer
1365        )
1366
1367        self.set_forward_func_return_logits_only(True)
1368        # calculate layer attribution of the step
1369        attribution = layer_integrated_gradients.attribute(
1370            inputs=input,
1371            baselines=baselines,
1372            target=target,
1373            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1374        )
1375        self.set_forward_func_return_logits_only(False)
1376
1377        attribution_abs_batch_mean = torch.mean(
1378            torch.abs(attribution),
1379            dim=[
1380                i for i in range(attribution.dim()) if i != 1
1381            ],  # average the features over batch samples
1382        )
1383
1384        importance_step_layer = attribution_abs_batch_mean
1385        importance_step_layer = importance_step_layer.detach()
1386
1387        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of integrated gradients. We implement this using Layer Integrated Gradients in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): starting point from which integral is computed. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_feature_ablation_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], layer_baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int, if_captum: bool = False) -> torch.Tensor:
1389    def get_importance_step_layer_feature_ablation_abs(
1390        self: str,
1391        layer_name: str,
1392        input: Tensor | tuple[Tensor, ...],
1393        layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1394        target: Tensor | None,
1395        batch_idx: int,
1396        num_batches: int,
1397        if_captum: bool = False,
1398    ) -> Tensor:
1399        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum.
1400
1401        **Args:**
1402        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1403        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1404        - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details.
1405        - **target** (`Tensor` | `None`): the target batch of the training step.
1406        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1407        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1408        - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster.
1409
1410        **Returns:**
1411        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1412        """
1413        layer = self.backbone.get_layer_by_name(layer_name)
1414
1415        if not if_captum:
1416            # 1. Baseline logits (take first element of forward output)
1417            baseline_out, _, _ = self.forward(
1418                input, "train", batch_idx, num_batches, self.task_id
1419            )
1420            if target is not None:
1421                baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1)
1422            else:
1423                baseline_scores = baseline_out.sum(dim=1)
1424
1425            # 2. Capture layer’s output shape
1426            activs = {}
1427            handle = layer.register_forward_hook(
1428                lambda module, inp, out: activs.setdefault("output", out.detach())
1429            )
1430            _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id)
1431            handle.remove()
1432            layer_output = activs["output"]  # shape (B, F, ...)
1433
1434            # 3. Build baseline tensor matching that shape
1435            if layer_baselines is None:
1436                baseline_tensor = torch.zeros_like(layer_output)
1437            elif isinstance(layer_baselines, (int, float)):
1438                baseline_tensor = torch.full_like(layer_output, layer_baselines)
1439            elif isinstance(layer_baselines, Tensor):
1440                if layer_baselines.shape == layer_output.shape:
1441                    baseline_tensor = layer_baselines
1442                elif layer_baselines.shape == layer_output.shape[1:]:
1443                    baseline_tensor = layer_baselines.unsqueeze(0).repeat(
1444                        layer_output.size(0), *([1] * layer_baselines.ndim)
1445                    )
1446                else:
1447                    raise ValueError(...)
1448            else:
1449                raise ValueError(...)
1450
1451            B, F = layer_output.size(0), layer_output.size(1)
1452
1453            # 4. Create a “mega-batch” replicating the input F times
1454            if isinstance(input, tuple):
1455                mega_inputs = tuple(
1456                    t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:])
1457                    for t in input
1458                )
1459            else:
1460                mega_inputs = (
1461                    input.unsqueeze(0)
1462                    .repeat(F, *([1] * input.ndim))
1463                    .view(-1, *input.shape[1:])
1464                )
1465
1466            # 5. Equally replicate the baseline tensor
1467            mega_baseline = (
1468                baseline_tensor.unsqueeze(0)
1469                .repeat(F, *([1] * baseline_tensor.ndim))
1470                .view(-1, *baseline_tensor.shape[1:])
1471            )
1472
1473            # 6. Precompute vectorized indices
1474            device = layer_output.device
1475            positions = torch.arange(F * B, device=device)  # [0,1,...,F*B-1]
1476            feat_idx = torch.arange(F, device=device).repeat_interleave(
1477                B
1478            )  # [0,0,...,1,1,...,F-1]
1479
1480            # 7. One hook to zero out each channel slice across the mega-batch
1481            def mega_ablate_hook(module, inp, out):
1482                out_mod = out.clone()
1483                # for each sample in mega-batch, zero its corresponding channel
1484                out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx]
1485                return out_mod
1486
1487            h = layer.register_forward_hook(mega_ablate_hook)
1488            out_all, _, _ = self.forward(
1489                mega_inputs, "train", batch_idx, num_batches, self.task_id
1490            )
1491            h.remove()
1492
1493            # 8. Recover scores, reshape [F*B] → [F, B], diff & mean
1494            if target is not None:
1495                tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1)
1496                scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1)
1497            else:
1498                scores_all = out_all.sum(dim=1)
1499
1500            scores_all = scores_all.view(F, B)
1501            diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all)
1502            importance_step_layer = diffs.mean(dim=1).detach()  # [F]
1503
1504            return importance_step_layer
1505
1506        else:
1507            # initialize the Layer Feature Ablation object
1508            layer_feature_ablation = LayerFeatureAblation(
1509                forward_func=self.forward, layer=layer
1510            )
1511
1512            # calculate layer attribution of the step
1513            self.set_forward_func_return_logits_only(True)
1514            attribution = layer_feature_ablation.attribute(
1515                inputs=input,
1516                layer_baselines=layer_baselines,
1517                # target=target, # disable target to enable perturbations_per_eval
1518                additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1519                perturbations_per_eval=128,  # to accelerate the computation
1520            )
1521            self.set_forward_func_return_logits_only(False)
1522
1523            attribution_abs_batch_mean = torch.mean(
1524                torch.abs(attribution),
1525                dim=[
1526                    i for i in range(attribution.dim()) if i != 1
1527                ],  # average the features over batch samples
1528            )
1529
1530        importance_step_layer = attribution_abs_batch_mean
1531        importance_step_layer = importance_step_layer.detach()
1532
1533        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of feature ablation attribution. We implement this using Layer Feature Ablation in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • layer_baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): reference values which replace each layer input / output value when ablated. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.
  • if_captum (bool): whether to use Captum or not. If True, we use Captum to calculate the feature ablation. If False, we use our implementation. Default is False, because our implementation is much faster.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_lrp_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1535    def get_importance_step_layer_lrp_abs(
1536        self: str,
1537        layer_name: str,
1538        input: Tensor | tuple[Tensor, ...],
1539        target: Tensor | None,
1540        batch_idx: int,
1541        num_batches: int,
1542    ) -> Tensor:
1543        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum.
1544
1545        **Args:**
1546        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1547        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1548        - **target** (`Tensor` | `None`): the target batch of the training step.
1549        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1550        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1551
1552        **Returns:**
1553        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1554        """
1555        layer = self.backbone.get_layer_by_name(layer_name)
1556
1557        # initialize the Layer LRP object
1558        layer_lrp = LayerLRP(model=self, layer=layer)
1559
1560        # set model to evaluation mode to prevent updating the model parameters
1561        self.eval()
1562
1563        self.set_forward_func_return_logits_only(True)
1564        # calculate layer attribution of the step
1565        attribution = layer_lrp.attribute(
1566            inputs=input,
1567            target=target,
1568            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1569        )
1570        self.set_forward_func_return_logits_only(False)
1571
1572        attribution_abs_batch_mean = torch.mean(
1573            torch.abs(attribution),
1574            dim=[
1575                i for i in range(attribution.dim()) if i != 1
1576            ],  # average the features over batch samples
1577        )
1578
1579        importance_step_layer = attribution_abs_batch_mean
1580        importance_step_layer = importance_step_layer.detach()
1581
1582        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of LRP. We implement this using Layer LRP in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_cbp_adaptive_contribution(self: str, layer_name: str, activation: torch.Tensor) -> torch.Tensor:
1584    def get_importance_step_layer_cbp_adaptive_contribution(
1585        self: str,
1586        layer_name: str,
1587        activation: Tensor,
1588    ) -> Tensor:
1589        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
1590
1591        **Args:**
1592        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1593        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
1594
1595        **Returns:**
1596        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1597        """
1598        layer = self.backbone.get_layer_by_name(layer_name)
1599
1600        input_weight_abs = torch.abs(layer.weight.data)
1601        input_weight_abs_sum = torch.sum(
1602            input_weight_abs,
1603            dim=[
1604                i for i in range(input_weight_abs.dim()) if i != 0
1605            ],  # sum over the input dimension
1606        )
1607        input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum)
1608
1609        output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
1610        output_weight_abs_sum = torch.sum(
1611            output_weight_abs,
1612            dim=[
1613                i for i in range(output_weight_abs.dim()) if i != 1
1614            ],  # sum over the output dimension
1615        )
1616
1617        activation_abs_batch_mean = torch.mean(
1618            torch.abs(activation),
1619            dim=[
1620                i for i in range(activation.dim()) if i != 1
1621            ],  # average the features over batch samples
1622        )
1623
1624        importance_step_layer = (
1625            output_weight_abs_sum
1626            * activation_abs_batch_mean
1627            * input_weight_abs_sum_reciprocal
1628        )
1629        importance_step_layer = importance_step_layer.detach()
1630
1631        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in CBP.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.