clarena.cl_algorithms.fgadahat

The submodule in cl_algorithms for FG-AdaHAT algorithm.

   1r"""
   2The submodule in `cl_algorithms` for FG-AdaHAT algorithm.
   3"""
   4
   5__all__ = ["FGAdaHAT"]
   6
   7import logging
   8import math
   9from typing import Any
  10
  11import torch
  12from captum.attr import (
  13    InternalInfluence,
  14    LayerConductance,
  15    LayerDeepLift,
  16    LayerDeepLiftShap,
  17    LayerFeatureAblation,
  18    LayerGradCam,
  19    LayerGradientShap,
  20    LayerGradientXActivation,
  21    LayerIntegratedGradients,
  22    LayerLRP,
  23)
  24from torch import Tensor
  25
  26from clarena.backbones import HATMaskBackbone
  27from clarena.cl_algorithms import AdaHAT
  28from clarena.heads import HeadsTIL
  29from clarena.utils.metrics import HATNetworkCapacityMetric
  30from clarena.utils.transforms import min_max_normalize
  31
  32# always get logger for built-in logging in each module
  33pylogger = logging.getLogger(__name__)
  34
  35
  36class FGAdaHAT(AdaHAT):
  37    r"""FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm.
  38
  39    An architecture-based continual learning approach that improves [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
  40
  41    We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components.
  42    """
  43
  44    def __init__(
  45        self,
  46        backbone: HATMaskBackbone,
  47        heads: HeadsTIL,
  48        adjustment_intensity: float,
  49        importance_type: str,
  50        importance_summing_strategy: str,
  51        importance_scheduler_type: str,
  52        neuron_to_weight_importance_aggregation_mode: str,
  53        s_max: float,
  54        clamp_threshold: float,
  55        mask_sparsity_reg_factor: float,
  56        mask_sparsity_reg_mode: str = "original",
  57        base_importance: float = 0.01,
  58        base_mask_sparsity_reg: float = 0.1,
  59        base_linear: float = 10,
  60        filter_by_cumulative_mask: bool = False,
  61        filter_unmasked_importance: bool = True,
  62        step_multiply_training_mask: bool = True,
  63        task_embedding_init_mode: str = "N01",
  64        importance_summing_strategy_linear_step: float | None = None,
  65        importance_summing_strategy_exponential_rate: float | None = None,
  66        importance_summing_strategy_log_base: float | None = None,
  67        non_algorithmic_hparams: dict[str, Any] = {},
  68    ) -> None:
  69        r"""Initialize the FG-AdaHAT algorithm with the network.
  70
  71        **Args:**
  72        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
  73        - **heads** (`HeadsTIL`): output heads. FG-AdaHAT supports only TIL (Task-Incremental Learning).
  74        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper).
  75        - **importance_type** (`str`): the type of neuron-wise importance, must be one of:
  76            1. 'input_weight_abs_sum': sum of absolute input weights;
  77            2. 'output_weight_abs_sum': sum of absolute output weights;
  78            3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
  79            4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
  80            5. 'activation_abs': absolute activation;
  81            6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
  82            7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
  83            8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
  84            9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
  85            10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
  86            11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
  87            12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
  88            13. 'conductance_abs': absolute layer conductance;
  89            14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
  90            15. 'gradcam_abs': absolute Grad-CAM;
  91            16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
  92            17. 'deepliftshap_abs': absolute DeepLIFT-SHAP;
  93            18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
  94            19. 'integrated_gradients_abs': absolute Integrated Gradients;
  95            20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
  96            21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
  97            22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
  98            23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
  99        - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of:
 100            1. 'add_latest': add the latest neuron-wise importance to the summative importance;
 101            2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
 102            3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
 103            4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
 104            5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
 105            6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
 106            7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
 107            8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
 108            9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
 109        - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:
 110            1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`);
 111            2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$;
 112            3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$.
 113        - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:
 114            1. 'min': take the minimum of neuron-wise importance for each weight;
 115            2. 'max': take the maximum of neuron-wise importance for each weight;
 116            3. 'mean': take the mean of neuron-wise importance for each weight.
 117        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 118        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 119        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 120        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 121            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 122            2. 'cross': the cross version of mask sparsity regularization.
 123        - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01.
 124        - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1.
 125        - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10.
 126        - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False.
 127        - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False.
 128        - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True.
 129        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 130            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 131            2. 'U-11': uniform distribution $U(-1, 1)$.
 132            3. 'U01': uniform distribution $U(0, 1)$.
 133            4. 'U-10': uniform distribution $U(-1, 0)$.
 134            5. 'last': inherit the task embedding from the last task.
 135        - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0.
 136        - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1.
 137        - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1.
 138        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 139
 140        """
 141        super().__init__(
 142            backbone=backbone,
 143            heads=heads,
 144            adjustment_mode=None,  # use the own adjustment mechanism of FG-AdaHAT
 145            adjustment_intensity=adjustment_intensity,
 146            s_max=s_max,
 147            clamp_threshold=clamp_threshold,
 148            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 149            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 150            task_embedding_init_mode=task_embedding_init_mode,
 151            epsilon=base_mask_sparsity_reg,  # the epsilon is now the base mask sparsity regularization factor
 152            non_algorithmic_hparams=non_algorithmic_hparams,
 153        )
 154
 155        # save additional algorithmic hyperparameters
 156        self.save_hyperparameters(
 157            "adjustment_intensity",
 158            "importance_type",
 159            "importance_summing_strategy",
 160            "importance_scheduler_type",
 161            "neuron_to_weight_importance_aggregation_mode",
 162            "s_max",
 163            "clamp_threshold",
 164            "mask_sparsity_reg_factor",
 165            "mask_sparsity_reg_mode",
 166            "base_importance",
 167            "base_mask_sparsity_reg",
 168            "base_linear",
 169            "filter_by_cumulative_mask",
 170            "filter_unmasked_importance",
 171            "step_multiply_training_mask",
 172        )
 173
 174        self.importance_type: str | None = importance_type
 175        r"""The type of the neuron-wise importance added to AdaHAT importance."""
 176
 177        self.importance_scheduler_type: str = importance_scheduler_type
 178        r"""The type of the importance scheduler."""
 179        self.neuron_to_weight_importance_aggregation_mode: str = (
 180            neuron_to_weight_importance_aggregation_mode
 181        )
 182        r"""The mode of aggregation from neuron-wise to weight-wise importance. """
 183        self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask
 184        r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate."""
 185        self.filter_unmasked_importance: bool = filter_unmasked_importance
 186        r"""The flag to filter unmasked importance values (set them to 0) at the end of task training."""
 187        self.step_multiply_training_mask: bool = step_multiply_training_mask
 188        r"""The flag to multiply the training mask to the importance at each training step."""
 189
 190        # importance summing strategy
 191        self.importance_summing_strategy: str = importance_summing_strategy
 192        r"""The strategy to sum the neuron-wise importance for previous tasks."""
 193        if importance_summing_strategy_linear_step is not None:
 194            self.importance_summing_strategy_linear_step: float = (
 195                importance_summing_strategy_linear_step
 196            )
 197            r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease')."""
 198        if importance_summing_strategy_exponential_rate is not None:
 199            self.importance_summing_strategy_exponential_rate: float = (
 200                importance_summing_strategy_exponential_rate
 201            )
 202            r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """
 203        if importance_summing_strategy_log_base is not None:
 204            self.importance_summing_strategy_log_base: float = (
 205                importance_summing_strategy_log_base
 206            )
 207            r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """
 208
 209        # base values
 210        self.base_importance: float = base_importance
 211        r"""The base value added to the importance to avoid zero. """
 212        self.base_mask_sparsity_reg: float = base_mask_sparsity_reg
 213        r"""The base value added to the mask sparsity regularization to avoid zero. """
 214        self.base_linear: float = base_linear
 215        r"""The base value added to the linear layer to avoid zero. """
 216
 217        self.importances: dict[int, dict[str, Tensor]] = {}
 218        r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
 219        self.summative_importance_for_previous_tasks: dict[str, Tensor] = {}
 220        r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """
 221
 222        self.num_steps_t: int
 223        r"""The number of training steps for the current task `self.task_id`."""
 224        # set manual optimization
 225        self.automatic_optimization = False
 226
 227        FGAdaHAT.sanity_check(self)
 228
 229    def sanity_check(self) -> None:
 230        r"""Sanity check."""
 231
 232        # check importance type
 233        if self.importance_type not in [
 234            "input_weight_abs_sum",
 235            "output_weight_abs_sum",
 236            "input_weight_gradient_abs_sum",
 237            "output_weight_gradient_abs_sum",
 238            "activation_abs",
 239            "input_weight_abs_sum_x_activation_abs",
 240            "output_weight_abs_sum_x_activation_abs",
 241            "gradient_x_activation_abs",
 242            "input_weight_gradient_square_sum",
 243            "output_weight_gradient_square_sum",
 244            "input_weight_gradient_square_sum_x_activation_abs",
 245            "output_weight_gradient_square_sum_x_activation_abs",
 246            "conductance_abs",
 247            "internal_influence_abs",
 248            "gradcam_abs",
 249            "deeplift_abs",
 250            "deepliftshap_abs",
 251            "gradientshap_abs",
 252            "integrated_gradients_abs",
 253            "feature_ablation_abs",
 254            "lrp_abs",
 255            "cbp_adaptation",
 256            "cbp_adaptive_contribution",
 257        ]:
 258            raise ValueError(
 259                f"importance_type must be one of the predefined types, but got {self.importance_type}"
 260            )
 261
 262        # check importance summing strategy
 263        if self.importance_summing_strategy not in [
 264            "add_latest",
 265            "add_all",
 266            "add_average",
 267            "linear_decrease",
 268            "quadratic_decrease",
 269            "cubic_decrease",
 270            "exponential_decrease",
 271            "log_decrease",
 272            "factorial_decrease",
 273        ]:
 274            raise ValueError(
 275                f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}"
 276            )
 277
 278        # check importance scheduler type
 279        if self.importance_scheduler_type not in [
 280            "linear_sparsity_reg",
 281            "sparsity_reg",
 282            "summative_mask_sparsity_reg",
 283        ]:
 284            raise ValueError(
 285                f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}"
 286            )
 287
 288        # check neuron to weight importance aggregation mode
 289        if self.neuron_to_weight_importance_aggregation_mode not in [
 290            "min",
 291            "max",
 292            "mean",
 293        ]:
 294            raise ValueError(
 295                f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}"
 296            )
 297
 298        # check base values
 299        if self.base_importance < 0:
 300            raise ValueError(
 301                f"base_importance must be >= 0, but got {self.base_importance}"
 302            )
 303        if self.base_mask_sparsity_reg <= 0:
 304            raise ValueError(
 305                f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}"
 306            )
 307        if self.base_linear <= 0:
 308            raise ValueError(f"base_linear must be > 0, but got {self.base_linear}")
 309
 310    def on_train_start(self) -> None:
 311        r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization."""
 312        super().on_train_start()
 313
 314        self.importances[self.task_id] = (
 315            {}
 316        )  # initialize the importance for the current task
 317
 318        # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time.
 319        for layer_name in self.backbone.weighted_layer_names:
 320            layer = self.backbone.get_layer_by_name(
 321                layer_name
 322            )  # get the layer by its name
 323            num_units = layer.weight.shape[0]
 324
 325            # initialize the accumulated importance at the beginning of each task
 326            self.importances[self.task_id][layer_name] = torch.zeros(num_units).to(
 327                self.device
 328            )
 329
 330            # reset the number of steps counter for the current task
 331            self.num_steps_t = 0
 332
 333            # initialize the summative neuron-wise importance at the beginning of the first task
 334            if self.task_id == 1:
 335                self.summative_importance_for_previous_tasks[layer_name] = torch.zeros(
 336                    num_units
 337                ).to(
 338                    self.device
 339                )  # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$
 340
 341    def clip_grad_by_adjustment(
 342        self,
 343        network_sparsity: dict[str, Tensor],
 344    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
 345        r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper.
 346
 347        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
 348
 349        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 350
 351        **Args:**
 352        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.
 353
 354        **Returns:**
 355        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
 356        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
 357        - **capacity** (`Tensor`): the calculated network capacity.
 358        """
 359
 360        # initialize network capacity metric
 361        capacity = HATNetworkCapacityMetric().to(self.device)
 362        adjustment_rate_weight = {}
 363        adjustment_rate_bias = {}
 364
 365        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper
 366        for layer_name in self.backbone.weighted_layer_names:
 367
 368            layer = self.backbone.get_layer_by_name(
 369                layer_name
 370            )  # get the layer by its name
 371
 372            # placeholder for the adjustment rate to avoid the error of using it before assignment
 373            adjustment_rate_weight_layer = 1
 374            adjustment_rate_bias_layer = 1
 375
 376            # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask
 377            weight_importance, bias_importance = (
 378                self.backbone.get_layer_measure_parameter_wise(
 379                    neuron_wise_measure=self.summative_importance_for_previous_tasks,
 380                    layer_name=layer_name,
 381                    aggregation_mode=self.neuron_to_weight_importance_aggregation_mode,
 382                )
 383            )
 384
 385            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
 386                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
 387                layer_name=layer_name,
 388                aggregation_mode="min",
 389            )
 390
 391            # filter the weight importance by the cumulative mask
 392            if self.filter_by_cumulative_mask:
 393                weight_importance = weight_importance * weight_mask
 394                bias_importance = bias_importance * bias_mask
 395
 396            network_sparsity_layer = network_sparsity[layer_name]
 397
 398            # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper
 399            factor = network_sparsity_layer + self.base_mask_sparsity_reg
 400            if self.importance_scheduler_type == "linear_sparsity_reg":
 401                factor = factor * (self.task_id + self.base_linear)
 402            elif self.importance_scheduler_type == "sparsity_reg":
 403                pass
 404            elif self.importance_scheduler_type == "summative_mask_sparsity_reg":
 405                factor = factor * (
 406                    self.summative_mask_for_previous_tasks + self.base_linear
 407                )
 408
 409            # calculate the adjustment rate
 410            adjustment_rate_weight_layer = torch.div(
 411                self.adjustment_intensity,
 412                (factor * weight_importance + self.adjustment_intensity),
 413            )
 414
 415            adjustment_rate_bias_layer = torch.div(
 416                self.adjustment_intensity,
 417                (factor * bias_importance + self.adjustment_intensity),
 418            )
 419
 420            # apply the adjustment rate to the gradients
 421            layer.weight.grad.data *= adjustment_rate_weight_layer
 422            if layer.bias is not None:
 423                layer.bias.grad.data *= adjustment_rate_bias_layer
 424
 425            # store the adjustment rate for logging
 426            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
 427            if layer.bias is not None:
 428                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
 429
 430            # update network capacity metric
 431            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
 432
 433        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
 434
 435    def on_train_batch_end(
 436        self, outputs: dict[str, Any], batch: Any, batch_idx: int
 437    ) -> None:
 438        r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.
 439
 440        **Args:**
 441        - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`).
 442        - **batch** (`Any`): training data batch.
 443        - **batch_idx** (`int`): index of the current batch (for mask figure file name).
 444        """
 445
 446        # get potential useful information from training batch
 447        activations = outputs["activations"]
 448        input = outputs["input"]
 449        target = outputs["target"]
 450        mask = outputs["mask"]
 451        num_batches = self.trainer.num_training_batches
 452
 453        for layer_name in self.backbone.weighted_layer_names:
 454            # layer-wise operation
 455
 456            activation = activations[layer_name]
 457
 458            # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper.
 459            if self.importance_type == "input_weight_abs_sum":
 460                importance_step = self.get_importance_step_layer_weight_abs_sum(
 461                    layer_name=layer_name,
 462                    if_output_weight=False,
 463                    reciprocal=False,
 464                )
 465            elif self.importance_type == "output_weight_abs_sum":
 466                importance_step = self.get_importance_step_layer_weight_abs_sum(
 467                    layer_name=layer_name,
 468                    if_output_weight=True,
 469                    reciprocal=False,
 470                )
 471            elif self.importance_type == "input_weight_gradient_abs_sum":
 472                importance_step = (
 473                    self.get_importance_step_layer_weight_gradient_abs_sum(
 474                        layer_name=layer_name, if_output_weight=False
 475                    )
 476                )
 477            elif self.importance_type == "output_weight_gradient_abs_sum":
 478                importance_step = (
 479                    self.get_importance_step_layer_weight_gradient_abs_sum(
 480                        layer_name=layer_name, if_output_weight=True
 481                    )
 482                )
 483            elif self.importance_type == "activation_abs":
 484                importance_step = self.get_importance_step_layer_activation_abs(
 485                    activation=activation
 486                )
 487            elif self.importance_type == "input_weight_abs_sum_x_activation_abs":
 488                importance_step = (
 489                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
 490                        layer_name=layer_name,
 491                        activation=activation,
 492                        if_output_weight=False,
 493                    )
 494                )
 495            elif self.importance_type == "output_weight_abs_sum_x_activation_abs":
 496                importance_step = (
 497                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
 498                        layer_name=layer_name,
 499                        activation=activation,
 500                        if_output_weight=True,
 501                    )
 502                )
 503            elif self.importance_type == "gradient_x_activation_abs":
 504                importance_step = (
 505                    self.get_importance_step_layer_gradient_x_activation_abs(
 506                        layer_name=layer_name,
 507                        input=input,
 508                        target=target,
 509                        batch_idx=batch_idx,
 510                        num_batches=num_batches,
 511                    )
 512                )
 513            elif self.importance_type == "input_weight_gradient_square_sum":
 514                importance_step = (
 515                    self.get_importance_step_layer_weight_gradient_square_sum(
 516                        layer_name=layer_name,
 517                        activation=activation,
 518                        if_output_weight=False,
 519                    )
 520                )
 521            elif self.importance_type == "output_weight_gradient_square_sum":
 522                importance_step = (
 523                    self.get_importance_step_layer_weight_gradient_square_sum(
 524                        layer_name=layer_name,
 525                        activation=activation,
 526                        if_output_weight=True,
 527                    )
 528                )
 529            elif (
 530                self.importance_type
 531                == "input_weight_gradient_square_sum_x_activation_abs"
 532            ):
 533                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 534                    layer_name=layer_name,
 535                    activation=activation,
 536                    if_output_weight=False,
 537                )
 538            elif (
 539                self.importance_type
 540                == "output_weight_gradient_square_sum_x_activation_abs"
 541            ):
 542                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 543                    layer_name=layer_name,
 544                    activation=activation,
 545                    if_output_weight=True,
 546                )
 547            elif self.importance_type == "conductance_abs":
 548                importance_step = self.get_importance_step_layer_conductance_abs(
 549                    layer_name=layer_name,
 550                    input=input,
 551                    baselines=None,
 552                    target=target,
 553                    batch_idx=batch_idx,
 554                    num_batches=num_batches,
 555                )
 556            elif self.importance_type == "internal_influence_abs":
 557                importance_step = self.get_importance_step_layer_internal_influence_abs(
 558                    layer_name=layer_name,
 559                    input=input,
 560                    baselines=None,
 561                    target=target,
 562                    batch_idx=batch_idx,
 563                    num_batches=num_batches,
 564                )
 565            elif self.importance_type == "gradcam_abs":
 566                importance_step = self.get_importance_step_layer_gradcam_abs(
 567                    layer_name=layer_name,
 568                    input=input,
 569                    target=target,
 570                    batch_idx=batch_idx,
 571                    num_batches=num_batches,
 572                )
 573            elif self.importance_type == "deeplift_abs":
 574                importance_step = self.get_importance_step_layer_deeplift_abs(
 575                    layer_name=layer_name,
 576                    input=input,
 577                    baselines=None,
 578                    target=target,
 579                    batch_idx=batch_idx,
 580                    num_batches=num_batches,
 581                )
 582            elif self.importance_type == "deepliftshap_abs":
 583                importance_step = self.get_importance_step_layer_deepliftshap_abs(
 584                    layer_name=layer_name,
 585                    input=input,
 586                    baselines=None,
 587                    target=target,
 588                    batch_idx=batch_idx,
 589                    num_batches=num_batches,
 590                )
 591            elif self.importance_type == "gradientshap_abs":
 592                importance_step = self.get_importance_step_layer_gradientshap_abs(
 593                    layer_name=layer_name,
 594                    input=input,
 595                    baselines=None,
 596                    target=target,
 597                    batch_idx=batch_idx,
 598                    num_batches=num_batches,
 599                )
 600            elif self.importance_type == "integrated_gradients_abs":
 601                importance_step = (
 602                    self.get_importance_step_layer_integrated_gradients_abs(
 603                        layer_name=layer_name,
 604                        input=input,
 605                        baselines=None,
 606                        target=target,
 607                        batch_idx=batch_idx,
 608                        num_batches=num_batches,
 609                    )
 610                )
 611            elif self.importance_type == "feature_ablation_abs":
 612                importance_step = self.get_importance_step_layer_feature_ablation_abs(
 613                    layer_name=layer_name,
 614                    input=input,
 615                    layer_baselines=None,
 616                    target=target,
 617                    batch_idx=batch_idx,
 618                    num_batches=num_batches,
 619                )
 620            elif self.importance_type == "lrp_abs":
 621                importance_step = self.get_importance_step_layer_lrp_abs(
 622                    layer_name=layer_name,
 623                    input=input,
 624                    target=target,
 625                    batch_idx=batch_idx,
 626                    num_batches=num_batches,
 627                )
 628            elif self.importance_type == "cbp_adaptation":
 629                importance_step = self.get_importance_step_layer_weight_abs_sum(
 630                    layer_name=layer_name,
 631                    if_output_weight=False,
 632                    reciprocal=True,
 633                )
 634            elif self.importance_type == "cbp_adaptive_contribution":
 635                importance_step = (
 636                    self.get_importance_step_layer_cbp_adaptive_contribution(
 637                        layer_name=layer_name,
 638                        activation=activation,
 639                    )
 640                )
 641
 642            importance_step = min_max_normalize(
 643                importance_step
 644            )  # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper
 645
 646            # multiply the importance by the training mask. See Eq. (6) in the paper
 647            if self.step_multiply_training_mask:
 648                importance_step = importance_step * mask[layer_name]
 649
 650            # update accumulated importance
 651            self.importances[self.task_id][layer_name] = (
 652                self.importances[self.task_id][layer_name] + importance_step
 653            )
 654
 655        # update number of steps counter
 656        self.num_steps_t += 1
 657
 658    def on_train_end(self) -> None:
 659        r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task."""
 660        super().on_train_end()  # store the mask and update cumulative and summative masks
 661
 662        for layer_name in self.backbone.weighted_layer_names:
 663
 664            # average the neuron-wise step importance. See Eq. (4) in the paper
 665            self.importances[self.task_id][layer_name] = (
 666                self.importances[self.task_id][layer_name]
 667            ) / self.num_steps_t
 668
 669            # add the base importance. See Eq. (6) in the paper
 670            self.importances[self.task_id][layer_name] = (
 671                self.importances[self.task_id][layer_name] + self.base_importance
 672            )
 673
 674            # filter unmasked importance
 675            if self.filter_unmasked_importance:
 676                self.importances[self.task_id][layer_name] = (
 677                    self.importances[self.task_id][layer_name]
 678                    * self.backbone.masks[f"{self.task_id}"][layer_name]
 679                )
 680
 681            # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper
 682            if self.importance_summing_strategy == "add_latest":
 683                self.summative_importance_for_previous_tasks[
 684                    layer_name
 685                ] += self.importances[self.task_id][layer_name]
 686
 687            elif self.importance_summing_strategy == "add_all":
 688                for t in range(1, self.task_id + 1):
 689                    self.summative_importance_for_previous_tasks[
 690                        layer_name
 691                    ] += self.importances[t][layer_name]
 692
 693            elif self.importance_summing_strategy == "add_average":
 694                for t in range(1, self.task_id + 1):
 695                    self.summative_importance_for_previous_tasks[layer_name] += (
 696                        self.importances[t][layer_name] / self.task_id
 697                    )
 698            else:
 699                self.summative_importance_for_previous_tasks[
 700                    layer_name
 701                ] = torch.zeros_like(
 702                    self.summative_importance_for_previous_tasks[layer_name]
 703                ).to(
 704                    self.device
 705                )  # starting adding from 0
 706
 707                if self.importance_summing_strategy == "linear_decrease":
 708                    s = self.importance_summing_strategy_linear_step
 709                    for t in range(1, self.task_id + 1):
 710                        w_t = s * (self.task_id - t) + 1
 711
 712                elif self.importance_summing_strategy == "quadratic_decrease":
 713                    for t in range(1, self.task_id + 1):
 714                        w_t = (self.task_id - t + 1) ** 2
 715                elif self.importance_summing_strategy == "cubic_decrease":
 716                    for t in range(1, self.task_id + 1):
 717                        w_t = (self.task_id - t + 1) ** 3
 718                elif self.importance_summing_strategy == "exponential_decrease":
 719                    for t in range(1, self.task_id + 1):
 720                        r = self.importance_summing_strategy_exponential_rate
 721
 722                        w_t = r ** (self.task_id - t + 1)
 723                elif self.importance_summing_strategy == "log_decrease":
 724                    a = self.importance_summing_strategy_log_base
 725                    for t in range(1, self.task_id + 1):
 726                        w_t = math.log(self.task_id - t, a) + 1
 727                elif self.importance_summing_strategy == "factorial_decrease":
 728                    for t in range(1, self.task_id + 1):
 729                        w_t = math.factorial(self.task_id - t + 1)
 730                else:
 731                    raise ValueError
 732                self.summative_importance_for_previous_tasks[layer_name] += (
 733                    self.importances[t][layer_name] * w_t
 734                )
 735
 736    def get_importance_step_layer_weight_abs_sum(
 737        self: str,
 738        layer_name: str,
 739        if_output_weight: bool,
 740        reciprocal: bool,
 741    ) -> Tensor:
 742        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.
 743
 744        **Args:**
 745        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 746        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 747        - **reciprocal** (`bool`): whether to take reciprocal.
 748
 749        **Returns:**
 750        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 751        """
 752        layer = self.backbone.get_layer_by_name(layer_name)
 753
 754        if not if_output_weight:
 755            weight_abs = torch.abs(layer.weight.data)
 756            weight_abs_sum = torch.sum(
 757                weight_abs,
 758                dim=[
 759                    i for i in range(weight_abs.dim()) if i != 0
 760                ],  # sum over the input dimension
 761            )
 762        else:
 763            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
 764            weight_abs_sum = torch.sum(
 765                weight_abs,
 766                dim=[
 767                    i for i in range(weight_abs.dim()) if i != 1
 768                ],  # sum over the output dimension
 769            )
 770
 771        if reciprocal:
 772            weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum)
 773            importance_step_layer = weight_abs_sum_reciprocal
 774        else:
 775            importance_step_layer = weight_abs_sum
 776        importance_step_layer = importance_step_layer.detach()
 777
 778        return importance_step_layer
 779
 780    def get_importance_step_layer_weight_gradient_abs_sum(
 781        self: str,
 782        layer_name: str,
 783        if_output_weight: bool,
 784    ) -> Tensor:
 785        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.
 786
 787        **Args:**
 788        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 789        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 790
 791        **Returns:**
 792        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 793        """
 794        layer = self.backbone.get_layer_by_name(layer_name)
 795
 796        if not if_output_weight:
 797            gradient_abs = torch.abs(layer.weight.grad.data)
 798            gradient_abs_sum = torch.sum(
 799                gradient_abs,
 800                dim=[
 801                    i for i in range(gradient_abs.dim()) if i != 0
 802                ],  # sum over the input dimension
 803            )
 804        else:
 805            gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data)
 806            gradient_abs_sum = torch.sum(
 807                gradient_abs,
 808                dim=[
 809                    i for i in range(gradient_abs.dim()) if i != 1
 810                ],  # sum over the output dimension
 811            )
 812
 813        importance_step_layer = gradient_abs_sum
 814        importance_step_layer = importance_step_layer.detach()
 815
 816        return importance_step_layer
 817
 818    def get_importance_step_layer_activation_abs(
 819        self: str,
 820        activation: Tensor,
 821    ) -> Tensor:
 822        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum.
 823
 824        **Args:**
 825        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 826
 827        **Returns:**
 828        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 829        """
 830        activation_abs_batch_mean = torch.mean(
 831            torch.abs(activation),
 832            dim=[
 833                i for i in range(activation.dim()) if i != 1
 834            ],  # average the features over batch samples
 835        )
 836        importance_step_layer = activation_abs_batch_mean
 837        importance_step_layer = importance_step_layer.detach()
 838
 839        return importance_step_layer
 840
 841    def get_importance_step_layer_weight_abs_sum_x_activation_abs(
 842        self: str,
 843        layer_name: str,
 844        activation: Tensor,
 845        if_output_weight: bool,
 846    ) -> Tensor:
 847        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
 848
 849        **Args:**
 850        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 851        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 852        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 853
 854        **Returns:**
 855        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 856        """
 857        layer = self.backbone.get_layer_by_name(layer_name)
 858
 859        if not if_output_weight:
 860            weight_abs = torch.abs(layer.weight.data)
 861            weight_abs_sum = torch.sum(
 862                weight_abs,
 863                dim=[
 864                    i for i in range(weight_abs.dim()) if i != 0
 865                ],  # sum over the input dimension
 866            )
 867        else:
 868            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
 869            weight_abs_sum = torch.sum(
 870                weight_abs,
 871                dim=[
 872                    i for i in range(weight_abs.dim()) if i != 1
 873                ],  # sum over the output dimension
 874            )
 875
 876        activation_abs_batch_mean = torch.mean(
 877            torch.abs(activation),
 878            dim=[
 879                i for i in range(activation.dim()) if i != 1
 880            ],  # average the features over batch samples
 881        )
 882
 883        importance_step_layer = weight_abs_sum * activation_abs_batch_mean
 884        importance_step_layer = importance_step_layer.detach()
 885
 886        return importance_step_layer
 887
 888    def get_importance_step_layer_gradient_x_activation_abs(
 889        self: str,
 890        layer_name: str,
 891        input: Tensor | tuple[Tensor, ...],
 892        target: Tensor | None,
 893        batch_idx: int,
 894        num_batches: int,
 895    ) -> Tensor:
 896        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum.
 897
 898        **Args:**
 899        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 900        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
 901        - **target** (`Tensor` | `None`): the target batch of the training step.
 902        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
 903        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
 904
 905        **Returns:**
 906        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 907        """
 908        layer = self.backbone.get_layer_by_name(layer_name)
 909
 910        input = input.requires_grad_()
 911
 912        # initialize the Layer Gradient X Activation object
 913        layer_gradient_x_activation = LayerGradientXActivation(
 914            forward_func=self.forward, layer=layer
 915        )
 916
 917        self.set_forward_func_return_logits_only(True)
 918        # calculate layer attribution of the step
 919        attribution = layer_gradient_x_activation.attribute(
 920            inputs=input,
 921            target=target,
 922            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
 923        )
 924        self.set_forward_func_return_logits_only(False)
 925
 926        attribution_abs_batch_mean = torch.mean(
 927            torch.abs(attribution),
 928            dim=[
 929                i for i in range(attribution.dim()) if i != 1
 930            ],  # average the features over batch samples
 931        )
 932
 933        importance_step_layer = attribution_abs_batch_mean
 934        importance_step_layer = importance_step_layer.detach()
 935
 936        return importance_step_layer
 937
 938    def get_importance_step_layer_weight_gradient_square_sum(
 939        self: str,
 940        layer_name: str,
 941        activation: Tensor,
 942        if_output_weight: bool,
 943    ) -> Tensor:
 944        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 945
 946        **Args:**
 947        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 948        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 949        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 950
 951        **Returns:**
 952        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 953        """
 954        layer = self.backbone.get_layer_by_name(layer_name)
 955
 956        if not if_output_weight:
 957            gradient_square = layer.weight.grad.data**2
 958            gradient_square_sum = torch.sum(
 959                gradient_square,
 960                dim=[
 961                    i for i in range(gradient_square.dim()) if i != 0
 962                ],  # sum over the input dimension
 963            )
 964        else:
 965            gradient_square = self.next_layer(layer_name).weight.grad.data**2
 966            gradient_square_sum = torch.sum(
 967                gradient_square,
 968                dim=[
 969                    i for i in range(gradient_square.dim()) if i != 1
 970                ],  # sum over the output dimension
 971            )
 972
 973        importance_step_layer = gradient_square_sum
 974        importance_step_layer = importance_step_layer.detach()
 975
 976        return importance_step_layer
 977
 978    def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 979        self: str,
 980        layer_name: str,
 981        activation: Tensor,
 982        if_output_weight: bool,
 983    ) -> Tensor:
 984        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 985
 986        **Args:**
 987        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 988        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 989        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 990
 991        **Returns:**
 992        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 993        """
 994        layer = self.backbone.get_layer_by_name(layer_name)
 995
 996        if not if_output_weight:
 997            gradient_square = layer.weight.grad.data**2
 998            gradient_square_sum = torch.sum(
 999                gradient_square,
1000                dim=[
1001                    i for i in range(gradient_square.dim()) if i != 0
1002                ],  # sum over the input dimension
1003            )
1004        else:
1005            gradient_square = self.next_layer(layer_name).weight.grad.data**2
1006            gradient_square_sum = torch.sum(
1007                gradient_square,
1008                dim=[
1009                    i for i in range(gradient_square.dim()) if i != 1
1010                ],  # sum over the output dimension
1011            )
1012
1013        activation_abs_batch_mean = torch.mean(
1014            torch.abs(activation),
1015            dim=[
1016                i for i in range(activation.dim()) if i != 1
1017            ],  # average the features over batch samples
1018        )
1019
1020        importance_step_layer = gradient_square_sum * activation_abs_batch_mean
1021        importance_step_layer = importance_step_layer.detach()
1022
1023        return importance_step_layer
1024
1025    def get_importance_step_layer_conductance_abs(
1026        self: str,
1027        layer_name: str,
1028        input: Tensor | tuple[Tensor, ...],
1029        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1030        target: Tensor | None,
1031        batch_idx: int,
1032        num_batches: int,
1033    ) -> Tensor:
1034        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum.
1035
1036        **Args:**
1037        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1038        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1039        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details.
1040        - **target** (`Tensor` | `None`): the target batch of the training step.
1041        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1042        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).
1043
1044        **Returns:**
1045        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1046        """
1047        layer = self.backbone.get_layer_by_name(layer_name)
1048
1049        # initialize the Layer Conductance object
1050        layer_conductance = LayerConductance(forward_func=self.forward, layer=layer)
1051
1052        self.set_forward_func_return_logits_only(True)
1053        # calculate layer attribution of the step
1054        attribution = layer_conductance.attribute(
1055            inputs=input,
1056            baselines=baselines,
1057            target=target,
1058            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1059        )
1060        self.set_forward_func_return_logits_only(False)
1061
1062        attribution_abs_batch_mean = torch.mean(
1063            torch.abs(attribution),
1064            dim=[
1065                i for i in range(attribution.dim()) if i != 1
1066            ],  # average the features over batch samples
1067        )
1068
1069        importance_step_layer = attribution_abs_batch_mean
1070        importance_step_layer = importance_step_layer.detach()
1071
1072        return importance_step_layer
1073
1074    def get_importance_step_layer_internal_influence_abs(
1075        self: str,
1076        layer_name: str,
1077        input: Tensor | tuple[Tensor, ...],
1078        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1079        target: Tensor | None,
1080        batch_idx: int,
1081        num_batches: int,
1082    ) -> Tensor:
1083        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum.
1084
1085        **Args:**
1086        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1087        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1088        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details.
1089        - **target** (`Tensor` | `None`): the target batch of the training step.
1090        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1091        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1092
1093        **Returns:**
1094        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1095        """
1096        layer = self.backbone.get_layer_by_name(layer_name)
1097
1098        # initialize the Internal Influence object
1099        internal_influence = InternalInfluence(forward_func=self.forward, layer=layer)
1100
1101        # convert the target to long type to avoid error
1102        target = target.long() if target is not None else None
1103
1104        self.set_forward_func_return_logits_only(True)
1105        # calculate layer attribution of the step
1106        attribution = internal_influence.attribute(
1107            inputs=input,
1108            baselines=baselines,
1109            target=target,
1110            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1111            n_steps=5,  # set 10 instead of default 50 to accelerate the computation
1112        )
1113        self.set_forward_func_return_logits_only(False)
1114
1115        attribution_abs_batch_mean = torch.mean(
1116            torch.abs(attribution),
1117            dim=[
1118                i for i in range(attribution.dim()) if i != 1
1119            ],  # average the features over batch samples
1120        )
1121
1122        importance_step_layer = attribution_abs_batch_mean
1123        importance_step_layer = importance_step_layer.detach()
1124
1125        return importance_step_layer
1126
1127    def get_importance_step_layer_gradcam_abs(
1128        self: str,
1129        layer_name: str,
1130        input: Tensor | tuple[Tensor, ...],
1131        target: Tensor | None,
1132        batch_idx: int,
1133        num_batches: int,
1134    ) -> Tensor:
1135        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum.
1136
1137        **Args:**
1138        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1139        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1140        - **target** (`Tensor` | `None`): the target batch of the training step.
1141        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1142        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1143
1144        **Returns:**
1145        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1146        """
1147        layer = self.backbone.get_layer_by_name(layer_name)
1148
1149        # initialize the GradCAM object
1150        gradcam = LayerGradCam(forward_func=self.forward, layer=layer)
1151
1152        self.set_forward_func_return_logits_only(True)
1153        # calculate layer attribution of the step
1154        attribution = gradcam.attribute(
1155            inputs=input,
1156            target=target,
1157            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1158        )
1159        self.set_forward_func_return_logits_only(False)
1160
1161        attribution_abs_batch_mean = torch.mean(
1162            torch.abs(attribution),
1163            dim=[
1164                i for i in range(attribution.dim()) if i != 1
1165            ],  # average the features over batch samples
1166        )
1167
1168        importance_step_layer = attribution_abs_batch_mean
1169        importance_step_layer = importance_step_layer.detach()
1170
1171        return importance_step_layer
1172
1173    def get_importance_step_layer_deeplift_abs(
1174        self: str,
1175        layer_name: str,
1176        input: Tensor | tuple[Tensor, ...],
1177        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1178        target: Tensor | None,
1179        batch_idx: int,
1180        num_batches: int,
1181    ) -> Tensor:
1182        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum.
1183
1184        **Args:**
1185        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1186        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1187        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details.
1188        - **target** (`Tensor` | `None`): the target batch of the training step.
1189        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1190        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1191
1192        **Returns:**
1193        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1194        """
1195        layer = self.backbone.get_layer_by_name(layer_name)
1196
1197        # initialize the Layer DeepLift object
1198        layer_deeplift = LayerDeepLift(model=self, layer=layer)
1199
1200        # convert the target to long type to avoid error
1201        target = target.long() if target is not None else None
1202
1203        self.set_forward_func_return_logits_only(True)
1204        # calculate layer attribution of the step
1205        attribution = layer_deeplift.attribute(
1206            inputs=input,
1207            baselines=baselines,
1208            target=target,
1209            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1210        )
1211        self.set_forward_func_return_logits_only(False)
1212
1213        attribution_abs_batch_mean = torch.mean(
1214            torch.abs(attribution),
1215            dim=[
1216                i for i in range(attribution.dim()) if i != 1
1217            ],  # average the features over batch samples
1218        )
1219
1220        importance_step_layer = attribution_abs_batch_mean
1221        importance_step_layer = importance_step_layer.detach()
1222
1223        return importance_step_layer
1224
1225    def get_importance_step_layer_deepliftshap_abs(
1226        self: str,
1227        layer_name: str,
1228        input: Tensor | tuple[Tensor, ...],
1229        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1230        target: Tensor | None,
1231        batch_idx: int,
1232        num_batches: int,
1233    ) -> Tensor:
1234        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum.
1235
1236        **Args:**
1237        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1238        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1239        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details.
1240        - **target** (`Tensor` | `None`): the target batch of the training step.
1241        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1242        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1243
1244        **Returns:**
1245        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1246        """
1247        layer = self.backbone.get_layer_by_name(layer_name)
1248
1249        # initialize the Layer DeepLiftShap object
1250        layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer)
1251
1252        # convert the target to long type to avoid error
1253        target = target.long() if target is not None else None
1254
1255        self.set_forward_func_return_logits_only(True)
1256        # calculate layer attribution of the step
1257        attribution = layer_deepliftshap.attribute(
1258            inputs=input,
1259            baselines=baselines,
1260            target=target,
1261            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1262        )
1263        self.set_forward_func_return_logits_only(False)
1264
1265        attribution_abs_batch_mean = torch.mean(
1266            torch.abs(attribution),
1267            dim=[
1268                i for i in range(attribution.dim()) if i != 1
1269            ],  # average the features over batch samples
1270        )
1271
1272        importance_step_layer = attribution_abs_batch_mean
1273        importance_step_layer = importance_step_layer.detach()
1274
1275        return importance_step_layer
1276
1277    def get_importance_step_layer_gradientshap_abs(
1278        self: str,
1279        layer_name: str,
1280        input: Tensor | tuple[Tensor, ...],
1281        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1282        target: Tensor | None,
1283        batch_idx: int,
1284        num_batches: int,
1285    ) -> Tensor:
1286        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum.
1287
1288        **Args:**
1289        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1290        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1291        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero.
1292        - **target** (`Tensor` | `None`): the target batch of the training step.
1293        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1294        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1295
1296        **Returns:**
1297        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1298        """
1299        layer = self.backbone.get_layer_by_name(layer_name)
1300
1301        if baselines is None:
1302            baselines = torch.zeros_like(
1303                input
1304            )  # baselines are mandatory for GradientShap API. We explicitly set them to zero
1305
1306        # initialize the Layer GradientShap object
1307        layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer)
1308
1309        # convert the target to long type to avoid error
1310        target = target.long() if target is not None else None
1311
1312        self.set_forward_func_return_logits_only(True)
1313        # calculate layer attribution of the step
1314        attribution = layer_gradientshap.attribute(
1315            inputs=input,
1316            baselines=baselines,
1317            target=target,
1318            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1319        )
1320        self.set_forward_func_return_logits_only(False)
1321
1322        attribution_abs_batch_mean = torch.mean(
1323            torch.abs(attribution),
1324            dim=[
1325                i for i in range(attribution.dim()) if i != 1
1326            ],  # average the features over batch samples
1327        )
1328
1329        importance_step_layer = attribution_abs_batch_mean
1330        importance_step_layer = importance_step_layer.detach()
1331
1332        return importance_step_layer
1333
1334    def get_importance_step_layer_integrated_gradients_abs(
1335        self: str,
1336        layer_name: str,
1337        input: Tensor | tuple[Tensor, ...],
1338        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1339        target: Tensor | None,
1340        batch_idx: int,
1341        num_batches: int,
1342    ) -> Tensor:
1343        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum.
1344
1345        **Args:**
1346        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1347        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1348        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details.
1349        - **target** (`Tensor` | `None`): the target batch of the training step.
1350        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1351        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1352
1353        **Returns:**
1354        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1355        """
1356        layer = self.backbone.get_layer_by_name(layer_name)
1357
1358        # initialize the Layer Integrated Gradients object
1359        layer_integrated_gradients = LayerIntegratedGradients(
1360            forward_func=self.forward, layer=layer
1361        )
1362
1363        self.set_forward_func_return_logits_only(True)
1364        # calculate layer attribution of the step
1365        attribution = layer_integrated_gradients.attribute(
1366            inputs=input,
1367            baselines=baselines,
1368            target=target,
1369            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1370        )
1371        self.set_forward_func_return_logits_only(False)
1372
1373        attribution_abs_batch_mean = torch.mean(
1374            torch.abs(attribution),
1375            dim=[
1376                i for i in range(attribution.dim()) if i != 1
1377            ],  # average the features over batch samples
1378        )
1379
1380        importance_step_layer = attribution_abs_batch_mean
1381        importance_step_layer = importance_step_layer.detach()
1382
1383        return importance_step_layer
1384
1385    def get_importance_step_layer_feature_ablation_abs(
1386        self: str,
1387        layer_name: str,
1388        input: Tensor | tuple[Tensor, ...],
1389        layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1390        target: Tensor | None,
1391        batch_idx: int,
1392        num_batches: int,
1393        if_captum: bool = False,
1394    ) -> Tensor:
1395        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum.
1396
1397        **Args:**
1398        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1399        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1400        - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details.
1401        - **target** (`Tensor` | `None`): the target batch of the training step.
1402        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1403        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1404        - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster.
1405
1406        **Returns:**
1407        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1408        """
1409        layer = self.backbone.get_layer_by_name(layer_name)
1410
1411        if not if_captum:
1412            # 1. Baseline logits (take first element of forward output)
1413            baseline_out, _, _ = self.forward(
1414                input, "train", batch_idx, num_batches, self.task_id
1415            )
1416            if target is not None:
1417                baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1)
1418            else:
1419                baseline_scores = baseline_out.sum(dim=1)
1420
1421            # 2. Capture layer’s output shape
1422            activs = {}
1423            handle = layer.register_forward_hook(
1424                lambda module, inp, out: activs.setdefault("output", out.detach())
1425            )
1426            _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id)
1427            handle.remove()
1428            layer_output = activs["output"]  # shape (B, F, ...)
1429
1430            # 3. Build baseline tensor matching that shape
1431            if layer_baselines is None:
1432                baseline_tensor = torch.zeros_like(layer_output)
1433            elif isinstance(layer_baselines, (int, float)):
1434                baseline_tensor = torch.full_like(layer_output, layer_baselines)
1435            elif isinstance(layer_baselines, Tensor):
1436                if layer_baselines.shape == layer_output.shape:
1437                    baseline_tensor = layer_baselines
1438                elif layer_baselines.shape == layer_output.shape[1:]:
1439                    baseline_tensor = layer_baselines.unsqueeze(0).repeat(
1440                        layer_output.size(0), *([1] * layer_baselines.ndim)
1441                    )
1442                else:
1443                    raise ValueError(...)
1444            else:
1445                raise ValueError(...)
1446
1447            B, F = layer_output.size(0), layer_output.size(1)
1448
1449            # 4. Create a “mega-batch” replicating the input F times
1450            if isinstance(input, tuple):
1451                mega_inputs = tuple(
1452                    t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:])
1453                    for t in input
1454                )
1455            else:
1456                mega_inputs = (
1457                    input.unsqueeze(0)
1458                    .repeat(F, *([1] * input.ndim))
1459                    .view(-1, *input.shape[1:])
1460                )
1461
1462            # 5. Equally replicate the baseline tensor
1463            mega_baseline = (
1464                baseline_tensor.unsqueeze(0)
1465                .repeat(F, *([1] * baseline_tensor.ndim))
1466                .view(-1, *baseline_tensor.shape[1:])
1467            )
1468
1469            # 6. Precompute vectorized indices
1470            device = layer_output.device
1471            positions = torch.arange(F * B, device=device)  # [0,1,...,F*B-1]
1472            feat_idx = torch.arange(F, device=device).repeat_interleave(
1473                B
1474            )  # [0,0,...,1,1,...,F-1]
1475
1476            # 7. One hook to zero out each channel slice across the mega-batch
1477            def mega_ablate_hook(module, inp, out):
1478                out_mod = out.clone()
1479                # for each sample in mega-batch, zero its corresponding channel
1480                out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx]
1481                return out_mod
1482
1483            h = layer.register_forward_hook(mega_ablate_hook)
1484            out_all, _, _ = self.forward(
1485                mega_inputs, "train", batch_idx, num_batches, self.task_id
1486            )
1487            h.remove()
1488
1489            # 8. Recover scores, reshape [F*B] → [F, B], diff & mean
1490            if target is not None:
1491                tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1)
1492                scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1)
1493            else:
1494                scores_all = out_all.sum(dim=1)
1495
1496            scores_all = scores_all.view(F, B)
1497            diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all)
1498            importance_step_layer = diffs.mean(dim=1).detach()  # [F]
1499
1500            return importance_step_layer
1501
1502        else:
1503            # initialize the Layer Feature Ablation object
1504            layer_feature_ablation = LayerFeatureAblation(
1505                forward_func=self.forward, layer=layer
1506            )
1507
1508            # calculate layer attribution of the step
1509            self.set_forward_func_return_logits_only(True)
1510            attribution = layer_feature_ablation.attribute(
1511                inputs=input,
1512                layer_baselines=layer_baselines,
1513                # target=target, # disable target to enable perturbations_per_eval
1514                additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1515                perturbations_per_eval=128,  # to accelerate the computation
1516            )
1517            self.set_forward_func_return_logits_only(False)
1518
1519            attribution_abs_batch_mean = torch.mean(
1520                torch.abs(attribution),
1521                dim=[
1522                    i for i in range(attribution.dim()) if i != 1
1523                ],  # average the features over batch samples
1524            )
1525
1526        importance_step_layer = attribution_abs_batch_mean
1527        importance_step_layer = importance_step_layer.detach()
1528
1529        return importance_step_layer
1530
1531    def get_importance_step_layer_lrp_abs(
1532        self: str,
1533        layer_name: str,
1534        input: Tensor | tuple[Tensor, ...],
1535        target: Tensor | None,
1536        batch_idx: int,
1537        num_batches: int,
1538    ) -> Tensor:
1539        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum.
1540
1541        **Args:**
1542        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1543        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1544        - **target** (`Tensor` | `None`): the target batch of the training step.
1545        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1546        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1547
1548        **Returns:**
1549        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1550        """
1551        layer = self.backbone.get_layer_by_name(layer_name)
1552
1553        # initialize the Layer LRP object
1554        layer_lrp = LayerLRP(model=self, layer=layer)
1555
1556        # set model to evaluation mode to prevent updating the model parameters
1557        self.eval()
1558
1559        self.set_forward_func_return_logits_only(True)
1560        # calculate layer attribution of the step
1561        attribution = layer_lrp.attribute(
1562            inputs=input,
1563            target=target,
1564            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1565        )
1566        self.set_forward_func_return_logits_only(False)
1567
1568        attribution_abs_batch_mean = torch.mean(
1569            torch.abs(attribution),
1570            dim=[
1571                i for i in range(attribution.dim()) if i != 1
1572            ],  # average the features over batch samples
1573        )
1574
1575        importance_step_layer = attribution_abs_batch_mean
1576        importance_step_layer = importance_step_layer.detach()
1577
1578        return importance_step_layer
1579
1580    def get_importance_step_layer_cbp_adaptive_contribution(
1581        self: str,
1582        layer_name: str,
1583        activation: Tensor,
1584    ) -> Tensor:
1585        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
1586
1587        **Args:**
1588        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1589        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
1590
1591        **Returns:**
1592        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1593        """
1594        layer = self.backbone.get_layer_by_name(layer_name)
1595
1596        input_weight_abs = torch.abs(layer.weight.data)
1597        input_weight_abs_sum = torch.sum(
1598            input_weight_abs,
1599            dim=[
1600                i for i in range(input_weight_abs.dim()) if i != 0
1601            ],  # sum over the input dimension
1602        )
1603        input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum)
1604
1605        output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
1606        output_weight_abs_sum = torch.sum(
1607            output_weight_abs,
1608            dim=[
1609                i for i in range(output_weight_abs.dim()) if i != 1
1610            ],  # sum over the output dimension
1611        )
1612
1613        activation_abs_batch_mean = torch.mean(
1614            torch.abs(activation),
1615            dim=[
1616                i for i in range(activation.dim()) if i != 1
1617            ],  # average the features over batch samples
1618        )
1619
1620        importance_step_layer = (
1621            output_weight_abs_sum
1622            * activation_abs_batch_mean
1623            * input_weight_abs_sum_reciprocal
1624        )
1625        importance_step_layer = importance_step_layer.detach()
1626
1627        return importance_step_layer
class FGAdaHAT(clarena.cl_algorithms.adahat.AdaHAT):
  37class FGAdaHAT(AdaHAT):
  38    r"""FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm.
  39
  40    An architecture-based continual learning approach that improves [AdaHAT (Adaptive Hard Attention to the Task)](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.
  41
  42    We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components.
  43    """
  44
  45    def __init__(
  46        self,
  47        backbone: HATMaskBackbone,
  48        heads: HeadsTIL,
  49        adjustment_intensity: float,
  50        importance_type: str,
  51        importance_summing_strategy: str,
  52        importance_scheduler_type: str,
  53        neuron_to_weight_importance_aggregation_mode: str,
  54        s_max: float,
  55        clamp_threshold: float,
  56        mask_sparsity_reg_factor: float,
  57        mask_sparsity_reg_mode: str = "original",
  58        base_importance: float = 0.01,
  59        base_mask_sparsity_reg: float = 0.1,
  60        base_linear: float = 10,
  61        filter_by_cumulative_mask: bool = False,
  62        filter_unmasked_importance: bool = True,
  63        step_multiply_training_mask: bool = True,
  64        task_embedding_init_mode: str = "N01",
  65        importance_summing_strategy_linear_step: float | None = None,
  66        importance_summing_strategy_exponential_rate: float | None = None,
  67        importance_summing_strategy_log_base: float | None = None,
  68        non_algorithmic_hparams: dict[str, Any] = {},
  69    ) -> None:
  70        r"""Initialize the FG-AdaHAT algorithm with the network.
  71
  72        **Args:**
  73        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
  74        - **heads** (`HeadsTIL`): output heads. FG-AdaHAT supports only TIL (Task-Incremental Learning).
  75        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper).
  76        - **importance_type** (`str`): the type of neuron-wise importance, must be one of:
  77            1. 'input_weight_abs_sum': sum of absolute input weights;
  78            2. 'output_weight_abs_sum': sum of absolute output weights;
  79            3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
  80            4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
  81            5. 'activation_abs': absolute activation;
  82            6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
  83            7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
  84            8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
  85            9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
  86            10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
  87            11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
  88            12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
  89            13. 'conductance_abs': absolute layer conductance;
  90            14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
  91            15. 'gradcam_abs': absolute Grad-CAM;
  92            16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
  93            17. 'deepliftshap_abs': absolute DeepLIFT-SHAP;
  94            18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
  95            19. 'integrated_gradients_abs': absolute Integrated Gradients;
  96            20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
  97            21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
  98            22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
  99            23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
 100        - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of:
 101            1. 'add_latest': add the latest neuron-wise importance to the summative importance;
 102            2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
 103            3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
 104            4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
 105            5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
 106            6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
 107            7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
 108            8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
 109            9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
 110        - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:
 111            1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`);
 112            2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$;
 113            3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$.
 114        - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:
 115            1. 'min': take the minimum of neuron-wise importance for each weight;
 116            2. 'max': take the maximum of neuron-wise importance for each weight;
 117            3. 'mean': take the mean of neuron-wise importance for each weight.
 118        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 119        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 120        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
 121        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
 122            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
 123            2. 'cross': the cross version of mask sparsity regularization.
 124        - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01.
 125        - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1.
 126        - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10.
 127        - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False.
 128        - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False.
 129        - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True.
 130        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
 131            1. 'N01' (default): standard normal distribution $N(0, 1)$.
 132            2. 'U-11': uniform distribution $U(-1, 1)$.
 133            3. 'U01': uniform distribution $U(0, 1)$.
 134            4. 'U-10': uniform distribution $U(-1, 0)$.
 135            5. 'last': inherit the task embedding from the last task.
 136        - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0.
 137        - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1.
 138        - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1.
 139        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 140
 141        """
 142        super().__init__(
 143            backbone=backbone,
 144            heads=heads,
 145            adjustment_mode=None,  # use the own adjustment mechanism of FG-AdaHAT
 146            adjustment_intensity=adjustment_intensity,
 147            s_max=s_max,
 148            clamp_threshold=clamp_threshold,
 149            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
 150            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
 151            task_embedding_init_mode=task_embedding_init_mode,
 152            epsilon=base_mask_sparsity_reg,  # the epsilon is now the base mask sparsity regularization factor
 153            non_algorithmic_hparams=non_algorithmic_hparams,
 154        )
 155
 156        # save additional algorithmic hyperparameters
 157        self.save_hyperparameters(
 158            "adjustment_intensity",
 159            "importance_type",
 160            "importance_summing_strategy",
 161            "importance_scheduler_type",
 162            "neuron_to_weight_importance_aggregation_mode",
 163            "s_max",
 164            "clamp_threshold",
 165            "mask_sparsity_reg_factor",
 166            "mask_sparsity_reg_mode",
 167            "base_importance",
 168            "base_mask_sparsity_reg",
 169            "base_linear",
 170            "filter_by_cumulative_mask",
 171            "filter_unmasked_importance",
 172            "step_multiply_training_mask",
 173        )
 174
 175        self.importance_type: str | None = importance_type
 176        r"""The type of the neuron-wise importance added to AdaHAT importance."""
 177
 178        self.importance_scheduler_type: str = importance_scheduler_type
 179        r"""The type of the importance scheduler."""
 180        self.neuron_to_weight_importance_aggregation_mode: str = (
 181            neuron_to_weight_importance_aggregation_mode
 182        )
 183        r"""The mode of aggregation from neuron-wise to weight-wise importance. """
 184        self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask
 185        r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate."""
 186        self.filter_unmasked_importance: bool = filter_unmasked_importance
 187        r"""The flag to filter unmasked importance values (set them to 0) at the end of task training."""
 188        self.step_multiply_training_mask: bool = step_multiply_training_mask
 189        r"""The flag to multiply the training mask to the importance at each training step."""
 190
 191        # importance summing strategy
 192        self.importance_summing_strategy: str = importance_summing_strategy
 193        r"""The strategy to sum the neuron-wise importance for previous tasks."""
 194        if importance_summing_strategy_linear_step is not None:
 195            self.importance_summing_strategy_linear_step: float = (
 196                importance_summing_strategy_linear_step
 197            )
 198            r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease')."""
 199        if importance_summing_strategy_exponential_rate is not None:
 200            self.importance_summing_strategy_exponential_rate: float = (
 201                importance_summing_strategy_exponential_rate
 202            )
 203            r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """
 204        if importance_summing_strategy_log_base is not None:
 205            self.importance_summing_strategy_log_base: float = (
 206                importance_summing_strategy_log_base
 207            )
 208            r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """
 209
 210        # base values
 211        self.base_importance: float = base_importance
 212        r"""The base value added to the importance to avoid zero. """
 213        self.base_mask_sparsity_reg: float = base_mask_sparsity_reg
 214        r"""The base value added to the mask sparsity regularization to avoid zero. """
 215        self.base_linear: float = base_linear
 216        r"""The base value added to the linear layer to avoid zero. """
 217
 218        self.importances: dict[int, dict[str, Tensor]] = {}
 219        r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
 220        self.summative_importance_for_previous_tasks: dict[str, Tensor] = {}
 221        r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """
 222
 223        self.num_steps_t: int
 224        r"""The number of training steps for the current task `self.task_id`."""
 225        # set manual optimization
 226        self.automatic_optimization = False
 227
 228        FGAdaHAT.sanity_check(self)
 229
 230    def sanity_check(self) -> None:
 231        r"""Sanity check."""
 232
 233        # check importance type
 234        if self.importance_type not in [
 235            "input_weight_abs_sum",
 236            "output_weight_abs_sum",
 237            "input_weight_gradient_abs_sum",
 238            "output_weight_gradient_abs_sum",
 239            "activation_abs",
 240            "input_weight_abs_sum_x_activation_abs",
 241            "output_weight_abs_sum_x_activation_abs",
 242            "gradient_x_activation_abs",
 243            "input_weight_gradient_square_sum",
 244            "output_weight_gradient_square_sum",
 245            "input_weight_gradient_square_sum_x_activation_abs",
 246            "output_weight_gradient_square_sum_x_activation_abs",
 247            "conductance_abs",
 248            "internal_influence_abs",
 249            "gradcam_abs",
 250            "deeplift_abs",
 251            "deepliftshap_abs",
 252            "gradientshap_abs",
 253            "integrated_gradients_abs",
 254            "feature_ablation_abs",
 255            "lrp_abs",
 256            "cbp_adaptation",
 257            "cbp_adaptive_contribution",
 258        ]:
 259            raise ValueError(
 260                f"importance_type must be one of the predefined types, but got {self.importance_type}"
 261            )
 262
 263        # check importance summing strategy
 264        if self.importance_summing_strategy not in [
 265            "add_latest",
 266            "add_all",
 267            "add_average",
 268            "linear_decrease",
 269            "quadratic_decrease",
 270            "cubic_decrease",
 271            "exponential_decrease",
 272            "log_decrease",
 273            "factorial_decrease",
 274        ]:
 275            raise ValueError(
 276                f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}"
 277            )
 278
 279        # check importance scheduler type
 280        if self.importance_scheduler_type not in [
 281            "linear_sparsity_reg",
 282            "sparsity_reg",
 283            "summative_mask_sparsity_reg",
 284        ]:
 285            raise ValueError(
 286                f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}"
 287            )
 288
 289        # check neuron to weight importance aggregation mode
 290        if self.neuron_to_weight_importance_aggregation_mode not in [
 291            "min",
 292            "max",
 293            "mean",
 294        ]:
 295            raise ValueError(
 296                f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}"
 297            )
 298
 299        # check base values
 300        if self.base_importance < 0:
 301            raise ValueError(
 302                f"base_importance must be >= 0, but got {self.base_importance}"
 303            )
 304        if self.base_mask_sparsity_reg <= 0:
 305            raise ValueError(
 306                f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}"
 307            )
 308        if self.base_linear <= 0:
 309            raise ValueError(f"base_linear must be > 0, but got {self.base_linear}")
 310
 311    def on_train_start(self) -> None:
 312        r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization."""
 313        super().on_train_start()
 314
 315        self.importances[self.task_id] = (
 316            {}
 317        )  # initialize the importance for the current task
 318
 319        # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time.
 320        for layer_name in self.backbone.weighted_layer_names:
 321            layer = self.backbone.get_layer_by_name(
 322                layer_name
 323            )  # get the layer by its name
 324            num_units = layer.weight.shape[0]
 325
 326            # initialize the accumulated importance at the beginning of each task
 327            self.importances[self.task_id][layer_name] = torch.zeros(num_units).to(
 328                self.device
 329            )
 330
 331            # reset the number of steps counter for the current task
 332            self.num_steps_t = 0
 333
 334            # initialize the summative neuron-wise importance at the beginning of the first task
 335            if self.task_id == 1:
 336                self.summative_importance_for_previous_tasks[layer_name] = torch.zeros(
 337                    num_units
 338                ).to(
 339                    self.device
 340                )  # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$
 341
 342    def clip_grad_by_adjustment(
 343        self,
 344        network_sparsity: dict[str, Tensor],
 345    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
 346        r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper.
 347
 348        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
 349
 350        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 351
 352        **Args:**
 353        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.
 354
 355        **Returns:**
 356        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
 357        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
 358        - **capacity** (`Tensor`): the calculated network capacity.
 359        """
 360
 361        # initialize network capacity metric
 362        capacity = HATNetworkCapacityMetric().to(self.device)
 363        adjustment_rate_weight = {}
 364        adjustment_rate_bias = {}
 365
 366        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper
 367        for layer_name in self.backbone.weighted_layer_names:
 368
 369            layer = self.backbone.get_layer_by_name(
 370                layer_name
 371            )  # get the layer by its name
 372
 373            # placeholder for the adjustment rate to avoid the error of using it before assignment
 374            adjustment_rate_weight_layer = 1
 375            adjustment_rate_bias_layer = 1
 376
 377            # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask
 378            weight_importance, bias_importance = (
 379                self.backbone.get_layer_measure_parameter_wise(
 380                    neuron_wise_measure=self.summative_importance_for_previous_tasks,
 381                    layer_name=layer_name,
 382                    aggregation_mode=self.neuron_to_weight_importance_aggregation_mode,
 383                )
 384            )
 385
 386            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
 387                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
 388                layer_name=layer_name,
 389                aggregation_mode="min",
 390            )
 391
 392            # filter the weight importance by the cumulative mask
 393            if self.filter_by_cumulative_mask:
 394                weight_importance = weight_importance * weight_mask
 395                bias_importance = bias_importance * bias_mask
 396
 397            network_sparsity_layer = network_sparsity[layer_name]
 398
 399            # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper
 400            factor = network_sparsity_layer + self.base_mask_sparsity_reg
 401            if self.importance_scheduler_type == "linear_sparsity_reg":
 402                factor = factor * (self.task_id + self.base_linear)
 403            elif self.importance_scheduler_type == "sparsity_reg":
 404                pass
 405            elif self.importance_scheduler_type == "summative_mask_sparsity_reg":
 406                factor = factor * (
 407                    self.summative_mask_for_previous_tasks + self.base_linear
 408                )
 409
 410            # calculate the adjustment rate
 411            adjustment_rate_weight_layer = torch.div(
 412                self.adjustment_intensity,
 413                (factor * weight_importance + self.adjustment_intensity),
 414            )
 415
 416            adjustment_rate_bias_layer = torch.div(
 417                self.adjustment_intensity,
 418                (factor * bias_importance + self.adjustment_intensity),
 419            )
 420
 421            # apply the adjustment rate to the gradients
 422            layer.weight.grad.data *= adjustment_rate_weight_layer
 423            if layer.bias is not None:
 424                layer.bias.grad.data *= adjustment_rate_bias_layer
 425
 426            # store the adjustment rate for logging
 427            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
 428            if layer.bias is not None:
 429                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
 430
 431            # update network capacity metric
 432            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
 433
 434        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()
 435
 436    def on_train_batch_end(
 437        self, outputs: dict[str, Any], batch: Any, batch_idx: int
 438    ) -> None:
 439        r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.
 440
 441        **Args:**
 442        - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`).
 443        - **batch** (`Any`): training data batch.
 444        - **batch_idx** (`int`): index of the current batch (for mask figure file name).
 445        """
 446
 447        # get potential useful information from training batch
 448        activations = outputs["activations"]
 449        input = outputs["input"]
 450        target = outputs["target"]
 451        mask = outputs["mask"]
 452        num_batches = self.trainer.num_training_batches
 453
 454        for layer_name in self.backbone.weighted_layer_names:
 455            # layer-wise operation
 456
 457            activation = activations[layer_name]
 458
 459            # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper.
 460            if self.importance_type == "input_weight_abs_sum":
 461                importance_step = self.get_importance_step_layer_weight_abs_sum(
 462                    layer_name=layer_name,
 463                    if_output_weight=False,
 464                    reciprocal=False,
 465                )
 466            elif self.importance_type == "output_weight_abs_sum":
 467                importance_step = self.get_importance_step_layer_weight_abs_sum(
 468                    layer_name=layer_name,
 469                    if_output_weight=True,
 470                    reciprocal=False,
 471                )
 472            elif self.importance_type == "input_weight_gradient_abs_sum":
 473                importance_step = (
 474                    self.get_importance_step_layer_weight_gradient_abs_sum(
 475                        layer_name=layer_name, if_output_weight=False
 476                    )
 477                )
 478            elif self.importance_type == "output_weight_gradient_abs_sum":
 479                importance_step = (
 480                    self.get_importance_step_layer_weight_gradient_abs_sum(
 481                        layer_name=layer_name, if_output_weight=True
 482                    )
 483                )
 484            elif self.importance_type == "activation_abs":
 485                importance_step = self.get_importance_step_layer_activation_abs(
 486                    activation=activation
 487                )
 488            elif self.importance_type == "input_weight_abs_sum_x_activation_abs":
 489                importance_step = (
 490                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
 491                        layer_name=layer_name,
 492                        activation=activation,
 493                        if_output_weight=False,
 494                    )
 495                )
 496            elif self.importance_type == "output_weight_abs_sum_x_activation_abs":
 497                importance_step = (
 498                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
 499                        layer_name=layer_name,
 500                        activation=activation,
 501                        if_output_weight=True,
 502                    )
 503                )
 504            elif self.importance_type == "gradient_x_activation_abs":
 505                importance_step = (
 506                    self.get_importance_step_layer_gradient_x_activation_abs(
 507                        layer_name=layer_name,
 508                        input=input,
 509                        target=target,
 510                        batch_idx=batch_idx,
 511                        num_batches=num_batches,
 512                    )
 513                )
 514            elif self.importance_type == "input_weight_gradient_square_sum":
 515                importance_step = (
 516                    self.get_importance_step_layer_weight_gradient_square_sum(
 517                        layer_name=layer_name,
 518                        activation=activation,
 519                        if_output_weight=False,
 520                    )
 521                )
 522            elif self.importance_type == "output_weight_gradient_square_sum":
 523                importance_step = (
 524                    self.get_importance_step_layer_weight_gradient_square_sum(
 525                        layer_name=layer_name,
 526                        activation=activation,
 527                        if_output_weight=True,
 528                    )
 529                )
 530            elif (
 531                self.importance_type
 532                == "input_weight_gradient_square_sum_x_activation_abs"
 533            ):
 534                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 535                    layer_name=layer_name,
 536                    activation=activation,
 537                    if_output_weight=False,
 538                )
 539            elif (
 540                self.importance_type
 541                == "output_weight_gradient_square_sum_x_activation_abs"
 542            ):
 543                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 544                    layer_name=layer_name,
 545                    activation=activation,
 546                    if_output_weight=True,
 547                )
 548            elif self.importance_type == "conductance_abs":
 549                importance_step = self.get_importance_step_layer_conductance_abs(
 550                    layer_name=layer_name,
 551                    input=input,
 552                    baselines=None,
 553                    target=target,
 554                    batch_idx=batch_idx,
 555                    num_batches=num_batches,
 556                )
 557            elif self.importance_type == "internal_influence_abs":
 558                importance_step = self.get_importance_step_layer_internal_influence_abs(
 559                    layer_name=layer_name,
 560                    input=input,
 561                    baselines=None,
 562                    target=target,
 563                    batch_idx=batch_idx,
 564                    num_batches=num_batches,
 565                )
 566            elif self.importance_type == "gradcam_abs":
 567                importance_step = self.get_importance_step_layer_gradcam_abs(
 568                    layer_name=layer_name,
 569                    input=input,
 570                    target=target,
 571                    batch_idx=batch_idx,
 572                    num_batches=num_batches,
 573                )
 574            elif self.importance_type == "deeplift_abs":
 575                importance_step = self.get_importance_step_layer_deeplift_abs(
 576                    layer_name=layer_name,
 577                    input=input,
 578                    baselines=None,
 579                    target=target,
 580                    batch_idx=batch_idx,
 581                    num_batches=num_batches,
 582                )
 583            elif self.importance_type == "deepliftshap_abs":
 584                importance_step = self.get_importance_step_layer_deepliftshap_abs(
 585                    layer_name=layer_name,
 586                    input=input,
 587                    baselines=None,
 588                    target=target,
 589                    batch_idx=batch_idx,
 590                    num_batches=num_batches,
 591                )
 592            elif self.importance_type == "gradientshap_abs":
 593                importance_step = self.get_importance_step_layer_gradientshap_abs(
 594                    layer_name=layer_name,
 595                    input=input,
 596                    baselines=None,
 597                    target=target,
 598                    batch_idx=batch_idx,
 599                    num_batches=num_batches,
 600                )
 601            elif self.importance_type == "integrated_gradients_abs":
 602                importance_step = (
 603                    self.get_importance_step_layer_integrated_gradients_abs(
 604                        layer_name=layer_name,
 605                        input=input,
 606                        baselines=None,
 607                        target=target,
 608                        batch_idx=batch_idx,
 609                        num_batches=num_batches,
 610                    )
 611                )
 612            elif self.importance_type == "feature_ablation_abs":
 613                importance_step = self.get_importance_step_layer_feature_ablation_abs(
 614                    layer_name=layer_name,
 615                    input=input,
 616                    layer_baselines=None,
 617                    target=target,
 618                    batch_idx=batch_idx,
 619                    num_batches=num_batches,
 620                )
 621            elif self.importance_type == "lrp_abs":
 622                importance_step = self.get_importance_step_layer_lrp_abs(
 623                    layer_name=layer_name,
 624                    input=input,
 625                    target=target,
 626                    batch_idx=batch_idx,
 627                    num_batches=num_batches,
 628                )
 629            elif self.importance_type == "cbp_adaptation":
 630                importance_step = self.get_importance_step_layer_weight_abs_sum(
 631                    layer_name=layer_name,
 632                    if_output_weight=False,
 633                    reciprocal=True,
 634                )
 635            elif self.importance_type == "cbp_adaptive_contribution":
 636                importance_step = (
 637                    self.get_importance_step_layer_cbp_adaptive_contribution(
 638                        layer_name=layer_name,
 639                        activation=activation,
 640                    )
 641                )
 642
 643            importance_step = min_max_normalize(
 644                importance_step
 645            )  # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper
 646
 647            # multiply the importance by the training mask. See Eq. (6) in the paper
 648            if self.step_multiply_training_mask:
 649                importance_step = importance_step * mask[layer_name]
 650
 651            # update accumulated importance
 652            self.importances[self.task_id][layer_name] = (
 653                self.importances[self.task_id][layer_name] + importance_step
 654            )
 655
 656        # update number of steps counter
 657        self.num_steps_t += 1
 658
 659    def on_train_end(self) -> None:
 660        r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task."""
 661        super().on_train_end()  # store the mask and update cumulative and summative masks
 662
 663        for layer_name in self.backbone.weighted_layer_names:
 664
 665            # average the neuron-wise step importance. See Eq. (4) in the paper
 666            self.importances[self.task_id][layer_name] = (
 667                self.importances[self.task_id][layer_name]
 668            ) / self.num_steps_t
 669
 670            # add the base importance. See Eq. (6) in the paper
 671            self.importances[self.task_id][layer_name] = (
 672                self.importances[self.task_id][layer_name] + self.base_importance
 673            )
 674
 675            # filter unmasked importance
 676            if self.filter_unmasked_importance:
 677                self.importances[self.task_id][layer_name] = (
 678                    self.importances[self.task_id][layer_name]
 679                    * self.backbone.masks[f"{self.task_id}"][layer_name]
 680                )
 681
 682            # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper
 683            if self.importance_summing_strategy == "add_latest":
 684                self.summative_importance_for_previous_tasks[
 685                    layer_name
 686                ] += self.importances[self.task_id][layer_name]
 687
 688            elif self.importance_summing_strategy == "add_all":
 689                for t in range(1, self.task_id + 1):
 690                    self.summative_importance_for_previous_tasks[
 691                        layer_name
 692                    ] += self.importances[t][layer_name]
 693
 694            elif self.importance_summing_strategy == "add_average":
 695                for t in range(1, self.task_id + 1):
 696                    self.summative_importance_for_previous_tasks[layer_name] += (
 697                        self.importances[t][layer_name] / self.task_id
 698                    )
 699            else:
 700                self.summative_importance_for_previous_tasks[
 701                    layer_name
 702                ] = torch.zeros_like(
 703                    self.summative_importance_for_previous_tasks[layer_name]
 704                ).to(
 705                    self.device
 706                )  # starting adding from 0
 707
 708                if self.importance_summing_strategy == "linear_decrease":
 709                    s = self.importance_summing_strategy_linear_step
 710                    for t in range(1, self.task_id + 1):
 711                        w_t = s * (self.task_id - t) + 1
 712
 713                elif self.importance_summing_strategy == "quadratic_decrease":
 714                    for t in range(1, self.task_id + 1):
 715                        w_t = (self.task_id - t + 1) ** 2
 716                elif self.importance_summing_strategy == "cubic_decrease":
 717                    for t in range(1, self.task_id + 1):
 718                        w_t = (self.task_id - t + 1) ** 3
 719                elif self.importance_summing_strategy == "exponential_decrease":
 720                    for t in range(1, self.task_id + 1):
 721                        r = self.importance_summing_strategy_exponential_rate
 722
 723                        w_t = r ** (self.task_id - t + 1)
 724                elif self.importance_summing_strategy == "log_decrease":
 725                    a = self.importance_summing_strategy_log_base
 726                    for t in range(1, self.task_id + 1):
 727                        w_t = math.log(self.task_id - t, a) + 1
 728                elif self.importance_summing_strategy == "factorial_decrease":
 729                    for t in range(1, self.task_id + 1):
 730                        w_t = math.factorial(self.task_id - t + 1)
 731                else:
 732                    raise ValueError
 733                self.summative_importance_for_previous_tasks[layer_name] += (
 734                    self.importances[t][layer_name] * w_t
 735                )
 736
 737    def get_importance_step_layer_weight_abs_sum(
 738        self: str,
 739        layer_name: str,
 740        if_output_weight: bool,
 741        reciprocal: bool,
 742    ) -> Tensor:
 743        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.
 744
 745        **Args:**
 746        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 747        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 748        - **reciprocal** (`bool`): whether to take reciprocal.
 749
 750        **Returns:**
 751        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 752        """
 753        layer = self.backbone.get_layer_by_name(layer_name)
 754
 755        if not if_output_weight:
 756            weight_abs = torch.abs(layer.weight.data)
 757            weight_abs_sum = torch.sum(
 758                weight_abs,
 759                dim=[
 760                    i for i in range(weight_abs.dim()) if i != 0
 761                ],  # sum over the input dimension
 762            )
 763        else:
 764            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
 765            weight_abs_sum = torch.sum(
 766                weight_abs,
 767                dim=[
 768                    i for i in range(weight_abs.dim()) if i != 1
 769                ],  # sum over the output dimension
 770            )
 771
 772        if reciprocal:
 773            weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum)
 774            importance_step_layer = weight_abs_sum_reciprocal
 775        else:
 776            importance_step_layer = weight_abs_sum
 777        importance_step_layer = importance_step_layer.detach()
 778
 779        return importance_step_layer
 780
 781    def get_importance_step_layer_weight_gradient_abs_sum(
 782        self: str,
 783        layer_name: str,
 784        if_output_weight: bool,
 785    ) -> Tensor:
 786        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.
 787
 788        **Args:**
 789        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 790        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 791
 792        **Returns:**
 793        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 794        """
 795        layer = self.backbone.get_layer_by_name(layer_name)
 796
 797        if not if_output_weight:
 798            gradient_abs = torch.abs(layer.weight.grad.data)
 799            gradient_abs_sum = torch.sum(
 800                gradient_abs,
 801                dim=[
 802                    i for i in range(gradient_abs.dim()) if i != 0
 803                ],  # sum over the input dimension
 804            )
 805        else:
 806            gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data)
 807            gradient_abs_sum = torch.sum(
 808                gradient_abs,
 809                dim=[
 810                    i for i in range(gradient_abs.dim()) if i != 1
 811                ],  # sum over the output dimension
 812            )
 813
 814        importance_step_layer = gradient_abs_sum
 815        importance_step_layer = importance_step_layer.detach()
 816
 817        return importance_step_layer
 818
 819    def get_importance_step_layer_activation_abs(
 820        self: str,
 821        activation: Tensor,
 822    ) -> Tensor:
 823        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum.
 824
 825        **Args:**
 826        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 827
 828        **Returns:**
 829        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 830        """
 831        activation_abs_batch_mean = torch.mean(
 832            torch.abs(activation),
 833            dim=[
 834                i for i in range(activation.dim()) if i != 1
 835            ],  # average the features over batch samples
 836        )
 837        importance_step_layer = activation_abs_batch_mean
 838        importance_step_layer = importance_step_layer.detach()
 839
 840        return importance_step_layer
 841
 842    def get_importance_step_layer_weight_abs_sum_x_activation_abs(
 843        self: str,
 844        layer_name: str,
 845        activation: Tensor,
 846        if_output_weight: bool,
 847    ) -> Tensor:
 848        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
 849
 850        **Args:**
 851        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 852        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 853        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 854
 855        **Returns:**
 856        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 857        """
 858        layer = self.backbone.get_layer_by_name(layer_name)
 859
 860        if not if_output_weight:
 861            weight_abs = torch.abs(layer.weight.data)
 862            weight_abs_sum = torch.sum(
 863                weight_abs,
 864                dim=[
 865                    i for i in range(weight_abs.dim()) if i != 0
 866                ],  # sum over the input dimension
 867            )
 868        else:
 869            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
 870            weight_abs_sum = torch.sum(
 871                weight_abs,
 872                dim=[
 873                    i for i in range(weight_abs.dim()) if i != 1
 874                ],  # sum over the output dimension
 875            )
 876
 877        activation_abs_batch_mean = torch.mean(
 878            torch.abs(activation),
 879            dim=[
 880                i for i in range(activation.dim()) if i != 1
 881            ],  # average the features over batch samples
 882        )
 883
 884        importance_step_layer = weight_abs_sum * activation_abs_batch_mean
 885        importance_step_layer = importance_step_layer.detach()
 886
 887        return importance_step_layer
 888
 889    def get_importance_step_layer_gradient_x_activation_abs(
 890        self: str,
 891        layer_name: str,
 892        input: Tensor | tuple[Tensor, ...],
 893        target: Tensor | None,
 894        batch_idx: int,
 895        num_batches: int,
 896    ) -> Tensor:
 897        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum.
 898
 899        **Args:**
 900        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 901        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
 902        - **target** (`Tensor` | `None`): the target batch of the training step.
 903        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
 904        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
 905
 906        **Returns:**
 907        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 908        """
 909        layer = self.backbone.get_layer_by_name(layer_name)
 910
 911        input = input.requires_grad_()
 912
 913        # initialize the Layer Gradient X Activation object
 914        layer_gradient_x_activation = LayerGradientXActivation(
 915            forward_func=self.forward, layer=layer
 916        )
 917
 918        self.set_forward_func_return_logits_only(True)
 919        # calculate layer attribution of the step
 920        attribution = layer_gradient_x_activation.attribute(
 921            inputs=input,
 922            target=target,
 923            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
 924        )
 925        self.set_forward_func_return_logits_only(False)
 926
 927        attribution_abs_batch_mean = torch.mean(
 928            torch.abs(attribution),
 929            dim=[
 930                i for i in range(attribution.dim()) if i != 1
 931            ],  # average the features over batch samples
 932        )
 933
 934        importance_step_layer = attribution_abs_batch_mean
 935        importance_step_layer = importance_step_layer.detach()
 936
 937        return importance_step_layer
 938
 939    def get_importance_step_layer_weight_gradient_square_sum(
 940        self: str,
 941        layer_name: str,
 942        activation: Tensor,
 943        if_output_weight: bool,
 944    ) -> Tensor:
 945        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 946
 947        **Args:**
 948        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 949        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 950        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 951
 952        **Returns:**
 953        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 954        """
 955        layer = self.backbone.get_layer_by_name(layer_name)
 956
 957        if not if_output_weight:
 958            gradient_square = layer.weight.grad.data**2
 959            gradient_square_sum = torch.sum(
 960                gradient_square,
 961                dim=[
 962                    i for i in range(gradient_square.dim()) if i != 0
 963                ],  # sum over the input dimension
 964            )
 965        else:
 966            gradient_square = self.next_layer(layer_name).weight.grad.data**2
 967            gradient_square_sum = torch.sum(
 968                gradient_square,
 969                dim=[
 970                    i for i in range(gradient_square.dim()) if i != 1
 971                ],  # sum over the output dimension
 972            )
 973
 974        importance_step_layer = gradient_square_sum
 975        importance_step_layer = importance_step_layer.detach()
 976
 977        return importance_step_layer
 978
 979    def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 980        self: str,
 981        layer_name: str,
 982        activation: Tensor,
 983        if_output_weight: bool,
 984    ) -> Tensor:
 985        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 986
 987        **Args:**
 988        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 989        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 990        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 991
 992        **Returns:**
 993        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 994        """
 995        layer = self.backbone.get_layer_by_name(layer_name)
 996
 997        if not if_output_weight:
 998            gradient_square = layer.weight.grad.data**2
 999            gradient_square_sum = torch.sum(
1000                gradient_square,
1001                dim=[
1002                    i for i in range(gradient_square.dim()) if i != 0
1003                ],  # sum over the input dimension
1004            )
1005        else:
1006            gradient_square = self.next_layer(layer_name).weight.grad.data**2
1007            gradient_square_sum = torch.sum(
1008                gradient_square,
1009                dim=[
1010                    i for i in range(gradient_square.dim()) if i != 1
1011                ],  # sum over the output dimension
1012            )
1013
1014        activation_abs_batch_mean = torch.mean(
1015            torch.abs(activation),
1016            dim=[
1017                i for i in range(activation.dim()) if i != 1
1018            ],  # average the features over batch samples
1019        )
1020
1021        importance_step_layer = gradient_square_sum * activation_abs_batch_mean
1022        importance_step_layer = importance_step_layer.detach()
1023
1024        return importance_step_layer
1025
1026    def get_importance_step_layer_conductance_abs(
1027        self: str,
1028        layer_name: str,
1029        input: Tensor | tuple[Tensor, ...],
1030        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1031        target: Tensor | None,
1032        batch_idx: int,
1033        num_batches: int,
1034    ) -> Tensor:
1035        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum.
1036
1037        **Args:**
1038        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1039        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1040        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details.
1041        - **target** (`Tensor` | `None`): the target batch of the training step.
1042        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1043        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).
1044
1045        **Returns:**
1046        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1047        """
1048        layer = self.backbone.get_layer_by_name(layer_name)
1049
1050        # initialize the Layer Conductance object
1051        layer_conductance = LayerConductance(forward_func=self.forward, layer=layer)
1052
1053        self.set_forward_func_return_logits_only(True)
1054        # calculate layer attribution of the step
1055        attribution = layer_conductance.attribute(
1056            inputs=input,
1057            baselines=baselines,
1058            target=target,
1059            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1060        )
1061        self.set_forward_func_return_logits_only(False)
1062
1063        attribution_abs_batch_mean = torch.mean(
1064            torch.abs(attribution),
1065            dim=[
1066                i for i in range(attribution.dim()) if i != 1
1067            ],  # average the features over batch samples
1068        )
1069
1070        importance_step_layer = attribution_abs_batch_mean
1071        importance_step_layer = importance_step_layer.detach()
1072
1073        return importance_step_layer
1074
1075    def get_importance_step_layer_internal_influence_abs(
1076        self: str,
1077        layer_name: str,
1078        input: Tensor | tuple[Tensor, ...],
1079        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1080        target: Tensor | None,
1081        batch_idx: int,
1082        num_batches: int,
1083    ) -> Tensor:
1084        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum.
1085
1086        **Args:**
1087        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1088        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1089        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details.
1090        - **target** (`Tensor` | `None`): the target batch of the training step.
1091        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1092        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1093
1094        **Returns:**
1095        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1096        """
1097        layer = self.backbone.get_layer_by_name(layer_name)
1098
1099        # initialize the Internal Influence object
1100        internal_influence = InternalInfluence(forward_func=self.forward, layer=layer)
1101
1102        # convert the target to long type to avoid error
1103        target = target.long() if target is not None else None
1104
1105        self.set_forward_func_return_logits_only(True)
1106        # calculate layer attribution of the step
1107        attribution = internal_influence.attribute(
1108            inputs=input,
1109            baselines=baselines,
1110            target=target,
1111            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1112            n_steps=5,  # set 10 instead of default 50 to accelerate the computation
1113        )
1114        self.set_forward_func_return_logits_only(False)
1115
1116        attribution_abs_batch_mean = torch.mean(
1117            torch.abs(attribution),
1118            dim=[
1119                i for i in range(attribution.dim()) if i != 1
1120            ],  # average the features over batch samples
1121        )
1122
1123        importance_step_layer = attribution_abs_batch_mean
1124        importance_step_layer = importance_step_layer.detach()
1125
1126        return importance_step_layer
1127
1128    def get_importance_step_layer_gradcam_abs(
1129        self: str,
1130        layer_name: str,
1131        input: Tensor | tuple[Tensor, ...],
1132        target: Tensor | None,
1133        batch_idx: int,
1134        num_batches: int,
1135    ) -> Tensor:
1136        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum.
1137
1138        **Args:**
1139        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1140        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1141        - **target** (`Tensor` | `None`): the target batch of the training step.
1142        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1143        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1144
1145        **Returns:**
1146        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1147        """
1148        layer = self.backbone.get_layer_by_name(layer_name)
1149
1150        # initialize the GradCAM object
1151        gradcam = LayerGradCam(forward_func=self.forward, layer=layer)
1152
1153        self.set_forward_func_return_logits_only(True)
1154        # calculate layer attribution of the step
1155        attribution = gradcam.attribute(
1156            inputs=input,
1157            target=target,
1158            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1159        )
1160        self.set_forward_func_return_logits_only(False)
1161
1162        attribution_abs_batch_mean = torch.mean(
1163            torch.abs(attribution),
1164            dim=[
1165                i for i in range(attribution.dim()) if i != 1
1166            ],  # average the features over batch samples
1167        )
1168
1169        importance_step_layer = attribution_abs_batch_mean
1170        importance_step_layer = importance_step_layer.detach()
1171
1172        return importance_step_layer
1173
1174    def get_importance_step_layer_deeplift_abs(
1175        self: str,
1176        layer_name: str,
1177        input: Tensor | tuple[Tensor, ...],
1178        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1179        target: Tensor | None,
1180        batch_idx: int,
1181        num_batches: int,
1182    ) -> Tensor:
1183        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum.
1184
1185        **Args:**
1186        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1187        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1188        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details.
1189        - **target** (`Tensor` | `None`): the target batch of the training step.
1190        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1191        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1192
1193        **Returns:**
1194        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1195        """
1196        layer = self.backbone.get_layer_by_name(layer_name)
1197
1198        # initialize the Layer DeepLift object
1199        layer_deeplift = LayerDeepLift(model=self, layer=layer)
1200
1201        # convert the target to long type to avoid error
1202        target = target.long() if target is not None else None
1203
1204        self.set_forward_func_return_logits_only(True)
1205        # calculate layer attribution of the step
1206        attribution = layer_deeplift.attribute(
1207            inputs=input,
1208            baselines=baselines,
1209            target=target,
1210            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1211        )
1212        self.set_forward_func_return_logits_only(False)
1213
1214        attribution_abs_batch_mean = torch.mean(
1215            torch.abs(attribution),
1216            dim=[
1217                i for i in range(attribution.dim()) if i != 1
1218            ],  # average the features over batch samples
1219        )
1220
1221        importance_step_layer = attribution_abs_batch_mean
1222        importance_step_layer = importance_step_layer.detach()
1223
1224        return importance_step_layer
1225
1226    def get_importance_step_layer_deepliftshap_abs(
1227        self: str,
1228        layer_name: str,
1229        input: Tensor | tuple[Tensor, ...],
1230        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1231        target: Tensor | None,
1232        batch_idx: int,
1233        num_batches: int,
1234    ) -> Tensor:
1235        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum.
1236
1237        **Args:**
1238        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1239        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1240        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details.
1241        - **target** (`Tensor` | `None`): the target batch of the training step.
1242        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1243        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1244
1245        **Returns:**
1246        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1247        """
1248        layer = self.backbone.get_layer_by_name(layer_name)
1249
1250        # initialize the Layer DeepLiftShap object
1251        layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer)
1252
1253        # convert the target to long type to avoid error
1254        target = target.long() if target is not None else None
1255
1256        self.set_forward_func_return_logits_only(True)
1257        # calculate layer attribution of the step
1258        attribution = layer_deepliftshap.attribute(
1259            inputs=input,
1260            baselines=baselines,
1261            target=target,
1262            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1263        )
1264        self.set_forward_func_return_logits_only(False)
1265
1266        attribution_abs_batch_mean = torch.mean(
1267            torch.abs(attribution),
1268            dim=[
1269                i for i in range(attribution.dim()) if i != 1
1270            ],  # average the features over batch samples
1271        )
1272
1273        importance_step_layer = attribution_abs_batch_mean
1274        importance_step_layer = importance_step_layer.detach()
1275
1276        return importance_step_layer
1277
1278    def get_importance_step_layer_gradientshap_abs(
1279        self: str,
1280        layer_name: str,
1281        input: Tensor | tuple[Tensor, ...],
1282        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1283        target: Tensor | None,
1284        batch_idx: int,
1285        num_batches: int,
1286    ) -> Tensor:
1287        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum.
1288
1289        **Args:**
1290        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1291        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1292        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero.
1293        - **target** (`Tensor` | `None`): the target batch of the training step.
1294        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1295        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1296
1297        **Returns:**
1298        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1299        """
1300        layer = self.backbone.get_layer_by_name(layer_name)
1301
1302        if baselines is None:
1303            baselines = torch.zeros_like(
1304                input
1305            )  # baselines are mandatory for GradientShap API. We explicitly set them to zero
1306
1307        # initialize the Layer GradientShap object
1308        layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer)
1309
1310        # convert the target to long type to avoid error
1311        target = target.long() if target is not None else None
1312
1313        self.set_forward_func_return_logits_only(True)
1314        # calculate layer attribution of the step
1315        attribution = layer_gradientshap.attribute(
1316            inputs=input,
1317            baselines=baselines,
1318            target=target,
1319            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1320        )
1321        self.set_forward_func_return_logits_only(False)
1322
1323        attribution_abs_batch_mean = torch.mean(
1324            torch.abs(attribution),
1325            dim=[
1326                i for i in range(attribution.dim()) if i != 1
1327            ],  # average the features over batch samples
1328        )
1329
1330        importance_step_layer = attribution_abs_batch_mean
1331        importance_step_layer = importance_step_layer.detach()
1332
1333        return importance_step_layer
1334
1335    def get_importance_step_layer_integrated_gradients_abs(
1336        self: str,
1337        layer_name: str,
1338        input: Tensor | tuple[Tensor, ...],
1339        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1340        target: Tensor | None,
1341        batch_idx: int,
1342        num_batches: int,
1343    ) -> Tensor:
1344        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum.
1345
1346        **Args:**
1347        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1348        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1349        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details.
1350        - **target** (`Tensor` | `None`): the target batch of the training step.
1351        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1352        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1353
1354        **Returns:**
1355        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1356        """
1357        layer = self.backbone.get_layer_by_name(layer_name)
1358
1359        # initialize the Layer Integrated Gradients object
1360        layer_integrated_gradients = LayerIntegratedGradients(
1361            forward_func=self.forward, layer=layer
1362        )
1363
1364        self.set_forward_func_return_logits_only(True)
1365        # calculate layer attribution of the step
1366        attribution = layer_integrated_gradients.attribute(
1367            inputs=input,
1368            baselines=baselines,
1369            target=target,
1370            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1371        )
1372        self.set_forward_func_return_logits_only(False)
1373
1374        attribution_abs_batch_mean = torch.mean(
1375            torch.abs(attribution),
1376            dim=[
1377                i for i in range(attribution.dim()) if i != 1
1378            ],  # average the features over batch samples
1379        )
1380
1381        importance_step_layer = attribution_abs_batch_mean
1382        importance_step_layer = importance_step_layer.detach()
1383
1384        return importance_step_layer
1385
1386    def get_importance_step_layer_feature_ablation_abs(
1387        self: str,
1388        layer_name: str,
1389        input: Tensor | tuple[Tensor, ...],
1390        layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1391        target: Tensor | None,
1392        batch_idx: int,
1393        num_batches: int,
1394        if_captum: bool = False,
1395    ) -> Tensor:
1396        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum.
1397
1398        **Args:**
1399        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1400        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1401        - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details.
1402        - **target** (`Tensor` | `None`): the target batch of the training step.
1403        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1404        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1405        - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster.
1406
1407        **Returns:**
1408        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1409        """
1410        layer = self.backbone.get_layer_by_name(layer_name)
1411
1412        if not if_captum:
1413            # 1. Baseline logits (take first element of forward output)
1414            baseline_out, _, _ = self.forward(
1415                input, "train", batch_idx, num_batches, self.task_id
1416            )
1417            if target is not None:
1418                baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1)
1419            else:
1420                baseline_scores = baseline_out.sum(dim=1)
1421
1422            # 2. Capture layer’s output shape
1423            activs = {}
1424            handle = layer.register_forward_hook(
1425                lambda module, inp, out: activs.setdefault("output", out.detach())
1426            )
1427            _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id)
1428            handle.remove()
1429            layer_output = activs["output"]  # shape (B, F, ...)
1430
1431            # 3. Build baseline tensor matching that shape
1432            if layer_baselines is None:
1433                baseline_tensor = torch.zeros_like(layer_output)
1434            elif isinstance(layer_baselines, (int, float)):
1435                baseline_tensor = torch.full_like(layer_output, layer_baselines)
1436            elif isinstance(layer_baselines, Tensor):
1437                if layer_baselines.shape == layer_output.shape:
1438                    baseline_tensor = layer_baselines
1439                elif layer_baselines.shape == layer_output.shape[1:]:
1440                    baseline_tensor = layer_baselines.unsqueeze(0).repeat(
1441                        layer_output.size(0), *([1] * layer_baselines.ndim)
1442                    )
1443                else:
1444                    raise ValueError(...)
1445            else:
1446                raise ValueError(...)
1447
1448            B, F = layer_output.size(0), layer_output.size(1)
1449
1450            # 4. Create a “mega-batch” replicating the input F times
1451            if isinstance(input, tuple):
1452                mega_inputs = tuple(
1453                    t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:])
1454                    for t in input
1455                )
1456            else:
1457                mega_inputs = (
1458                    input.unsqueeze(0)
1459                    .repeat(F, *([1] * input.ndim))
1460                    .view(-1, *input.shape[1:])
1461                )
1462
1463            # 5. Equally replicate the baseline tensor
1464            mega_baseline = (
1465                baseline_tensor.unsqueeze(0)
1466                .repeat(F, *([1] * baseline_tensor.ndim))
1467                .view(-1, *baseline_tensor.shape[1:])
1468            )
1469
1470            # 6. Precompute vectorized indices
1471            device = layer_output.device
1472            positions = torch.arange(F * B, device=device)  # [0,1,...,F*B-1]
1473            feat_idx = torch.arange(F, device=device).repeat_interleave(
1474                B
1475            )  # [0,0,...,1,1,...,F-1]
1476
1477            # 7. One hook to zero out each channel slice across the mega-batch
1478            def mega_ablate_hook(module, inp, out):
1479                out_mod = out.clone()
1480                # for each sample in mega-batch, zero its corresponding channel
1481                out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx]
1482                return out_mod
1483
1484            h = layer.register_forward_hook(mega_ablate_hook)
1485            out_all, _, _ = self.forward(
1486                mega_inputs, "train", batch_idx, num_batches, self.task_id
1487            )
1488            h.remove()
1489
1490            # 8. Recover scores, reshape [F*B] → [F, B], diff & mean
1491            if target is not None:
1492                tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1)
1493                scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1)
1494            else:
1495                scores_all = out_all.sum(dim=1)
1496
1497            scores_all = scores_all.view(F, B)
1498            diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all)
1499            importance_step_layer = diffs.mean(dim=1).detach()  # [F]
1500
1501            return importance_step_layer
1502
1503        else:
1504            # initialize the Layer Feature Ablation object
1505            layer_feature_ablation = LayerFeatureAblation(
1506                forward_func=self.forward, layer=layer
1507            )
1508
1509            # calculate layer attribution of the step
1510            self.set_forward_func_return_logits_only(True)
1511            attribution = layer_feature_ablation.attribute(
1512                inputs=input,
1513                layer_baselines=layer_baselines,
1514                # target=target, # disable target to enable perturbations_per_eval
1515                additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1516                perturbations_per_eval=128,  # to accelerate the computation
1517            )
1518            self.set_forward_func_return_logits_only(False)
1519
1520            attribution_abs_batch_mean = torch.mean(
1521                torch.abs(attribution),
1522                dim=[
1523                    i for i in range(attribution.dim()) if i != 1
1524                ],  # average the features over batch samples
1525            )
1526
1527        importance_step_layer = attribution_abs_batch_mean
1528        importance_step_layer = importance_step_layer.detach()
1529
1530        return importance_step_layer
1531
1532    def get_importance_step_layer_lrp_abs(
1533        self: str,
1534        layer_name: str,
1535        input: Tensor | tuple[Tensor, ...],
1536        target: Tensor | None,
1537        batch_idx: int,
1538        num_batches: int,
1539    ) -> Tensor:
1540        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum.
1541
1542        **Args:**
1543        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1544        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1545        - **target** (`Tensor` | `None`): the target batch of the training step.
1546        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1547        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1548
1549        **Returns:**
1550        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1551        """
1552        layer = self.backbone.get_layer_by_name(layer_name)
1553
1554        # initialize the Layer LRP object
1555        layer_lrp = LayerLRP(model=self, layer=layer)
1556
1557        # set model to evaluation mode to prevent updating the model parameters
1558        self.eval()
1559
1560        self.set_forward_func_return_logits_only(True)
1561        # calculate layer attribution of the step
1562        attribution = layer_lrp.attribute(
1563            inputs=input,
1564            target=target,
1565            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1566        )
1567        self.set_forward_func_return_logits_only(False)
1568
1569        attribution_abs_batch_mean = torch.mean(
1570            torch.abs(attribution),
1571            dim=[
1572                i for i in range(attribution.dim()) if i != 1
1573            ],  # average the features over batch samples
1574        )
1575
1576        importance_step_layer = attribution_abs_batch_mean
1577        importance_step_layer = importance_step_layer.detach()
1578
1579        return importance_step_layer
1580
1581    def get_importance_step_layer_cbp_adaptive_contribution(
1582        self: str,
1583        layer_name: str,
1584        activation: Tensor,
1585    ) -> Tensor:
1586        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
1587
1588        **Args:**
1589        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1590        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
1591
1592        **Returns:**
1593        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1594        """
1595        layer = self.backbone.get_layer_by_name(layer_name)
1596
1597        input_weight_abs = torch.abs(layer.weight.data)
1598        input_weight_abs_sum = torch.sum(
1599            input_weight_abs,
1600            dim=[
1601                i for i in range(input_weight_abs.dim()) if i != 0
1602            ],  # sum over the input dimension
1603        )
1604        input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum)
1605
1606        output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
1607        output_weight_abs_sum = torch.sum(
1608            output_weight_abs,
1609            dim=[
1610                i for i in range(output_weight_abs.dim()) if i != 1
1611            ],  # sum over the output dimension
1612        )
1613
1614        activation_abs_batch_mean = torch.mean(
1615            torch.abs(activation),
1616            dim=[
1617                i for i in range(activation.dim()) if i != 1
1618            ],  # average the features over batch samples
1619        )
1620
1621        importance_step_layer = (
1622            output_weight_abs_sum
1623            * activation_abs_batch_mean
1624            * input_weight_abs_sum_reciprocal
1625        )
1626        importance_step_layer = importance_step_layer.detach()
1627
1628        return importance_step_layer

