clarena.utils.metrics

The submodule in utils for metric utilities and custom torchmetrics.

  1r"""
  2The submodule in `utils` for metric utilities and custom torchmetrics.
  3"""
  4
  5__all__ = ["linear_cka", "MeanMetricBatch", "HATNetworkCapacityMetric"]
  6
  7
  8import logging
  9from typing import Any
 10
 11import torch
 12from torch import Tensor
 13from torchmetrics.aggregation import BaseAggregator
 14
 15# always get logger for built-in logging in each module
 16pylogger = logging.getLogger(__name__)
 17
 18
 19def linear_cka(x: Tensor, y: Tensor, eps: float = 1e-12) -> Tensor:
 20    r"""Compute linear CKA similarity between two representation matrices.
 21
 22    **Args:**
 23    - **x** (`Tensor`): the first representation matrix with shape `(num_samples, num_features_x)`.
 24    - **y** (`Tensor`): the second representation matrix with shape `(num_samples, num_features_y)`.
 25    - **eps** (`float`): a small constant to avoid division by zero.
 26
 27    **Returns:**
 28    - **cka_similarity** (`Tensor`): a scalar tensor of linear CKA similarity.
 29    """
 30    if x.dim() != 2:
 31        raise ValueError(
 32            f"Expected `x` to be a 2D tensor with shape (num_samples, num_features_x), but got {x.dim()} dimensions."
 33        )
 34    if y.dim() != 2:
 35        raise ValueError(
 36            f"Expected `y` to be a 2D tensor with shape (num_samples, num_features_y), but got {y.dim()} dimensions."
 37        )
 38    if x.size(0) != y.size(0):
 39        raise ValueError(
 40            f"Expected `x` and `y` to have the same number of samples, but got {x.size(0)} and {y.size(0)}."
 41        )
 42
 43    # Center features across samples before computing similarities.
 44    x = x - x.mean(dim=0, keepdim=True)
 45    y = y - y.mean(dim=0, keepdim=True)
 46
 47    cross_cov_norm = torch.norm(x.T @ y, p="fro") ** 2
 48    normalization = torch.norm(x.T @ x, p="fro") * torch.norm(y.T @ y, p="fro")
 49
 50    return cross_cov_norm / (normalization + eps)
 51
 52
 53class MeanMetricBatch(BaseAggregator):
 54    r"""A TorchMetrics metric to calculate the mean of metrics across data batches.
 55
 56    This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details.
 57    """
 58
 59    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
 60        r"""Initialize the metric. Add state variables."""
 61        BaseAggregator.__init__(
 62            self,
 63            "sum",
 64            torch.tensor(0.0, dtype=torch.get_default_dtype()),
 65            nan_strategy,
 66            state_name="sum",
 67            **kwargs,
 68        )
 69
 70        self.sum: Tensor
 71        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
 72
 73        self.add_state(
 74            "num",
 75            default=torch.tensor(0, dtype=torch.get_default_dtype()),
 76            dist_reduce_fx="sum",
 77        )
 78        self.num: Tensor
 79        r"""State variable created by `add_state()` to store the number of the data till this batch."""
 80
 81    def update(self, value: torch.Tensor, batch_size: int) -> None:
 82        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
 83
 84        **Args:**
 85        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
 86        - **batch_size** (`int`): the value to update the num, which is the batch size.
 87        """
 88
 89        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
 90        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
 91
 92        self.sum += value * batch_size
 93        self.num += batch_size
 94
 95    def compute(self) -> Tensor:
 96        r"""Compute this mean metric value till this batch.
 97
 98        **Returns:**
 99        - **mean** (`Tensor`): the calculated mean result.
