clarena.utils.metrics
The submodule in utils for metric utilities and custom torchmetrics.
1r""" 2The submodule in `utils` for metric utilities and custom torchmetrics. 3""" 4 5__all__ = ["linear_cka", "MeanMetricBatch", "HATNetworkCapacityMetric"] 6 7 8import logging 9from typing import Any 10 11import torch 12from torch import Tensor 13from torchmetrics.aggregation import BaseAggregator 14 15# always get logger for built-in logging in each module 16pylogger = logging.getLogger(__name__) 17 18 19def linear_cka(x: Tensor, y: Tensor, eps: float = 1e-12) -> Tensor: 20 r"""Compute linear CKA similarity between two representation matrices. 21 22 **Args:** 23 - **x** (`Tensor`): the first representation matrix with shape `(num_samples, num_features_x)`. 24 - **y** (`Tensor`): the second representation matrix with shape `(num_samples, num_features_y)`. 25 - **eps** (`float`): a small constant to avoid division by zero. 26 27 **Returns:** 28 - **cka_similarity** (`Tensor`): a scalar tensor of linear CKA similarity. 29 """ 30 if x.dim() != 2: 31 raise ValueError( 32 f"Expected `x` to be a 2D tensor with shape (num_samples, num_features_x), but got {x.dim()} dimensions." 33 ) 34 if y.dim() != 2: 35 raise ValueError( 36 f"Expected `y` to be a 2D tensor with shape (num_samples, num_features_y), but got {y.dim()} dimensions." 37 ) 38 if x.size(0) != y.size(0): 39 raise ValueError( 40 f"Expected `x` and `y` to have the same number of samples, but got {x.size(0)} and {y.size(0)}." 41 ) 42 43 # Center features across samples before computing similarities. 44 x = x - x.mean(dim=0, keepdim=True) 45 y = y - y.mean(dim=0, keepdim=True) 46 47 cross_cov_norm = torch.norm(x.T @ y, p="fro") ** 2 48 normalization = torch.norm(x.T @ x, p="fro") * torch.norm(y.T @ y, p="fro") 49 50 return cross_cov_norm / (normalization + eps) 51 52 53class MeanMetricBatch(BaseAggregator): 54 r"""A TorchMetrics metric to calculate the mean of metrics across data batches. 55 56 This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details. 57 """ 58 59 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 60 r"""Initialize the metric. Add state variables.""" 61 BaseAggregator.__init__( 62 self, 63 "sum", 64 torch.tensor(0.0, dtype=torch.get_default_dtype()), 65 nan_strategy, 66 state_name="sum", 67 **kwargs, 68 ) 69 70 self.sum: Tensor 71 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 72 73 self.add_state( 74 "num", 75 default=torch.tensor(0, dtype=torch.get_default_dtype()), 76 dist_reduce_fx="sum", 77 ) 78 self.num: Tensor 79 r"""State variable created by `add_state()` to store the number of the data till this batch.""" 80 81 def update(self, value: torch.Tensor, batch_size: int) -> None: 82 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 83 84 **Args:** 85 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 86 - **batch_size** (`int`): the value to update the num, which is the batch size. 87 """ 88 89 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 90 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 91 92 self.sum += value * batch_size 93 self.num += batch_size 94 95 def compute(self) -> Tensor: 96 r"""Compute this mean metric value till this batch. 97 98 **Returns:** 99 - **mean** (`Tensor`): the calculated mean result. 100 """ 101 return self.sum / self.num 102 103 104class HATNetworkCapacityMetric(BaseAggregator): 105 r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm. 106 107 Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 108 """ 109 110 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 111 r"""Initialise the HAT network capacity metric. Add state variables.""" 112 BaseAggregator.__init__( 113 self, 114 "sum", 115 torch.tensor(0.0, dtype=torch.get_default_dtype()), 116 nan_strategy, 117 state_name="sum_adjustment_rate", 118 **kwargs, 119 ) 120 self.sum_adjustment_rate: Tensor 121 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 122 123 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 124 self.num_params: Tensor 125 r"""State variable created by `add_state()` to store the number of the parameters till this layer.""" 126 127 def update( 128 self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor 129 ) -> None: 130 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 131 132 **Args:** 133 - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 134 - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer. 135 """ 136 adjustment_rate_weight_layer = torch.as_tensor( 137 adjustment_rate_weight_layer, dtype=self.dtype, device=self.device 138 ) 139 adjustment_rate_bias_layer = torch.as_tensor( 140 adjustment_rate_bias_layer, dtype=self.dtype, device=self.device 141 ) 142 143 self.sum_adjustment_rate += ( 144 adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum() 145 ) 146 self.num_params += ( 147 adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel() 148 ) 149 150 def compute(self) -> Tensor: 151 r"""Compute this HAT network capacity till this layer. 152 153 **Returns:** 154 - **network_capacity** (`Tensor`): the calculated network capacity result. 155 """ 156 157 return self.sum_adjustment_rate / self.num_params
20def linear_cka(x: Tensor, y: Tensor, eps: float = 1e-12) -> Tensor: 21 r"""Compute linear CKA similarity between two representation matrices. 22 23 **Args:** 24 - **x** (`Tensor`): the first representation matrix with shape `(num_samples, num_features_x)`. 25 - **y** (`Tensor`): the second representation matrix with shape `(num_samples, num_features_y)`. 26 - **eps** (`float`): a small constant to avoid division by zero. 27 28 **Returns:** 29 - **cka_similarity** (`Tensor`): a scalar tensor of linear CKA similarity. 30 """ 31 if x.dim() != 2: 32 raise ValueError( 33 f"Expected `x` to be a 2D tensor with shape (num_samples, num_features_x), but got {x.dim()} dimensions." 34 ) 35 if y.dim() != 2: 36 raise ValueError( 37 f"Expected `y` to be a 2D tensor with shape (num_samples, num_features_y), but got {y.dim()} dimensions." 38 ) 39 if x.size(0) != y.size(0): 40 raise ValueError( 41 f"Expected `x` and `y` to have the same number of samples, but got {x.size(0)} and {y.size(0)}." 42 ) 43 44 # Center features across samples before computing similarities. 45 x = x - x.mean(dim=0, keepdim=True) 46 y = y - y.mean(dim=0, keepdim=True) 47 48 cross_cov_norm = torch.norm(x.T @ y, p="fro") ** 2 49 normalization = torch.norm(x.T @ x, p="fro") * torch.norm(y.T @ y, p="fro") 50 51 return cross_cov_norm / (normalization + eps)
Compute linear CKA similarity between two representation matrices.
Args:
- x (
Tensor): the first representation matrix with shape(num_samples, num_features_x). - y (
Tensor): the second representation matrix with shape(num_samples, num_features_y). - eps (
float): a small constant to avoid division by zero.
Returns:
- cka_similarity (
Tensor): a scalar tensor of linear CKA similarity.
54class MeanMetricBatch(BaseAggregator): 55 r"""A TorchMetrics metric to calculate the mean of metrics across data batches. 56 57 This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details. 58 """ 59 60 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 61 r"""Initialize the metric. Add state variables.""" 62 BaseAggregator.__init__( 63 self, 64 "sum", 65 torch.tensor(0.0, dtype=torch.get_default_dtype()), 66 nan_strategy, 67 state_name="sum", 68 **kwargs, 69 ) 70 71 self.sum: Tensor 72 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 73 74 self.add_state( 75 "num", 76 default=torch.tensor(0, dtype=torch.get_default_dtype()), 77 dist_reduce_fx="sum", 78 ) 79 self.num: Tensor 80 r"""State variable created by `add_state()` to store the number of the data till this batch.""" 81 82 def update(self, value: torch.Tensor, batch_size: int) -> None: 83 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 84 85 **Args:** 86 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 87 - **batch_size** (`int`): the value to update the num, which is the batch size. 88 """ 89 90 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 91 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 92 93 self.sum += value * batch_size 94 self.num += batch_size 95 96 def compute(self) -> Tensor: 97 r"""Compute this mean metric value till this batch. 98 99 **Returns:** 100 - **mean** (`Tensor`): the calculated mean result. 101 """ 102 return self.sum / self.num
A TorchMetrics metric to calculate the mean of metrics across data batches.
This is used for accumulated metrics in deep learning. See here for more details.
60 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 61 r"""Initialize the metric. Add state variables.""" 62 BaseAggregator.__init__( 63 self, 64 "sum", 65 torch.tensor(0.0, dtype=torch.get_default_dtype()), 66 nan_strategy, 67 state_name="sum", 68 **kwargs, 69 ) 70 71 self.sum: Tensor 72 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 73 74 self.add_state( 75 "num", 76 default=torch.tensor(0, dtype=torch.get_default_dtype()), 77 dist_reduce_fx="sum", 78 ) 79 self.num: Tensor 80 r"""State variable created by `add_state()` to store the number of the data till this batch."""
Initialize the metric. Add state variables.
State variable created by super().__init__() to store the sum of the metric values till this batch.
State variable created by add_state() to store the number of the data till this batch.
82 def update(self, value: torch.Tensor, batch_size: int) -> None: 83 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 84 85 **Args:** 86 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 87 - **batch_size** (`int`): the value to update the num, which is the batch size. 88 """ 89 90 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 91 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 92 93 self.sum += value * batch_size 94 self.num += batch_size
Update and accumulate the sum of metric value and num of the data till this batch from the batch.
Args:
- val (
torch.Tensor): the metric value of the batch to update the sum. - batch_size (
int): the value to update the num, which is the batch size.
96 def compute(self) -> Tensor: 97 r"""Compute this mean metric value till this batch. 98 99 **Returns:** 100 - **mean** (`Tensor`): the calculated mean result. 101 """ 102 return self.sum / self.num
Compute this mean metric value till this batch.
Returns:
- mean (
Tensor): the calculated mean result.
105class HATNetworkCapacityMetric(BaseAggregator): 106 r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm. 107 108 Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 109 """ 110 111 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 112 r"""Initialise the HAT network capacity metric. Add state variables.""" 113 BaseAggregator.__init__( 114 self, 115 "sum", 116 torch.tensor(0.0, dtype=torch.get_default_dtype()), 117 nan_strategy, 118 state_name="sum_adjustment_rate", 119 **kwargs, 120 ) 121 self.sum_adjustment_rate: Tensor 122 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 123 124 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 125 self.num_params: Tensor 126 r"""State variable created by `add_state()` to store the number of the parameters till this layer.""" 127 128 def update( 129 self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor 130 ) -> None: 131 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 132 133 **Args:** 134 - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 135 - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer. 136 """ 137 adjustment_rate_weight_layer = torch.as_tensor( 138 adjustment_rate_weight_layer, dtype=self.dtype, device=self.device 139 ) 140 adjustment_rate_bias_layer = torch.as_tensor( 141 adjustment_rate_bias_layer, dtype=self.dtype, device=self.device 142 ) 143 144 self.sum_adjustment_rate += ( 145 adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum() 146 ) 147 self.num_params += ( 148 adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel() 149 ) 150 151 def compute(self) -> Tensor: 152 r"""Compute this HAT network capacity till this layer. 153 154 **Returns:** 155 - **network_capacity** (`Tensor`): the calculated network capacity result. 156 """ 157 158 return self.sum_adjustment_rate / self.num_params
A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the AdaHAT paper.
111 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 112 r"""Initialise the HAT network capacity metric. Add state variables.""" 113 BaseAggregator.__init__( 114 self, 115 "sum", 116 torch.tensor(0.0, dtype=torch.get_default_dtype()), 117 nan_strategy, 118 state_name="sum_adjustment_rate", 119 **kwargs, 120 ) 121 self.sum_adjustment_rate: Tensor 122 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 123 124 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 125 self.num_params: Tensor 126 r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
Initialise the HAT network capacity metric. Add state variables.
State variable created by add_state() to store the sum of the adjustment rate values till this layer.
State variable created by add_state() to store the number of the parameters till this layer.
128 def update( 129 self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor 130 ) -> None: 131 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 132 133 **Args:** 134 - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 135 - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer. 136 """ 137 adjustment_rate_weight_layer = torch.as_tensor( 138 adjustment_rate_weight_layer, dtype=self.dtype, device=self.device 139 ) 140 adjustment_rate_bias_layer = torch.as_tensor( 141 adjustment_rate_bias_layer, dtype=self.dtype, device=self.device 142 ) 143 144 self.sum_adjustment_rate += ( 145 adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum() 146 ) 147 self.num_params += ( 148 adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel() 149 )
Update and accumulate the sum of adjustment rate values till this layer from the layer.
Args:
- adjustment_rate_weight_layer (
Tensor): the adjustment rate values of the weight matrix of the layer. - adjustment_rate_bias_layer (
Tensor): the adjustment rate values of the bias vector of the layer.
151 def compute(self) -> Tensor: 152 r"""Compute this HAT network capacity till this layer. 153 154 **Returns:** 155 - **network_capacity** (`Tensor`): the calculated network capacity result. 156 """ 157 158 return self.sum_adjustment_rate / self.num_params
Compute this HAT network capacity till this layer.
Returns:
- network_capacity (
Tensor): the calculated network capacity result.