FG-AdaHAT (Fine-Grained Adaptive Hard Attention to the Task) algorithm.

An architecture-based continual learning approach that improves AdaHAT (Adaptive Hard Attention to the Task) by introducing fine-grained neuron-wise importance measures guiding the adaptive adjustment mechanism in AdaHAT.

We implement FG-AdaHAT as a subclass of AdaHAT, as it reuses AdaHAT's summative mask and other components.

FGAdaHAT( backbone: clarena.backbones.HATMaskBackbone, heads: clarena.heads.HeadsTIL, adjustment_intensity: float, importance_type: str, importance_summing_strategy: str, importance_scheduler_type: str, neuron_to_weight_importance_aggregation_mode: str, s_max: float, clamp_threshold: float, mask_sparsity_reg_factor: float, mask_sparsity_reg_mode: str = 'original', base_importance: float = 0.01, base_mask_sparsity_reg: float = 0.1, base_linear: float = 10, filter_by_cumulative_mask: bool = False, filter_unmasked_importance: bool = True, step_multiply_training_mask: bool = True, task_embedding_init_mode: str = 'N01', importance_summing_strategy_linear_step: float | None = None, importance_summing_strategy_exponential_rate: float | None = None, importance_summing_strategy_log_base: float | None = None, non_algorithmic_hparams: dict[str, typing.Any] = {})
 45    def __init__(
 46        self,
 47        backbone: HATMaskBackbone,
 48        heads: HeadsTIL,
 49        adjustment_intensity: float,
 50        importance_type: str,
 51        importance_summing_strategy: str,
 52        importance_scheduler_type: str,
 53        neuron_to_weight_importance_aggregation_mode: str,
 54        s_max: float,
 55        clamp_threshold: float,
 56        mask_sparsity_reg_factor: float,
 57        mask_sparsity_reg_mode: str = "original",
 58        base_importance: float = 0.01,
 59        base_mask_sparsity_reg: float = 0.1,
 60        base_linear: float = 10,
 61        filter_by_cumulative_mask: bool = False,
 62        filter_unmasked_importance: bool = True,
 63        step_multiply_training_mask: bool = True,
 64        task_embedding_init_mode: str = "N01",
 65        importance_summing_strategy_linear_step: float | None = None,
 66        importance_summing_strategy_exponential_rate: float | None = None,
 67        importance_summing_strategy_log_base: float | None = None,
 68        non_algorithmic_hparams: dict[str, Any] = {},
 69    ) -> None:
 70        r"""Initialize the FG-AdaHAT algorithm with the network.
 71
 72        **Args:**
 73        - **backbone** (`HATMaskBackbone`): must be a backbone network with the HAT mask mechanism.
 74        - **heads** (`HeadsTIL`): output heads. FG-AdaHAT supports only TIL (Task-Incremental Learning).
 75        - **adjustment_intensity** (`float`): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper).
 76        - **importance_type** (`str`): the type of neuron-wise importance, must be one of:
 77            1. 'input_weight_abs_sum': sum of absolute input weights;
 78            2. 'output_weight_abs_sum': sum of absolute output weights;
 79            3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
 80            4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
 81            5. 'activation_abs': absolute activation;
 82            6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
 83            7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
 84            8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
 85            9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
 86            10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
 87            11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
 88            12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
 89            13. 'conductance_abs': absolute layer conductance;
 90            14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
 91            15. 'gradcam_abs': absolute Grad-CAM;
 92            16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
 93            17. 'deepliftshap_abs': absolute DeepLIFT-SHAP;
 94            18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
 95            19. 'integrated_gradients_abs': absolute Integrated Gradients;
 96            20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
 97            21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
 98            22. 'cbp_adaptation': the adaptation function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
 99            23. 'cbp_adaptive_contribution': the adaptive contribution function in [Continual Backpropagation (CBP)](https://www.nature.com/articles/s41586-024-07711-7);
100        - **importance_summing_strategy** (`str`): the strategy to sum neuron-wise importance for previous tasks, must be one of:
101            1. 'add_latest': add the latest neuron-wise importance to the summative importance;
102            2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
103            3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
104            4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
105            5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
106            6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
107            7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
108            8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
109            9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
110        - **importance_scheduler_type** (`str`): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:
111            1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{<t}) + b_R]$, where $R(M^t, M^{<t})$ is the mask sparsity regularization betwwen the current task and previous tasks, $b_L$ is the base linear factor (see argument `base_linear`), and $b_R$ is the base mask sparsity regularization factor (see argument `base_mask_sparsity_reg`);
112            2. 'sparsity_reg': $c^t = [R(M^t, M^{<t}) + b_R]$;
113            3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{<t, \text{sum}}_{l,i}, m^{<t, \text{sum}}_{l-1,j}\right)+b_L\right) \cdot [R(M^t, M^{<t}) + b_R]$.
114        - **neuron_to_weight_importance_aggregation_mode** (`str`): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:
115            1. 'min': take the minimum of neuron-wise importance for each weight;
116            2. 'max': take the maximum of neuron-wise importance for each weight;
117            3. 'mean': take the mean of neuron-wise importance for each weight.
118        - **s_max** (`float`): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
119        - **clamp_threshold** (`float`): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
120        - **mask_sparsity_reg_factor** (`float`): hyperparameter, the regularization factor for mask sparsity.
121        - **mask_sparsity_reg_mode** (`str`): the mode of mask sparsity regularization, must be one of:
122            1. 'original' (default): the original mask sparsity regularization in the [HAT paper](http://proceedings.mlr.press/v80/serra18a).
123            2. 'cross': the cross version of mask sparsity regularization.
124        - **base_importance** (`float`): base value added to importance ($b_I$ in the paper). Default: 0.01.
125        - **base_mask_sparsity_reg** (`float`): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1.
126        - **base_linear** (`float`): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10.
127        - **filter_by_cumulative_mask** (`bool`): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False.
128        - **filter_unmasked_importance** (`bool`): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False.
129        - **step_multiply_training_mask** (`bool`): whether to multiply the training mask to the importance at each training step. Default: True.
130        - **task_embedding_init_mode** (`str`): the initialization mode for task embeddings, must be one of:
131            1. 'N01' (default): standard normal distribution $N(0, 1)$.
132            2. 'U-11': uniform distribution $U(-1, 1)$.
133            3. 'U01': uniform distribution $U(0, 1)$.
134            4. 'U-10': uniform distribution $U(-1, 0)$.
135            5. 'last': inherit the task embedding from the last task.
136        - **importance_summing_strategy_linear_step** (`float` | `None`): linear step for the importance summing strategy (used when `importance_summing_strategy` is 'linear_decrease'). Must be > 0.
137        - **importance_summing_strategy_exponential_rate** (`float` | `None`): exponential rate for the importance summing strategy (used when `importance_summing_strategy` is 'exponential_decrease'). Must be > 1.
138        - **importance_summing_strategy_log_base** (`float` | `None`): base for the logarithm in the importance summing strategy (used when `importance_summing_strategy` is 'log_decrease'). Must be > 1.
139        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
140
141        """
142        super().__init__(
143            backbone=backbone,
144            heads=heads,
145            adjustment_mode=None,  # use the own adjustment mechanism of FG-AdaHAT
146            adjustment_intensity=adjustment_intensity,
147            s_max=s_max,
148            clamp_threshold=clamp_threshold,
149            mask_sparsity_reg_factor=mask_sparsity_reg_factor,
150            mask_sparsity_reg_mode=mask_sparsity_reg_mode,
151            task_embedding_init_mode=task_embedding_init_mode,
152            epsilon=base_mask_sparsity_reg,  # the epsilon is now the base mask sparsity regularization factor
153            non_algorithmic_hparams=non_algorithmic_hparams,
154        )
155
156        # save additional algorithmic hyperparameters
157        self.save_hyperparameters(
158            "adjustment_intensity",
159            "importance_type",
160            "importance_summing_strategy",
161            "importance_scheduler_type",
162            "neuron_to_weight_importance_aggregation_mode",
163            "s_max",
164            "clamp_threshold",
165            "mask_sparsity_reg_factor",
166            "mask_sparsity_reg_mode",
167            "base_importance",
168            "base_mask_sparsity_reg",
169            "base_linear",
170            "filter_by_cumulative_mask",
171            "filter_unmasked_importance",
172            "step_multiply_training_mask",
173        )
174
175        self.importance_type: str | None = importance_type
176        r"""The type of the neuron-wise importance added to AdaHAT importance."""
177
178        self.importance_scheduler_type: str = importance_scheduler_type
179        r"""The type of the importance scheduler."""
180        self.neuron_to_weight_importance_aggregation_mode: str = (
181            neuron_to_weight_importance_aggregation_mode
182        )
183        r"""The mode of aggregation from neuron-wise to weight-wise importance. """
184        self.filter_by_cumulative_mask: bool = filter_by_cumulative_mask
185        r"""The flag to filter importance by the cumulative mask when calculating the adjustment rate."""
186        self.filter_unmasked_importance: bool = filter_unmasked_importance
187        r"""The flag to filter unmasked importance values (set them to 0) at the end of task training."""
188        self.step_multiply_training_mask: bool = step_multiply_training_mask
189        r"""The flag to multiply the training mask to the importance at each training step."""
190
191        # importance summing strategy
192        self.importance_summing_strategy: str = importance_summing_strategy
193        r"""The strategy to sum the neuron-wise importance for previous tasks."""
194        if importance_summing_strategy_linear_step is not None:
195            self.importance_summing_strategy_linear_step: float = (
196                importance_summing_strategy_linear_step
197            )
198            r"""The linear step for the importance summing strategy (only when `importance_summing_strategy` is 'linear_decrease')."""
199        if importance_summing_strategy_exponential_rate is not None:
200            self.importance_summing_strategy_exponential_rate: float = (
201                importance_summing_strategy_exponential_rate
202            )
203            r"""The exponential rate for the importance summing strategy (only when `importance_summing_strategy` is 'exponential_decrease'). """
204        if importance_summing_strategy_log_base is not None:
205            self.importance_summing_strategy_log_base: float = (
206                importance_summing_strategy_log_base
207            )
208            r"""The base for the logarithm in the importance summing strategy (only when `importance_summing_strategy` is 'log_decrease'). """
209
210        # base values
211        self.base_importance: float = base_importance
212        r"""The base value added to the importance to avoid zero. """
213        self.base_mask_sparsity_reg: float = base_mask_sparsity_reg
214        r"""The base value added to the mask sparsity regularization to avoid zero. """
215        self.base_linear: float = base_linear
216        r"""The base value added to the linear layer to avoid zero. """
217
218        self.importances: dict[int, dict[str, Tensor]] = {}
219        r"""The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
220        self.summative_importance_for_previous_tasks: dict[str, Tensor] = {}
221        r"""The summative neuron-wise importance values of units for previous tasks before the current task `self.task_id`. See $I^{<t}_{l}$ in the paper. Keys are layer names and values are the summative importance tensor for the layer. The summative importance tensor has the same size as the feature tensor with size (number of units, ). """
222
223        self.num_steps_t: int
224        r"""The number of training steps for the current task `self.task_id`."""
225        # set manual optimization
226        self.automatic_optimization = False
227
228        FGAdaHAT.sanity_check(self)

Initialize the FG-AdaHAT algorithm with the network.

Args:

  • backbone (HATMaskBackbone): must be a backbone network with the HAT mask mechanism.
  • heads (HeadsTIL): output heads. FG-AdaHAT supports only TIL (Task-Incremental Learning).
  • adjustment_intensity (float): hyperparameter, controls the overall intensity of gradient adjustment (the $\alpha$ in the paper).
  • importance_type (str): the type of neuron-wise importance, must be one of:
    1. 'input_weight_abs_sum': sum of absolute input weights;
    2. 'output_weight_abs_sum': sum of absolute output weights;
    3. 'input_weight_gradient_abs_sum': sum of absolute gradients of the input weights (Input Gradients (IG) in the paper);
    4. 'output_weight_gradient_abs_sum': sum of absolute gradients of the output weights (Output Gradients (OG) in the paper);
    5. 'activation_abs': absolute activation;
    6. 'input_weight_abs_sum_x_activation_abs': sum of absolute input weights multiplied by absolute activation (Input Contribution Utility (ICU) in the paper);
    7. 'output_weight_abs_sum_x_activation_abs': sum of absolute output weights multiplied by absolute activation (Contribution Utility (CU) in the paper);
    8. 'gradient_x_activation_abs': absolute gradient (the saliency) multiplied by activation;
    9. 'input_weight_gradient_square_sum': sum of squared gradients of the input weights;
    10. 'output_weight_gradient_square_sum': sum of squared gradients of the output weights;
    11. 'input_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the input weights multiplied by absolute activation (Activation Fisher Information (AFI) in the paper);
    12. 'output_weight_gradient_square_sum_x_activation_abs': sum of squared gradients of the output weights multiplied by absolute activation;
    13. 'conductance_abs': absolute layer conductance;
    14. 'internal_influence_abs': absolute internal influence (Internal Influence (II) in the paper);
    15. 'gradcam_abs': absolute Grad-CAM;
    16. 'deeplift_abs': absolute DeepLIFT (DeepLIFT (DL) in the paper);
    17. 'deepliftshap_abs': absolute DeepLIFT-SHAP;
    18. 'gradientshap_abs': absolute Gradient-SHAP (Gradient SHAP (GS) in the paper);
    19. 'integrated_gradients_abs': absolute Integrated Gradients;
    20. 'feature_ablation_abs': absolute Feature Ablation (Feature Ablation (FA) in the paper);
    21. 'lrp_abs': absolute Layer-wise Relevance Propagation (LRP);
    22. 'cbp_adaptation': the adaptation function in Continual Backpropagation (CBP);
    23. 'cbp_adaptive_contribution': the adaptive contribution function in Continual Backpropagation (CBP);
  • importance_summing_strategy (str): the strategy to sum neuron-wise importance for previous tasks, must be one of:
    1. 'add_latest': add the latest neuron-wise importance to the summative importance;
    2. 'add_all': add all previous neuron-wise importance (including the latest) to the summative importance;
    3. 'add_average': add the average of all previous neuron-wise importance (including the latest) to the summative importance;
    4. 'linear_decrease': weigh the previous neuron-wise importance by a linear factor that decreases with the task ID;
    5. 'quadratic_decrease': weigh the previous neuron-wise importance that decreases quadratically with the task ID;
    6. 'cubic_decrease': weigh the previous neuron-wise importance that decreases cubically with the task ID;
    7. 'exponential_decrease': weigh the previous neuron-wise importance by an exponential factor that decreases with the task ID;
    8. 'log_decrease': weigh the previous neuron-wise importance by a logarithmic factor that decreases with the task ID;
    9. 'factorial_decrease': weigh the previous neuron-wise importance that decreases factorially with the task ID;
  • importance_scheduler_type (str): the scheduler for importance, i.e., the factor $c^t$ multiplied to parameter importance. Must be one of:
    1. 'linear_sparsity_reg': $c^t = (t+b_L) \cdot [R(M^t, M^{base_linear), and $b_R$ is the base mask sparsity regularization factor (see argument base_mask_sparsity_reg);
    2. 'sparsity_reg': $c^t = [R(M^t, M^{
    3. 'summative_mask_sparsity_reg': $c^t_{l,ij} = \left(\min \left(m^{
  • neuron_to_weight_importance_aggregation_mode (str): aggregation mode from neuron-wise to weight-wise importance ($\text{Agg}(\cdot)$ in the paper), must be one of:
    1. 'min': take the minimum of neuron-wise importance for each weight;
    2. 'max': take the maximum of neuron-wise importance for each weight;
    3. 'mean': take the mean of neuron-wise importance for each weight.
  • s_max (float): hyperparameter, the maximum scaling factor in the gate function. See Sec. 2.4 "Hard Attention Training" in the HAT paper.
  • clamp_threshold (float): the threshold for task embedding gradient compensation. See Sec. 2.5 "Embedding Gradient Compensation" in the HAT paper.
  • mask_sparsity_reg_factor (float): hyperparameter, the regularization factor for mask sparsity.
  • mask_sparsity_reg_mode (str): the mode of mask sparsity regularization, must be one of:
    1. 'original' (default): the original mask sparsity regularization in the HAT paper.
    2. 'cross': the cross version of mask sparsity regularization.
  • base_importance (float): base value added to importance ($b_I$ in the paper). Default: 0.01.
  • base_mask_sparsity_reg (float): base value added to mask sparsity regularization factor in the importance scheduler ($b_R$ in the paper). Default: 0.1.
  • base_linear (float): base value added to the linear factor in the importance scheduler ($b_L$ in the paper). Default: 10.
  • filter_by_cumulative_mask (bool): whether to multiply the cumulative mask to the importance when calculating adjustment rate. Default: False.
  • filter_unmasked_importance (bool): whether to filter unmasked importance values (set to 0) at the end of task training. Default: False.
  • step_multiply_training_mask (bool): whether to multiply the training mask to the importance at each training step. Default: True.
  • task_embedding_init_mode (str): the initialization mode for task embeddings, must be one of:
    1. 'N01' (default): standard normal distribution $N(0, 1)$.
    2. 'U-11': uniform distribution $U(-1, 1)$.
    3. 'U01': uniform distribution $U(0, 1)$.
    4. 'U-10': uniform distribution $U(-1, 0)$.
    5. 'last': inherit the task embedding from the last task.
  • importance_summing_strategy_linear_step (float | None): linear step for the importance summing strategy (used when importance_summing_strategy is 'linear_decrease'). Must be > 0.
  • importance_summing_strategy_exponential_rate (float | None): exponential rate for the importance summing strategy (used when importance_summing_strategy is 'exponential_decrease'). Must be > 1.
  • importance_summing_strategy_log_base (float | None): base for the logarithm in the importance summing strategy (used when importance_summing_strategy is 'log_decrease'). Must be > 1.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
importance_type: str | None

The type of the neuron-wise importance added to AdaHAT importance.

importance_scheduler_type: str

The type of the importance scheduler.

neuron_to_weight_importance_aggregation_mode: str

The mode of aggregation from neuron-wise to weight-wise importance.

filter_by_cumulative_mask: bool

The flag to filter importance by the cumulative mask when calculating the adjustment rate.

filter_unmasked_importance: bool

The flag to filter unmasked importance values (set them to 0) at the end of task training.

step_multiply_training_mask: bool

The flag to multiply the training mask to the importance at each training step.

importance_summing_strategy: str

The strategy to sum the neuron-wise importance for previous tasks.

base_importance: float

The base value added to the importance to avoid zero.

base_mask_sparsity_reg: float

The base value added to the mask sparsity regularization to avoid zero.

base_linear: float

The base value added to the linear layer to avoid zero.

importances: dict[int, dict[str, torch.Tensor]]

The min-max scaled ($[0, 1]$) neuron-wise importance of units. It is $I^{\tau}_{l}$ in the paper. Keys are task IDs and values are the corresponding importance tensors. Each importance tensor is a dict where keys are layer names and values are the importance tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ).

summative_importance_for_previous_tasks: dict[str, torch.Tensor]

The summative neuron-wise importance values of units for previous tasks before the current task self.task_id. See $I^{

num_steps_t: int

The number of training steps for the current task self.task_id.

automatic_optimization: bool
290    @property
291    def automatic_optimization(self) -> bool:
292        """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
293        return self._automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