100        """
101        return self.sum / self.num
102
103
104class HATNetworkCapacityMetric(BaseAggregator):
105    r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
106
107    Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
108    """
109
110    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
111        r"""Initialise the HAT network capacity metric. Add state variables."""
112        BaseAggregator.__init__(
113            self,
114            "sum",
115            torch.tensor(0.0, dtype=torch.get_default_dtype()),
116            nan_strategy,
117            state_name="sum_adjustment_rate",
118            **kwargs,
119        )
120        self.sum_adjustment_rate: Tensor
121        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
122
123        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
124        self.num_params: Tensor
125        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
126
127    def update(
128        self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor
129    ) -> None:
130        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
131
132        **Args:**
133        - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
134        - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer.
135        """
136        adjustment_rate_weight_layer = torch.as_tensor(
137            adjustment_rate_weight_layer, dtype=self.dtype, device=self.device
138        )
139        adjustment_rate_bias_layer = torch.as_tensor(
140            adjustment_rate_bias_layer, dtype=self.dtype, device=self.device
141        )
142
143        self.sum_adjustment_rate += (
144            adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum()
145        )
146        self.num_params += (
147            adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel()
148        )
149
150    def compute(self) -> Tensor:
151        r"""Compute this HAT network capacity till this layer.
152
153        **Returns:**
154        - **network_capacity** (`Tensor`): the calculated network capacity result.
155        """
156
157        return self.sum_adjustment_rate / self.num_params
def linear_cka(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
20def linear_cka(x: Tensor, y: Tensor, eps: float = 1e-12) -> Tensor:
21    r"""Compute linear CKA similarity between two representation matrices.
22
23    **Args:**
24    - **x** (`Tensor`): the first representation matrix with shape `(num_samples, num_features_x)`.
25    - **y** (`Tensor`): the second representation matrix with shape `(num_samples, num_features_y)`.
26    - **eps** (`float`): a small constant to avoid division by zero.
27
28    **Returns:**
29    - **cka_similarity** (`Tensor`): a scalar tensor of linear CKA similarity.
30    """
31    if x.dim() != 2:
32        raise ValueError(
33            f"Expected `x` to be a 2D tensor with shape (num_samples, num_features_x), but got {x.dim()} dimensions."
34        )
35    if y.dim() != 2:
36        raise ValueError(
37            f"Expected `y` to be a 2D tensor with shape (num_samples, num_features_y), but got {y.dim()} dimensions."
38        )
39    if x.size(0) != y.size(0):
40        raise ValueError(
41            f"Expected `x` and `y` to have the same number of samples, but got {x.size(0)} and {y.size(0)}."
42        )
43
44    # Center features across samples before computing similarities.
45    x = x - x.mean(dim=0, keepdim=True)
46    y = y - y.mean(dim=0, keepdim=True)
47
48    cross_cov_norm = torch.norm(x.T @ y, p="fro") ** 2
49    normalization = torch.norm(x.T @ x, p="fro") * torch.norm(y.T @ y, p="fro")
50
51    return cross_cov_norm / (normalization + eps)

Compute linear CKA similarity between two representation matrices.

Args:

  • x (Tensor): the first representation matrix with shape (num_samples, num_features_x).
  • y (Tensor): the second representation matrix with shape (num_samples, num_features_y).
  • eps (float): a small constant to avoid division by zero.

Returns:

  • cka_similarity (Tensor): a scalar tensor of linear CKA similarity.
class MeanMetricBatch(torchmetrics.aggregation.BaseAggregator):
 54class MeanMetricBatch(BaseAggregator):
 55    r"""A TorchMetrics metric to calculate the mean of metrics across data batches.
 56
 57    This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details.
 58    """
 59
 60    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
 61        r"""Initialize the metric. Add state variables."""
 62        BaseAggregator.__init__(
 63            self,
 64            "sum",
 65            torch.tensor(0.0, dtype=torch.get_default_dtype()),
 66            nan_strategy,
 67            state_name="sum",
 68            **kwargs,
 69        )
 70
 71        self.sum: Tensor
 72        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
 73
 74        self.add_state(
 75            "num",
 76            default=torch.tensor(0, dtype=torch.get_default_dtype()),
 77            dist_reduce_fx="sum",
 78        )
 79        self.num: Tensor
 80        r"""State variable created by `add_state()` to store the number of the data till this batch."""
 81
 82    def update(self, value: torch.Tensor, batch_size: int) -> None:
 83        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
 84
 85        **Args:**
 86        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
 87        - **batch_size** (`int`): the value to update the num, which is the batch size.
 88        """
 89
 90        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
 91        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
 92
 93        self.sum += value * batch_size
 94        self.num += batch_size
 95
 96    def compute(self) -> Tensor:
 97        r"""Compute this mean metric value till this batch.
 98
 99        **Returns:**
100        - **mean** (`Tensor`): the calculated mean result.
101        """
102        return self.sum / self.num

A TorchMetrics metric to calculate the mean of metrics across data batches.

This is used for accumulated metrics in deep learning. See here for more details.

MeanMetricBatch(nan_strategy: str | float = 'error', **kwargs: Any)
60    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
61        r"""Initialize the metric. Add state variables."""
62        BaseAggregator.__init__(
63            self,
64            "sum",
65            torch.tensor(0.0, dtype=torch.get_default_dtype()),
66            nan_strategy,
67            state_name="sum",
68            **kwargs,
69        )
70
71        self.sum: Tensor
72        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
73
74        self.add_state(
75            "num",
76            default=torch.tensor(0, dtype=torch.get_default_dtype()),
77            dist_reduce_fx="sum",
78        )
79        self.num: Tensor
80        r"""State variable created by `add_state()` to store the number of the data till this batch."""

Initialize the metric. Add state variables.

sum: torch.Tensor

State variable created by super().__init__() to store the sum of the metric values till this batch.

num: torch.Tensor