def sanity_check(self) -> None:
230    def sanity_check(self) -> None:
231        r"""Sanity check."""
232
233        # check importance type
234        if self.importance_type not in [
235            "input_weight_abs_sum",
236            "output_weight_abs_sum",
237            "input_weight_gradient_abs_sum",
238            "output_weight_gradient_abs_sum",
239            "activation_abs",
240            "input_weight_abs_sum_x_activation_abs",
241            "output_weight_abs_sum_x_activation_abs",
242            "gradient_x_activation_abs",
243            "input_weight_gradient_square_sum",
244            "output_weight_gradient_square_sum",
245            "input_weight_gradient_square_sum_x_activation_abs",
246            "output_weight_gradient_square_sum_x_activation_abs",
247            "conductance_abs",
248            "internal_influence_abs",
249            "gradcam_abs",
250            "deeplift_abs",
251            "deepliftshap_abs",
252            "gradientshap_abs",
253            "integrated_gradients_abs",
254            "feature_ablation_abs",
255            "lrp_abs",
256            "cbp_adaptation",
257            "cbp_adaptive_contribution",
258        ]:
259            raise ValueError(
260                f"importance_type must be one of the predefined types, but got {self.importance_type}"
261            )
262
263        # check importance summing strategy
264        if self.importance_summing_strategy not in [
265            "add_latest",
266            "add_all",
267            "add_average",
268            "linear_decrease",
269            "quadratic_decrease",
270            "cubic_decrease",
271            "exponential_decrease",
272            "log_decrease",
273            "factorial_decrease",
274        ]:
275            raise ValueError(
276                f"importance_summing_strategy must be one of the predefined strategies, but got {self.importance_summing_strategy}"
277            )
278
279        # check importance scheduler type
280        if self.importance_scheduler_type not in [
281            "linear_sparsity_reg",
282            "sparsity_reg",
283            "summative_mask_sparsity_reg",
284        ]:
285            raise ValueError(
286                f"importance_scheduler_type must be one of the predefined types, but got {self.importance_scheduler_type}"
287            )
288
289        # check neuron to weight importance aggregation mode
290        if self.neuron_to_weight_importance_aggregation_mode not in [
291            "min",
292            "max",
293            "mean",
294        ]:
295            raise ValueError(
296                f"neuron_to_weight_importance_aggregation_mode must be one of the predefined modes, but got {self.neuron_to_weight_importance_aggregation_mode}"
297            )
298
299        # check base values
300        if self.base_importance < 0:
301            raise ValueError(
302                f"base_importance must be >= 0, but got {self.base_importance}"
303            )
304        if self.base_mask_sparsity_reg <= 0:
305            raise ValueError(
306                f"base_mask_sparsity_reg must be > 0, but got {self.base_mask_sparsity_reg}"
307            )
308        if self.base_linear <= 0:
309            raise ValueError(f"base_linear must be > 0, but got {self.base_linear}")

Sanity check.

def on_train_start(self) -> None:
311    def on_train_start(self) -> None:
312        r"""Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization."""
313        super().on_train_start()
314
315        self.importances[self.task_id] = (
316            {}
317        )  # initialize the importance for the current task
318
319        # initialize the neuron importance at the beginning of each task. This should not be called in `__init__()` method because `self.device` is not available at that time.
320        for layer_name in self.backbone.weighted_layer_names:
321            layer = self.backbone.get_layer_by_name(
322                layer_name
323            )  # get the layer by its name
324            num_units = layer.weight.shape[0]
325
326            # initialize the accumulated importance at the beginning of each task
327            self.importances[self.task_id][layer_name] = torch.zeros(num_units).to(
328                self.device
329            )
330
331            # reset the number of steps counter for the current task
332            self.num_steps_t = 0
333
334            # initialize the summative neuron-wise importance at the beginning of the first task
335            if self.task_id == 1:
336                self.summative_importance_for_previous_tasks[layer_name] = torch.zeros(
337                    num_units
338                ).to(
339                    self.device
340                )  # the summative neuron-wise importance for previous tasks $I^{<t}_{l}$ is initialized as zeros mask when $t=1$

Initialize neuron importance accumulation variable for each layer as zeros, in addition to AdaHAT's summative mask initialization.