State variable created by add_state() to store the number of the data till this batch.

def update(self, value: torch.Tensor, batch_size: int) -> None:
82    def update(self, value: torch.Tensor, batch_size: int) -> None:
83        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
84
85        **Args:**
86        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
87        - **batch_size** (`int`): the value to update the num, which is the batch size.
88        """
89
90        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
91        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
92
93        self.sum += value * batch_size
94        self.num += batch_size

Update and accumulate the sum of metric value and num of the data till this batch from the batch.

Args:

  • val (torch.Tensor): the metric value of the batch to update the sum.
  • batch_size (int): the value to update the num, which is the batch size.
def compute(self) -> torch.Tensor:
 96    def compute(self) -> Tensor:
 97        r"""Compute this mean metric value till this batch.
 98
 99        **Returns:**
100        - **mean** (`Tensor`): the calculated mean result.
101        """
102        return self.sum / self.num

Compute this mean metric value till this batch.

Returns:

  • mean (Tensor): the calculated mean result.
class HATNetworkCapacityMetric(torchmetrics.aggregation.BaseAggregator):
105class HATNetworkCapacityMetric(BaseAggregator):
106    r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
107
108    Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
109    """
110
111    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
112        r"""Initialise the HAT network capacity metric. Add state variables."""
113        BaseAggregator.__init__(
114            self,
115            "sum",
116            torch.tensor(0.0, dtype=torch.get_default_dtype()),
117            nan_strategy,
118            state_name="sum_adjustment_rate",
119            **kwargs,
120        )
121        self.sum_adjustment_rate: Tensor
122        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
123
124        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
125        self.num_params: Tensor
126        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
127
128    def update(
129        self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor
130    ) -> None:
131        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
132
133        **Args:**
134        - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
135        - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer.
136        """
137        adjustment_rate_weight_layer = torch.as_tensor(
138            adjustment_rate_weight_layer, dtype=self.dtype, device=self.device
139        )
140        adjustment_rate_bias_layer = torch.as_tensor(
141            adjustment_rate_bias_layer, dtype=self.dtype, device=self.device
142        )
143
144        self.sum_adjustment_rate += (
145            adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum()
146        )
147        self.num_params += (
148            adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel()
149        )
150
151    def compute(self) -> Tensor:
152        r"""Compute this HAT network capacity till this layer.
153
154        **Returns:**
155        - **network_capacity** (`Tensor`): the calculated network capacity result.
156        """
157
158        return self.sum_adjustment_rate / self.num_params

A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.

Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the AdaHAT paper.

HATNetworkCapacityMetric(nan_strategy: str | float = 'error', **kwargs: Any)
111    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
112        r"""Initialise the HAT network capacity metric. Add state variables."""
113        BaseAggregator.__init__(
114            self,
115            "sum",
116            torch.tensor(0.0, dtype=torch.get_default_dtype()),
117            nan_strategy,
118            state_name="sum_adjustment_rate",
119            **kwargs,
120        )
121        self.sum_adjustment_rate: Tensor
122        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
123
124        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
125        self.num_params: Tensor
126        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""

Initialise the HAT network capacity metric. Add state variables.

sum_adjustment_rate: torch.Tensor

State variable created by add_state() to store the sum of the adjustment rate values till this layer.

num_params: torch.Tensor

State variable created by add_state() to store the number of the parameters till this layer.

def update( self, adjustment_rate_weight_layer: torch.Tensor, adjustment_rate_bias_layer: torch.Tensor) -> None:
128    def update(
129        self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor
130    ) -> None:
131        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
132
133        **Args:**
134        - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
135        - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer.
136        """
137        adjustment_rate_weight_layer = torch.as_tensor(
138            adjustment_rate_weight_layer, dtype=self.dtype, device=self.device
139        )
140        adjustment_rate_bias_layer = torch.as_tensor(
141            adjustment_rate_bias_layer, dtype=self.dtype, device=self.device
142        )
143
144        self.sum_adjustment_rate += (
145            adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum()
146        )
147        self.num_params += (
148            adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel()
149        )

Update and accumulate the sum of adjustment rate values till this layer from the layer.

Args:

  • adjustment_rate_weight_layer (Tensor): the adjustment rate values of the weight matrix of the layer.
  • adjustment_rate_bias_layer (Tensor): the adjustment rate values of the bias vector of the layer.
def compute(self) -> torch.Tensor:
151    def compute(self) -> Tensor:
152        r"""Compute this HAT network capacity till this layer.
153
154        **Returns:**
155        - **network_capacity** (`Tensor`): the calculated network capacity result.
156        """
157
158        return self.sum_adjustment_rate / self.num_params

Compute this HAT network capacity till this layer.

Returns:

  • network_capacity (Tensor): the calculated network capacity result.