def clip_grad_by_adjustment( self, network_sparsity: dict[str, torch.Tensor]) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], torch.Tensor]:
342    def clip_grad_by_adjustment(
343        self,
344        network_sparsity: dict[str, Tensor],
345    ) -> tuple[dict[str, Tensor], dict[str, Tensor], Tensor]:
346        r"""Clip the gradients by the adjustment rate. See Eq. (1) in the paper.
347
348        Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.
349
350        Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
351
352        **Args:**
353        - **network_sparsity** (`dict[str, Tensor]`): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.
354
355        **Returns:**
356        - **adjustment_rate_weight** (`dict[str, Tensor]`): the adjustment rate for weights. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
357        - **adjustment_rate_bias** (`dict[str, Tensor]`): the adjustment rate for biases. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors.
358        - **capacity** (`Tensor`): the calculated network capacity.
359        """
360
361        # initialize network capacity metric
362        capacity = HATNetworkCapacityMetric().to(self.device)
363        adjustment_rate_weight = {}
364        adjustment_rate_bias = {}
365
366        # calculate the adjustment rate for gradients of the parameters, both weights and biases (if they exist). See Eq. (2) in the paper
367        for layer_name in self.backbone.weighted_layer_names:
368
369            layer = self.backbone.get_layer_by_name(
370                layer_name
371            )  # get the layer by its name
372
373            # placeholder for the adjustment rate to avoid the error of using it before assignment
374            adjustment_rate_weight_layer = 1
375            adjustment_rate_bias_layer = 1
376
377            # aggregate the neuron-wise importance to weight-wise importance. Note that the neuron-wise importance has already been min-max scaled to $[0, 1]$ in the `on_train_batch_end()` method, added the base value, and filtered by the mask
378            weight_importance, bias_importance = (
379                self.backbone.get_layer_measure_parameter_wise(
380                    neuron_wise_measure=self.summative_importance_for_previous_tasks,
381                    layer_name=layer_name,
382                    aggregation_mode=self.neuron_to_weight_importance_aggregation_mode,
383                )
384            )
385
386            weight_mask, bias_mask = self.backbone.get_layer_measure_parameter_wise(
387                neuron_wise_measure=self.cumulative_mask_for_previous_tasks,
388                layer_name=layer_name,
389                aggregation_mode="min",
390            )
391
392            # filter the weight importance by the cumulative mask
393            if self.filter_by_cumulative_mask:
394                weight_importance = weight_importance * weight_mask
395                bias_importance = bias_importance * bias_mask
396
397            network_sparsity_layer = network_sparsity[layer_name]
398
399            # calculate importance scheduler (the factor of importance). See Eq. (3) in the paper
400            factor = network_sparsity_layer + self.base_mask_sparsity_reg
401            if self.importance_scheduler_type == "linear_sparsity_reg":
402                factor = factor * (self.task_id + self.base_linear)
403            elif self.importance_scheduler_type == "sparsity_reg":
404                pass
405            elif self.importance_scheduler_type == "summative_mask_sparsity_reg":
406                factor = factor * (
407                    self.summative_mask_for_previous_tasks + self.base_linear
408                )
409
410            # calculate the adjustment rate
411            adjustment_rate_weight_layer = torch.div(
412                self.adjustment_intensity,
413                (factor * weight_importance + self.adjustment_intensity),
414            )
415
416            adjustment_rate_bias_layer = torch.div(
417                self.adjustment_intensity,
418                (factor * bias_importance + self.adjustment_intensity),
419            )
420
421            # apply the adjustment rate to the gradients
422            layer.weight.grad.data *= adjustment_rate_weight_layer
423            if layer.bias is not None:
424                layer.bias.grad.data *= adjustment_rate_bias_layer
425
426            # store the adjustment rate for logging
427            adjustment_rate_weight[layer_name] = adjustment_rate_weight_layer
428            if layer.bias is not None:
429                adjustment_rate_bias[layer_name] = adjustment_rate_bias_layer
430
431            # update network capacity metric
432            capacity.update(adjustment_rate_weight_layer, adjustment_rate_bias_layer)
433
434        return adjustment_rate_weight, adjustment_rate_bias, capacity.compute()

Clip the gradients by the adjustment rate. See Eq. (1) in the paper.

Note that because the task embedding fully covers every layer in the backbone network, no parameters are left out of this system. This applies not only to parameters between layers with task embeddings, but also to those before the first layer. We design it separately in the code.

Network capacity is measured alongside this method. Network capacity is defined as the average adjustment rate over all parameters. See Sec. 4.1 in the AdaHAT paper.

Args:

  • network_sparsity (dict[str, Tensor]): the network sparsity (i.e., mask sparsity loss of each layer) for the current task. Keys are layer names and values are the network sparsity values. It is used to calculate the adjustment rate for gradients. In FG-AdaHAT, it is used to construct the importance scheduler.

Returns:

  • adjustment_rate_weight (dict[str, Tensor]): the adjustment rate for weights. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors.
  • adjustment_rate_bias (dict[str, Tensor]): the adjustment rate for biases. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors.
  • capacity (Tensor): the calculated network capacity.
def on_train_batch_end(self, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
436    def on_train_batch_end(
437        self, outputs: dict[str, Any], batch: Any, batch_idx: int
438    ) -> None:
439        r"""Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.
440
441        **Args:**
442        - **outputs** (`dict[str, Any]`): outputs of the training step (returns of `training_step()` in `CLAlgorithm`).
443        - **batch** (`Any`): training data batch.
444        - **batch_idx** (`int`): index of the current batch (for mask figure file name).
445        """
446
447        # get potential useful information from training batch
448        activations = outputs["activations"]
449        input = outputs["input"]
450        target = outputs["target"]
451        mask = outputs["mask"]
452        num_batches = self.trainer.num_training_batches
453
454        for layer_name in self.backbone.weighted_layer_names:
455            # layer-wise operation
456
457            activation = activations[layer_name]
458
459            # calculate neuron-wise importance of the training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper.
460            if self.importance_type == "input_weight_abs_sum":
461                importance_step = self.get_importance_step_layer_weight_abs_sum(
462                    layer_name=layer_name,
463                    if_output_weight=False,
464                    reciprocal=False,
465                )
466            elif self.importance_type == "output_weight_abs_sum":
467                importance_step = self.get_importance_step_layer_weight_abs_sum(
468                    layer_name=layer_name,
469                    if_output_weight=True,
470                    reciprocal=False,
471                )
472            elif self.importance_type == "input_weight_gradient_abs_sum":
473                importance_step = (
474                    self.get_importance_step_layer_weight_gradient_abs_sum(
475                        layer_name=layer_name, if_output_weight=False
476                    )
477                )
478            elif self.importance_type == "output_weight_gradient_abs_sum":
479                importance_step = (
480                    self.get_importance_step_layer_weight_gradient_abs_sum(
481                        layer_name=layer_name, if_output_weight=True
482                    )
483                )
484            elif self.importance_type == "activation_abs":
485                importance_step = self.get_importance_step_layer_activation_abs(
486                    activation=activation
487                )
488            elif self.importance_type == "input_weight_abs_sum_x_activation_abs":
489                importance_step = (
490                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
491                        layer_name=layer_name,
492                        activation=activation,
493                        if_output_weight=False,
494                    )
495                )
496            elif self.importance_type == "output_weight_abs_sum_x_activation_abs":
497                importance_step = (
498                    self.get_importance_step_layer_weight_abs_sum_x_activation_abs(
499                        layer_name=layer_name,
500                        activation=activation,
501                        if_output_weight=True,
502                    )
503                )
504            elif self.importance_type == "gradient_x_activation_abs":
505                importance_step = (
506                    self.get_importance_step_layer_gradient_x_activation_abs(
507                        layer_name=layer_name,
508                        input=input,
509                        target=target,
510                        batch_idx=batch_idx,
511                        num_batches=num_batches,
512                    )
513                )
514            elif self.importance_type == "input_weight_gradient_square_sum":
515                importance_step = (
516                    self.get_importance_step_layer_weight_gradient_square_sum(
517                        layer_name=layer_name,
518                        activation=activation,
519                        if_output_weight=False,
520                    )
521                )
522            elif self.importance_type == "output_weight_gradient_square_sum":
523                importance_step = (
524                    self.get_importance_step_layer_weight_gradient_square_sum(
525                        layer_name=layer_name,
526                        activation=activation,
527                        if_output_weight=True,
528                    )
529                )
530            elif (
531                self.importance_type
532                == "input_weight_gradient_square_sum_x_activation_abs"
533            ):
534                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
535                    layer_name=layer_name,
536                    activation=activation,
537                    if_output_weight=False,
538                )
539            elif (
540                self.importance_type
541                == "output_weight_gradient_square_sum_x_activation_abs"
542            ):
543                importance_step = self.get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
544                    layer_name=layer_name,
545                    activation=activation,
546                    if_output_weight=True,
547                )
548            elif self.importance_type == "conductance_abs":
549                importance_step = self.get_importance_step_layer_conductance_abs(
550                    layer_name=layer_name,
551                    input=input,
552                    baselines=None,
553                    target=target,
554                    batch_idx=batch_idx,
555                    num_batches=num_batches,
556                )
557            elif self.importance_type == "internal_influence_abs":
558                importance_step = self.get_importance_step_layer_internal_influence_abs(
559                    layer_name=layer_name,
560                    input=input,
561                    baselines=None,
562                    target=target,
563                    batch_idx=batch_idx,
564                    num_batches=num_batches,
565                )
566            elif self.importance_type == "gradcam_abs":
567                importance_step = self.get_importance_step_layer_gradcam_abs(
568                    layer_name=layer_name,
569                    input=input,
570                    target=target,
571                    batch_idx=batch_idx,
572                    num_batches=num_batches,
573                )
574            elif self.importance_type == "deeplift_abs":
575                importance_step = self.get_importance_step_layer_deeplift_abs(
576                    layer_name=layer_name,
577                    input=input,
578                    baselines=None,
579                    target=target,
580                    batch_idx=batch_idx,
581                    num_batches=num_batches,
582                )
583            elif self.importance_type == "deepliftshap_abs":
584                importance_step = self.get_importance_step_layer_deepliftshap_abs(
585                    layer_name=layer_name,
586                    input=input,
587                    baselines=None,
588                    target=target,
589                    batch_idx=batch_idx,
590                    num_batches=num_batches,
591                )
592            elif self.importance_type == "gradientshap_abs":
593                importance_step = self.get_importance_step_layer_gradientshap_abs(
594                    layer_name=layer_name,
595                    input=input,
596                    baselines=None,
597                    target=target,
598                    batch_idx=batch_idx,
599                    num_batches=num_batches,
600                )
601            elif self.importance_type == "integrated_gradients_abs":
602                importance_step = (
603                    self.get_importance_step_layer_integrated_gradients_abs(
604                        layer_name=layer_name,
605                        input=input,
606                        baselines=None,
607                        target=target,
608                        batch_idx=batch_idx,
609                        num_batches=num_batches,
610                    )
611                )
612            elif self.importance_type == "feature_ablation_abs":
613                importance_step = self.get_importance_step_layer_feature_ablation_abs(
614                    layer_name=layer_name,
615                    input=input,
616                    layer_baselines=None,
617                    target=target,
618                    batch_idx=batch_idx,
619                    num_batches=num_batches,
620                )
621            elif self.importance_type == "lrp_abs":
622                importance_step = self.get_importance_step_layer_lrp_abs(
623                    layer_name=layer_name,
624                    input=input,
625                    target=target,
626                    batch_idx=batch_idx,
627                    num_batches=num_batches,
628                )
629            elif self.importance_type == "cbp_adaptation":
630                importance_step = self.get_importance_step_layer_weight_abs_sum(
631                    layer_name=layer_name,
632                    if_output_weight=False,
633                    reciprocal=True,
634                )
635            elif self.importance_type == "cbp_adaptive_contribution":
636                importance_step = (
637                    self.get_importance_step_layer_cbp_adaptive_contribution(
638                        layer_name=layer_name,
639                        activation=activation,
640                    )
641                )
642
643            importance_step = min_max_normalize(
644                importance_step
645            )  # min-max scaling the utility to $[0, 1]$. See Eq. (5) in the paper
646
647            # multiply the importance by the training mask. See Eq. (6) in the paper
648            if self.step_multiply_training_mask:
649                importance_step = importance_step * mask[layer_name]
650
651            # update accumulated importance
652            self.importances[self.task_id][layer_name] = (
653                self.importances[self.task_id][layer_name] + importance_step
654            )
655
656        # update number of steps counter
657        self.num_steps_t += 1

Calculate the step-wise importance, update the accumulated importance and number of steps counter after each training step.

Args:

  • outputs (dict[str, Any]): outputs of the training step (returns of training_step() in CLAlgorithm).
  • batch (Any): training data batch.
  • batch_idx (int): index of the current batch (for mask figure file name).
def on_train_end(self) -> None:
659    def on_train_end(self) -> None:
660        r"""Additionally calculate neuron-wise importance for previous tasks at the end of training each task."""
661        super().on_train_end()  # store the mask and update cumulative and summative masks
662
663        for layer_name in self.backbone.weighted_layer_names:
664
665            # average the neuron-wise step importance. See Eq. (4) in the paper
666            self.importances[self.task_id][layer_name] = (
667                self.importances[self.task_id][layer_name]
668            ) / self.num_steps_t
669
670            # add the base importance. See Eq. (6) in the paper
671            self.importances[self.task_id][layer_name] = (
672                self.importances[self.task_id][layer_name] + self.base_importance
673            )
674
675            # filter unmasked importance
676            if self.filter_unmasked_importance:
677                self.importances[self.task_id][layer_name] = (
678                    self.importances[self.task_id][layer_name]
679                    * self.backbone.masks[f"{self.task_id}"][layer_name]
680                )
681
682            # calculate the summative neuron-wise importance for previous tasks. See Eq. (4) in the paper
683            if self.importance_summing_strategy == "add_latest":
684                self.summative_importance_for_previous_tasks[
685                    layer_name
686                ] += self.importances[self.task_id][layer_name]
687
688            elif self.importance_summing_strategy == "add_all":
689                for t in range(1, self.task_id + 1):
690                    self.summative_importance_for_previous_tasks[
691                        layer_name
692                    ] += self.importances[t][layer_name]
693
694            elif self.importance_summing_strategy == "add_average":
695                for t in range(1, self.task_id + 1):
696                    self.summative_importance_for_previous_tasks[layer_name] += (
697                        self.importances[t][layer_name] / self.task_id
698                    )
699            else:
700                self.summative_importance_for_previous_tasks[
701                    layer_name
702                ] = torch.zeros_like(
703                    self.summative_importance_for_previous_tasks[layer_name]
704                ).to(
705                    self.device
706                )  # starting adding from 0
707
708                if self.importance_summing_strategy == "linear_decrease":
709                    s = self.importance_summing_strategy_linear_step
710                    for t in range(1, self.task_id + 1):
711                        w_t = s * (self.task_id - t) + 1
712
713                elif self.importance_summing_strategy == "quadratic_decrease":
714                    for t in range(1, self.task_id + 1):
715                        w_t = (self.task_id - t + 1) ** 2
716                elif self.importance_summing_strategy == "cubic_decrease":
717                    for t in range(1, self.task_id + 1):
718                        w_t = (self.task_id - t + 1) ** 3
719                elif self.importance_summing_strategy == "exponential_decrease":
720                    for t in range(1, self.task_id + 1):
721                        r = self.importance_summing_strategy_exponential_rate
722
723                        w_t = r ** (self.task_id - t + 1)
724                elif self.importance_summing_strategy == "log_decrease":
725                    a = self.importance_summing_strategy_log_base
726                    for t in range(1, self.task_id + 1):
727                        w_t = math.log(self.task_id - t, a) + 1
728                elif self.importance_summing_strategy == "factorial_decrease":
729                    for t in range(1, self.task_id + 1):
730                        w_t = math.factorial(self.task_id - t + 1)
731                else:
732                    raise ValueError
733                self.summative_importance_for_previous_tasks[layer_name] += (
734                    self.importances[t][layer_name] * w_t
735                )

Additionally calculate neuron-wise importance for previous tasks at the end of training each task.

def get_importance_step_layer_weight_abs_sum( self: str, layer_name: str, if_output_weight: bool, reciprocal: bool) -> torch.Tensor:
737    def get_importance_step_layer_weight_abs_sum(
738        self: str,
739        layer_name: str,
740        if_output_weight: bool,
741        reciprocal: bool,
742    ) -> Tensor:
743        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.
744
745        **Args:**
746        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
747        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
748        - **reciprocal** (`bool`): whether to take reciprocal.
749
750        **Returns:**
751        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
752        """
753        layer = self.backbone.get_layer_by_name(layer_name)
754
755        if not if_output_weight:
756            weight_abs = torch.abs(layer.weight.data)
757            weight_abs_sum = torch.sum(
758                weight_abs,
759                dim=[
760                    i for i in range(weight_abs.dim()) if i != 0
761                ],  # sum over the input dimension
762            )
763        else:
764            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
765            weight_abs_sum = torch.sum(
766                weight_abs,
767                dim=[
768                    i for i in range(weight_abs.dim()) if i != 1
769                ],  # sum over the output dimension
770            )
771
772        if reciprocal:
773            weight_abs_sum_reciprocal = torch.reciprocal(weight_abs_sum)
774            importance_step_layer = weight_abs_sum_reciprocal
775        else:
776            importance_step_layer = weight_abs_sum
777        importance_step_layer = importance_step_layer.detach()
778
779        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input or output weights.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • if_output_weight (bool): whether to use the output weights or input weights.
  • reciprocal (bool): whether to take reciprocal.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_weight_gradient_abs_sum(self: str, layer_name: str, if_output_weight: bool) -> torch.Tensor:
781    def get_importance_step_layer_weight_gradient_abs_sum(
782        self: str,
783        layer_name: str,
784        if_output_weight: bool,
785    ) -> Tensor:
786        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.
787
788        **Args:**
789        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
790        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
791
792        **Returns:**
793        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
794        """
795        layer = self.backbone.get_layer_by_name(layer_name)
796
797        if not if_output_weight:
798            gradient_abs = torch.abs(layer.weight.grad.data)
799            gradient_abs_sum = torch.sum(
800                gradient_abs,
801                dim=[
802                    i for i in range(gradient_abs.dim()) if i != 0
803                ],  # sum over the input dimension
804            )
805        else:
806            gradient_abs = torch.abs(self.next_layer(layer_name).weight.grad.data)
807            gradient_abs_sum = torch.sum(
808                gradient_abs,
809                dim=[
810                    i for i in range(gradient_abs.dim()) if i != 1
811                ],  # sum over the output dimension
812            )
813
814        importance_step_layer = gradient_abs_sum
815        importance_step_layer = importance_step_layer.detach()
816
817        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of gradients of the layer input or output weights.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • if_output_weight (bool): whether to use the output weights or input weights.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_activation_abs(self: str, activation: torch.Tensor) -> torch.Tensor:
819    def get_importance_step_layer_activation_abs(
820        self: str,
821        activation: Tensor,
822    ) -> Tensor:
823        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of [Layer Activation](https://captum.ai/api/layer.html#layer-activation) in Captum.
824
825        **Args:**
826        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
827
828        **Returns:**
829        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
830        """
831        activation_abs_batch_mean = torch.mean(
832            torch.abs(activation),
833            dim=[
834                i for i in range(activation.dim()) if i != 1
835            ],  # average the features over batch samples
836        )
837        importance_step_layer = activation_abs_batch_mean
838        importance_step_layer = importance_step_layer.detach()
839
840        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute value of activation of the layer. This is our own implementation of Layer Activation in Captum.

Args:

  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_weight_abs_sum_x_activation_abs( self: str, layer_name: str, activation: torch.Tensor, if_output_weight: bool) -> torch.Tensor:
842    def get_importance_step_layer_weight_abs_sum_x_activation_abs(
843        self: str,
844        layer_name: str,
845        activation: Tensor,
846        if_output_weight: bool,
847    ) -> Tensor:
848        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
849
850        **Args:**
851        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
852        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
853        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
854
855        **Returns:**
856        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
857        """
858        layer = self.backbone.get_layer_by_name(layer_name)
859
860        if not if_output_weight:
861            weight_abs = torch.abs(layer.weight.data)
862            weight_abs_sum = torch.sum(
863                weight_abs,
864                dim=[
865                    i for i in range(weight_abs.dim()) if i != 0
866                ],  # sum over the input dimension
867            )
868        else:
869            weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
870            weight_abs_sum = torch.sum(
871                weight_abs,
872                dim=[
873                    i for i in range(weight_abs.dim()) if i != 1
874                ],  # sum over the output dimension
875            )
876
877        activation_abs_batch_mean = torch.mean(
878            torch.abs(activation),
879            dim=[
880                i for i in range(activation.dim()) if i != 1
881            ],  # average the features over batch samples
882        )
883
884        importance_step_layer = weight_abs_sum * activation_abs_batch_mean
885        importance_step_layer = importance_step_layer.detach()
886
887        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer input / output weights multiplied by absolute values of activation. The input weights version is equal to the contribution utility in CBP.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
  • if_output_weight (bool): whether to use the output weights or input weights.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_gradient_x_activation_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
889    def get_importance_step_layer_gradient_x_activation_abs(
890        self: str,
891        layer_name: str,
892        input: Tensor | tuple[Tensor, ...],
893        target: Tensor | None,
894        batch_idx: int,
895        num_batches: int,
896    ) -> Tensor:
897        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using [Layer Gradient X Activation](https://captum.ai/api/layer.html#layer-gradient-x-activation) in Captum.
898
899        **Args:**
900        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
901        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
902        - **target** (`Tensor` | `None`): the target batch of the training step.
903        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
904        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
905
906        **Returns:**
907        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
908        """
909        layer = self.backbone.get_layer_by_name(layer_name)
910
911        input = input.requires_grad_()
912
913        # initialize the Layer Gradient X Activation object
914        layer_gradient_x_activation = LayerGradientXActivation(
915            forward_func=self.forward, layer=layer
916        )
917
918        self.set_forward_func_return_logits_only(True)
919        # calculate layer attribution of the step
920        attribution = layer_gradient_x_activation.attribute(
921            inputs=input,
922            target=target,
923            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
924        )
925        self.set_forward_func_return_logits_only(False)
926
927        attribution_abs_batch_mean = torch.mean(
928            torch.abs(attribution),
929            dim=[
930                i for i in range(attribution.dim()) if i != 1
931            ],  # average the features over batch samples
932        )
933
934        importance_step_layer = attribution_abs_batch_mean
935        importance_step_layer = importance_step_layer.detach()
936
937        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of the gradient of layer activation multiplied by the activation. We implement this using Layer Gradient X Activation in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_weight_gradient_square_sum( self: str, layer_name: str, activation: torch.Tensor, if_output_weight: bool) -> torch.Tensor:
939    def get_importance_step_layer_weight_gradient_square_sum(
940        self: str,
941        layer_name: str,
942        activation: Tensor,
943        if_output_weight: bool,
944    ) -> Tensor:
945        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
946
947        **Args:**
948        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
949        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
950        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
951
952        **Returns:**
953        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
954        """
955        layer = self.backbone.get_layer_by_name(layer_name)
956
957        if not if_output_weight:
958            gradient_square = layer.weight.grad.data**2
959            gradient_square_sum = torch.sum(
960                gradient_square,
961                dim=[
962                    i for i in range(gradient_square.dim()) if i != 0
963                ],  # sum over the input dimension
964            )
965        else:
966            gradient_square = self.next_layer(layer_name).weight.grad.data**2
967            gradient_square_sum = torch.sum(
968                gradient_square,
969                dim=[
970                    i for i in range(gradient_square.dim()) if i != 1
971                ],  # sum over the output dimension
972            )
973
974        importance_step_layer = gradient_square_sum
975        importance_step_layer = importance_step_layer.detach()
976
977        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares. The weight gradient square is equal to fisher information in EWC.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
  • if_output_weight (bool): whether to use the output weights or input weights.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs( self: str, layer_name: str, activation: torch.Tensor, if_output_weight: bool) -> torch.Tensor:
 979    def get_importance_step_layer_weight_gradient_square_sum_x_activation_abs(
 980        self: str,
 981        layer_name: str,
 982        activation: Tensor,
 983        if_output_weight: bool,
 984    ) -> Tensor:
 985        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in [EWC](https://www.pnas.org/doi/10.1073/pnas.1611835114).
 986
 987        **Args:**
 988        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
 989        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
 990        - **if_output_weight** (`bool`): whether to use the output weights or input weights.
 991
 992        **Returns:**
 993        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
 994        """
 995        layer = self.backbone.get_layer_by_name(layer_name)
 996
 997        if not if_output_weight:
 998            gradient_square = layer.weight.grad.data**2
 999            gradient_square_sum = torch.sum(
1000                gradient_square,
1001                dim=[
1002                    i for i in range(gradient_square.dim()) if i != 0
1003                ],  # sum over the input dimension
1004            )
1005        else:
1006            gradient_square = self.next_layer(layer_name).weight.grad.data**2
1007            gradient_square_sum = torch.sum(
1008                gradient_square,
1009                dim=[
1010                    i for i in range(gradient_square.dim()) if i != 1
1011                ],  # sum over the output dimension
1012            )
1013
1014        activation_abs_batch_mean = torch.mean(
1015            torch.abs(activation),
1016            dim=[
1017                i for i in range(activation.dim()) if i != 1
1018            ],  # average the features over batch samples
1019        )
1020
1021        importance_step_layer = gradient_square_sum * activation_abs_batch_mean
1022        importance_step_layer = importance_step_layer.detach()
1023
1024        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of layer weight gradient squares multiplied by absolute values of activation. The weight gradient square is equal to fisher information in EWC.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).
  • if_output_weight (bool): whether to use the output weights or input weights.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_conductance_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1026    def get_importance_step_layer_conductance_abs(
1027        self: str,
1028        layer_name: str,
1029        input: Tensor | tuple[Tensor, ...],
1030        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1031        target: Tensor | None,
1032        batch_idx: int,
1033        num_batches: int,
1034    ) -> Tensor:
1035        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [conductance](https://openreview.net/forum?id=SylKoo0cKm). We implement this using [Layer Conductance](https://captum.ai/api/layer.html#layer-conductance) in Captum.
1036
1037        **Args:**
1038        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1039        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1040        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerConductance.attribute) for more details.
1041        - **target** (`Tensor` | `None`): the target batch of the training step.
1042        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1043        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.- **mask** (`Tensor`): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).
1044
1045        **Returns:**
1046        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1047        """
1048        layer = self.backbone.get_layer_by_name(layer_name)
1049
1050        # initialize the Layer Conductance object
1051        layer_conductance = LayerConductance(forward_func=self.forward, layer=layer)
1052
1053        self.set_forward_func_return_logits_only(True)
1054        # calculate layer attribution of the step
1055        attribution = layer_conductance.attribute(
1056            inputs=input,
1057            baselines=baselines,
1058            target=target,
1059            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1060        )
1061        self.set_forward_func_return_logits_only(False)
1062
1063        attribution_abs_batch_mean = torch.mean(
1064            torch.abs(attribution),
1065            dim=[
1066                i for i in range(attribution.dim()) if i != 1
1067            ],  # average the features over batch samples
1068        )
1069
1070        importance_step_layer = attribution_abs_batch_mean
1071        importance_step_layer = importance_step_layer.detach()
1072
1073        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of conductance. We implement this using Layer Conductance in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): starting point from which integral is computed in this method. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.- mask (Tensor): the mask tensor of the layer. It has the same size as the feature tensor with size (number of units, ).

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_internal_influence_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1075    def get_importance_step_layer_internal_influence_abs(
1076        self: str,
1077        layer_name: str,
1078        input: Tensor | tuple[Tensor, ...],
1079        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1080        target: Tensor | None,
1081        batch_idx: int,
1082        num_batches: int,
1083    ) -> Tensor:
1084        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [internal influence](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Internal Influence](https://captum.ai/api/layer.html#internal-influence) in Captum.
1085
1086        **Args:**
1087        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1088        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1089        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed in this method. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.InternalInfluence.attribute) for more details.
1090        - **target** (`Tensor` | `None`): the target batch of the training step.
1091        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1092        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1093
1094        **Returns:**
1095        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1096        """
1097        layer = self.backbone.get_layer_by_name(layer_name)
1098
1099        # initialize the Internal Influence object
1100        internal_influence = InternalInfluence(forward_func=self.forward, layer=layer)
1101
1102        # convert the target to long type to avoid error
1103        target = target.long() if target is not None else None
1104
1105        self.set_forward_func_return_logits_only(True)
1106        # calculate layer attribution of the step
1107        attribution = internal_influence.attribute(
1108            inputs=input,
1109            baselines=baselines,
1110            target=target,
1111            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1112            n_steps=5,  # set 10 instead of default 50 to accelerate the computation
1113        )
1114        self.set_forward_func_return_logits_only(False)
1115
1116        attribution_abs_batch_mean = torch.mean(
1117            torch.abs(attribution),
1118            dim=[
1119                i for i in range(attribution.dim()) if i != 1
1120            ],  # average the features over batch samples
1121        )
1122
1123        importance_step_layer = attribution_abs_batch_mean
1124        importance_step_layer = importance_step_layer.detach()
1125
1126        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of internal influence. We implement this using Internal Influence in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): starting point from which integral is computed in this method. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_gradcam_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1128    def get_importance_step_layer_gradcam_abs(
1129        self: str,
1130        layer_name: str,
1131        input: Tensor | tuple[Tensor, ...],
1132        target: Tensor | None,
1133        batch_idx: int,
1134        num_batches: int,
1135    ) -> Tensor:
1136        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [Grad-CAM](https://openreview.net/forum?id=SJPpHzW0-). We implement this using [Layer Grad-CAM](https://captum.ai/api/layer.html#gradcam) in Captum.
1137
1138        **Args:**
1139        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1140        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1141        - **target** (`Tensor` | `None`): the target batch of the training step.
1142        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1143        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1144
1145        **Returns:**
1146        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1147        """
1148        layer = self.backbone.get_layer_by_name(layer_name)
1149
1150        # initialize the GradCAM object
1151        gradcam = LayerGradCam(forward_func=self.forward, layer=layer)
1152
1153        self.set_forward_func_return_logits_only(True)
1154        # calculate layer attribution of the step
1155        attribution = gradcam.attribute(
1156            inputs=input,
1157            target=target,
1158            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1159        )
1160        self.set_forward_func_return_logits_only(False)
1161
1162        attribution_abs_batch_mean = torch.mean(
1163            torch.abs(attribution),
1164            dim=[
1165                i for i in range(attribution.dim()) if i != 1
1166            ],  # average the features over batch samples
1167        )
1168
1169        importance_step_layer = attribution_abs_batch_mean
1170        importance_step_layer = importance_step_layer.detach()
1171
1172        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of Grad-CAM. We implement this using Layer Grad-CAM in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_deeplift_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1174    def get_importance_step_layer_deeplift_abs(
1175        self: str,
1176        layer_name: str,
1177        input: Tensor | tuple[Tensor, ...],
1178        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1179        target: Tensor | None,
1180        batch_idx: int,
1181        num_batches: int,
1182    ) -> Tensor:
1183        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift](https://proceedings.mlr.press/v70/shrikumar17a/shrikumar17a.pdf). We implement this using [Layer DeepLift](https://captum.ai/api/layer.html#layer-deeplift) in Captum.
1184
1185        **Args:**
1186        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1187        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1188        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLift.attribute) for more details.
1189        - **target** (`Tensor` | `None`): the target batch of the training step.
1190        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1191        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1192
1193        **Returns:**
1194        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1195        """
1196        layer = self.backbone.get_layer_by_name(layer_name)
1197
1198        # initialize the Layer DeepLift object
1199        layer_deeplift = LayerDeepLift(model=self, layer=layer)
1200
1201        # convert the target to long type to avoid error
1202        target = target.long() if target is not None else None
1203
1204        self.set_forward_func_return_logits_only(True)
1205        # calculate layer attribution of the step
1206        attribution = layer_deeplift.attribute(
1207            inputs=input,
1208            baselines=baselines,
1209            target=target,
1210            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1211        )
1212        self.set_forward_func_return_logits_only(False)
1213
1214        attribution_abs_batch_mean = torch.mean(
1215            torch.abs(attribution),
1216            dim=[
1217                i for i in range(attribution.dim()) if i != 1
1218            ],  # average the features over batch samples
1219        )
1220
1221        importance_step_layer = attribution_abs_batch_mean
1222        importance_step_layer = importance_step_layer.detach()
1223
1224        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of DeepLift. We implement this using Layer DeepLift in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): baselines define reference samples that are compared with the inputs. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_deepliftshap_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1226    def get_importance_step_layer_deepliftshap_abs(
1227        self: str,
1228        layer_name: str,
1229        input: Tensor | tuple[Tensor, ...],
1230        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1231        target: Tensor | None,
1232        batch_idx: int,
1233        num_batches: int,
1234    ) -> Tensor:
1235        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [DeepLift SHAP](https://proceedings.neurips.cc/paper_files/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf). We implement this using [Layer DeepLiftShap](https://captum.ai/api/layer.html#layer-deepliftshap) in Captum.
1236
1237        **Args:**
1238        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1239        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1240        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): baselines define reference samples that are compared with the inputs. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerDeepLiftShap.attribute) for more details.
1241        - **target** (`Tensor` | `None`): the target batch of the training step.
1242        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1243        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1244
1245        **Returns:**
1246        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1247        """
1248        layer = self.backbone.get_layer_by_name(layer_name)
1249
1250        # initialize the Layer DeepLiftShap object
1251        layer_deepliftshap = LayerDeepLiftShap(model=self, layer=layer)
1252
1253        # convert the target to long type to avoid error
1254        target = target.long() if target is not None else None
1255
1256        self.set_forward_func_return_logits_only(True)
1257        # calculate layer attribution of the step
1258        attribution = layer_deepliftshap.attribute(
1259            inputs=input,
1260            baselines=baselines,
1261            target=target,
1262            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1263        )
1264        self.set_forward_func_return_logits_only(False)
1265
1266        attribution_abs_batch_mean = torch.mean(
1267            torch.abs(attribution),
1268            dim=[
1269                i for i in range(attribution.dim()) if i != 1
1270            ],  # average the features over batch samples
1271        )
1272
1273        importance_step_layer = attribution_abs_batch_mean
1274        importance_step_layer = importance_step_layer.detach()
1275
1276        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of DeepLift SHAP. We implement this using Layer DeepLiftShap in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): baselines define reference samples that are compared with the inputs. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_gradientshap_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1278    def get_importance_step_layer_gradientshap_abs(
1279        self: str,
1280        layer_name: str,
1281        input: Tensor | tuple[Tensor, ...],
1282        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1283        target: Tensor | None,
1284        batch_idx: int,
1285        num_batches: int,
1286    ) -> Tensor:
1287        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using [Layer GradientShap](https://captum.ai/api/layer.html#layer-gradientshap) in Captum.
1288
1289        **Args:**
1290        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1291        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1292        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which expectation is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerGradientShap.attribute) for more details. If `None`, the baselines are set to zero.
1293        - **target** (`Tensor` | `None`): the target batch of the training step.
1294        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1295        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1296
1297        **Returns:**
1298        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1299        """
1300        layer = self.backbone.get_layer_by_name(layer_name)
1301
1302        if baselines is None:
1303            baselines = torch.zeros_like(
1304                input
1305            )  # baselines are mandatory for GradientShap API. We explicitly set them to zero
1306
1307        # initialize the Layer GradientShap object
1308        layer_gradientshap = LayerGradientShap(forward_func=self.forward, layer=layer)
1309
1310        # convert the target to long type to avoid error
1311        target = target.long() if target is not None else None
1312
1313        self.set_forward_func_return_logits_only(True)
1314        # calculate layer attribution of the step
1315        attribution = layer_gradientshap.attribute(
1316            inputs=input,
1317            baselines=baselines,
1318            target=target,
1319            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1320        )
1321        self.set_forward_func_return_logits_only(False)
1322
1323        attribution_abs_batch_mean = torch.mean(
1324            torch.abs(attribution),
1325            dim=[
1326                i for i in range(attribution.dim()) if i != 1
1327            ],  # average the features over batch samples
1328        )
1329
1330        importance_step_layer = attribution_abs_batch_mean
1331        importance_step_layer = importance_step_layer.detach()
1332
1333        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of gradient SHAP. We implement this using Layer GradientShap in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): starting point from which expectation is computed. Please refer to the Captum documentation for more details. If None, the baselines are set to zero.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_integrated_gradients_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1335    def get_importance_step_layer_integrated_gradients_abs(
1336        self: str,
1337        layer_name: str,
1338        input: Tensor | tuple[Tensor, ...],
1339        baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1340        target: Tensor | None,
1341        batch_idx: int,
1342        num_batches: int,
1343    ) -> Tensor:
1344        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [integrated gradients](https://proceedings.mlr.press/v70/sundararajan17a/sundararajan17a.pdf). We implement this using [Layer Integrated Gradients](https://captum.ai/api/layer.html#layer-integrated-gradients) in Captum.
1345
1346        **Args:**
1347        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1348        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1349        - **baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): starting point from which integral is computed. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients.attribute) for more details.
1350        - **target** (`Tensor` | `None`): the target batch of the training step.
1351        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1352        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1353
1354        **Returns:**
1355        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1356        """
1357        layer = self.backbone.get_layer_by_name(layer_name)
1358
1359        # initialize the Layer Integrated Gradients object
1360        layer_integrated_gradients = LayerIntegratedGradients(
1361            forward_func=self.forward, layer=layer
1362        )
1363
1364        self.set_forward_func_return_logits_only(True)
1365        # calculate layer attribution of the step
1366        attribution = layer_integrated_gradients.attribute(
1367            inputs=input,
1368            baselines=baselines,
1369            target=target,
1370            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1371        )
1372        self.set_forward_func_return_logits_only(False)
1373
1374        attribution_abs_batch_mean = torch.mean(
1375            torch.abs(attribution),
1376            dim=[
1377                i for i in range(attribution.dim()) if i != 1
1378            ],  # average the features over batch samples
1379        )
1380
1381        importance_step_layer = attribution_abs_batch_mean
1382        importance_step_layer = importance_step_layer.detach()
1383
1384        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of integrated gradients. We implement this using Layer Integrated Gradients in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): starting point from which integral is computed. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_feature_ablation_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], layer_baselines: None | int | float | torch.Tensor | tuple[int | float | torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int, if_captum: bool = False) -> torch.Tensor:
1386    def get_importance_step_layer_feature_ablation_abs(
1387        self: str,
1388        layer_name: str,
1389        input: Tensor | tuple[Tensor, ...],
1390        layer_baselines: None | int | float | Tensor | tuple[int | float | Tensor, ...],
1391        target: Tensor | None,
1392        batch_idx: int,
1393        num_batches: int,
1394        if_captum: bool = False,
1395    ) -> Tensor:
1396        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [feature ablation](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) attribution. We implement this using [Layer Feature Ablation](https://captum.ai/api/layer.html#layer-feature-ablation) in Captum.
1397
1398        **Args:**
1399        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1400        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1401        - **layer_baselines** (`None` | `int` | `float` | `Tensor` | `tuple[int | float | Tensor, ...]`): reference values which replace each layer input / output value when ablated. Please refer to the [Captum documentation](https://captum.ai/api/layer.html#captum.attr.LayerFeatureAblation.attribute) for more details.
1402        - **target** (`Tensor` | `None`): the target batch of the training step.
1403        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1404        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1405        - **if_captum** (`bool`): whether to use Captum or not. If `True`, we use Captum to calculate the feature ablation. If `False`, we use our implementation. Default is `False`, because our implementation is much faster.
1406
1407        **Returns:**
1408        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1409        """
1410        layer = self.backbone.get_layer_by_name(layer_name)
1411
1412        if not if_captum:
1413            # 1. Baseline logits (take first element of forward output)
1414            baseline_out, _, _ = self.forward(
1415                input, "train", batch_idx, num_batches, self.task_id
1416            )
1417            if target is not None:
1418                baseline_scores = baseline_out.gather(1, target.view(-1, 1)).squeeze(1)
1419            else:
1420                baseline_scores = baseline_out.sum(dim=1)
1421
1422            # 2. Capture layer’s output shape
1423            activs = {}
1424            handle = layer.register_forward_hook(
1425                lambda module, inp, out: activs.setdefault("output", out.detach())
1426            )
1427            _, _, _ = self.forward(input, "train", batch_idx, num_batches, self.task_id)
1428            handle.remove()
1429            layer_output = activs["output"]  # shape (B, F, ...)
1430
1431            # 3. Build baseline tensor matching that shape
1432            if layer_baselines is None:
1433                baseline_tensor = torch.zeros_like(layer_output)
1434            elif isinstance(layer_baselines, (int, float)):
1435                baseline_tensor = torch.full_like(layer_output, layer_baselines)
1436            elif isinstance(layer_baselines, Tensor):
1437                if layer_baselines.shape == layer_output.shape:
1438                    baseline_tensor = layer_baselines
1439                elif layer_baselines.shape == layer_output.shape[1:]:
1440                    baseline_tensor = layer_baselines.unsqueeze(0).repeat(
1441                        layer_output.size(0), *([1] * layer_baselines.ndim)
1442                    )
1443                else:
1444                    raise ValueError(...)
1445            else:
1446                raise ValueError(...)
1447
1448            B, F = layer_output.size(0), layer_output.size(1)
1449
1450            # 4. Create a “mega-batch” replicating the input F times
1451            if isinstance(input, tuple):
1452                mega_inputs = tuple(
1453                    t.unsqueeze(0).repeat(F, *([1] * t.ndim)).view(-1, *t.shape[1:])
1454                    for t in input
1455                )
1456            else:
1457                mega_inputs = (
1458                    input.unsqueeze(0)
1459                    .repeat(F, *([1] * input.ndim))
1460                    .view(-1, *input.shape[1:])
1461                )
1462
1463            # 5. Equally replicate the baseline tensor
1464            mega_baseline = (
1465                baseline_tensor.unsqueeze(0)
1466                .repeat(F, *([1] * baseline_tensor.ndim))
1467                .view(-1, *baseline_tensor.shape[1:])
1468            )
1469
1470            # 6. Precompute vectorized indices
1471            device = layer_output.device
1472            positions = torch.arange(F * B, device=device)  # [0,1,...,F*B-1]
1473            feat_idx = torch.arange(F, device=device).repeat_interleave(
1474                B
1475            )  # [0,0,...,1,1,...,F-1]
1476
1477            # 7. One hook to zero out each channel slice across the mega-batch
1478            def mega_ablate_hook(module, inp, out):
1479                out_mod = out.clone()
1480                # for each sample in mega-batch, zero its corresponding channel
1481                out_mod[positions, feat_idx] = mega_baseline[positions, feat_idx]
1482                return out_mod
1483
1484            h = layer.register_forward_hook(mega_ablate_hook)
1485            out_all, _, _ = self.forward(
1486                mega_inputs, "train", batch_idx, num_batches, self.task_id
1487            )
1488            h.remove()
1489
1490            # 8. Recover scores, reshape [F*B] → [F, B], diff & mean
1491            if target is not None:
1492                tgt_flat = target.unsqueeze(0).repeat(F, 1).view(-1)
1493                scores_all = out_all.gather(1, tgt_flat.view(-1, 1)).squeeze(1)
1494            else:
1495                scores_all = out_all.sum(dim=1)
1496
1497            scores_all = scores_all.view(F, B)
1498            diffs = torch.abs(baseline_scores.unsqueeze(0) - scores_all)
1499            importance_step_layer = diffs.mean(dim=1).detach()  # [F]
1500
1501            return importance_step_layer
1502
1503        else:
1504            # initialize the Layer Feature Ablation object
1505            layer_feature_ablation = LayerFeatureAblation(
1506                forward_func=self.forward, layer=layer
1507            )
1508
1509            # calculate layer attribution of the step
1510            self.set_forward_func_return_logits_only(True)
1511            attribution = layer_feature_ablation.attribute(
1512                inputs=input,
1513                layer_baselines=layer_baselines,
1514                # target=target, # disable target to enable perturbations_per_eval
1515                additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1516                perturbations_per_eval=128,  # to accelerate the computation
1517            )
1518            self.set_forward_func_return_logits_only(False)
1519
1520            attribution_abs_batch_mean = torch.mean(
1521                torch.abs(attribution),
1522                dim=[
1523                    i for i in range(attribution.dim()) if i != 1
1524                ],  # average the features over batch samples
1525            )
1526
1527        importance_step_layer = attribution_abs_batch_mean
1528        importance_step_layer = importance_step_layer.detach()
1529
1530        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of feature ablation attribution. We implement this using Layer Feature Ablation in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • layer_baselines (None | int | float | Tensor | tuple[int | float | Tensor, ...]): reference values which replace each layer input / output value when ablated. Please refer to the Captum documentation for more details.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.
  • if_captum (bool): whether to use Captum or not. If True, we use Captum to calculate the feature ablation. If False, we use our implementation. Default is False, because our implementation is much faster.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_lrp_abs( self: str, layer_name: str, input: torch.Tensor | tuple[torch.Tensor, ...], target: torch.Tensor | None, batch_idx: int, num_batches: int) -> torch.Tensor:
1532    def get_importance_step_layer_lrp_abs(
1533        self: str,
1534        layer_name: str,
1535        input: Tensor | tuple[Tensor, ...],
1536        target: Tensor | None,
1537        batch_idx: int,
1538        num_batches: int,
1539    ) -> Tensor:
1540        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140). We implement this using [Layer LRP](https://captum.ai/api/layer.html#layer-lrp) in Captum.
1541
1542        **Args:**
1543        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1544        - **input** (`Tensor` | `tuple[Tensor, ...]`): the input batch of the training step.
1545        - **target** (`Tensor` | `None`): the target batch of the training step.
1546        - **batch_idx** (`int`): the index of the current batch. This is an argument of the forward function during training.
1547        - **num_batches** (`int`): the number of batches in the training step. This is an argument of the forward function during training.
1548
1549        **Returns:**
1550        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1551        """
1552        layer = self.backbone.get_layer_by_name(layer_name)
1553
1554        # initialize the Layer LRP object
1555        layer_lrp = LayerLRP(model=self, layer=layer)
1556
1557        # set model to evaluation mode to prevent updating the model parameters
1558        self.eval()
1559
1560        self.set_forward_func_return_logits_only(True)
1561        # calculate layer attribution of the step
1562        attribution = layer_lrp.attribute(
1563            inputs=input,
1564            target=target,
1565            additional_forward_args=("train", batch_idx, num_batches, self.task_id),
1566        )
1567        self.set_forward_func_return_logits_only(False)
1568
1569        attribution_abs_batch_mean = torch.mean(
1570            torch.abs(attribution),
1571            dim=[
1572                i for i in range(attribution.dim()) if i != 1
1573            ],  # average the features over batch samples
1574        )
1575
1576        importance_step_layer = attribution_abs_batch_mean
1577        importance_step_layer = importance_step_layer.detach()
1578
1579        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the absolute values of LRP. We implement this using Layer LRP in Captum.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • input (Tensor | tuple[Tensor, ...]): the input batch of the training step.
  • target (Tensor | None): the target batch of the training step.
  • batch_idx (int): the index of the current batch. This is an argument of the forward function during training.
  • num_batches (int): the number of batches in the training step. This is an argument of the forward function during training.

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.
def get_importance_step_layer_cbp_adaptive_contribution(self: str, layer_name: str, activation: torch.Tensor) -> torch.Tensor:
1581    def get_importance_step_layer_cbp_adaptive_contribution(
1582        self: str,
1583        layer_name: str,
1584        activation: Tensor,
1585    ) -> Tensor:
1586        r"""Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in [CBP](https://www.nature.com/articles/s41586-024-07711-7).
1587
1588        **Args:**
1589        - **layer_name** (`str`): the name of layer to get neuron-wise importance.
1590        - **activation** (`Tensor`): the activation tensor of the layer. It has the same size of (number of units, ).
1591
1592        **Returns:**
1593        - **importance_step_layer** (`Tensor`): the neuron-wise importance of the layer of the training step.
1594        """
1595        layer = self.backbone.get_layer_by_name(layer_name)
1596
1597        input_weight_abs = torch.abs(layer.weight.data)
1598        input_weight_abs_sum = torch.sum(
1599            input_weight_abs,
1600            dim=[
1601                i for i in range(input_weight_abs.dim()) if i != 0
1602            ],  # sum over the input dimension
1603        )
1604        input_weight_abs_sum_reciprocal = torch.reciprocal(input_weight_abs_sum)
1605
1606        output_weight_abs = torch.abs(self.next_layer(layer_name).weight.data)
1607        output_weight_abs_sum = torch.sum(
1608            output_weight_abs,
1609            dim=[
1610                i for i in range(output_weight_abs.dim()) if i != 1
1611            ],  # sum over the output dimension
1612        )
1613
1614        activation_abs_batch_mean = torch.mean(
1615            torch.abs(activation),
1616            dim=[
1617                i for i in range(activation.dim()) if i != 1
1618            ],  # average the features over batch samples
1619        )
1620
1621        importance_step_layer = (
1622            output_weight_abs_sum
1623            * activation_abs_batch_mean
1624            * input_weight_abs_sum_reciprocal
1625        )
1626        importance_step_layer = importance_step_layer.detach()
1627
1628        return importance_step_layer

Get the raw neuron-wise importance (before scaling) of a layer of a training step. See $I^{\tau}_l(\mathbf{x},y)$ (before Eqs. (5) and (6)) in the paper. This method uses the sum of absolute values of layer output weights multiplied by absolute values of activation, then divided by the reciprocal of sum of absolute values of layer input weights. It is equal to the adaptive contribution utility in CBP.

Args:

  • layer_name (str): the name of layer to get neuron-wise importance.
  • activation (Tensor): the activation tensor of the layer. It has the same size of (number of units, ).

Returns:

  • importance_step_layer (Tensor): the neuron-wise importance of the layer of the training step.