clarena.utils.metrics
The submodule in utils for custom torchmetrics.
1r""" 2The submodule in `utils` for custom torchmetrics. 3""" 4 5__all__ = ["MeanMetricBatch", "HATNetworkCapacityMetric"] 6 7 8import logging 9from typing import Any 10 11import torch 12from torch import Tensor 13from torchmetrics.aggregation import BaseAggregator 14 15# always get logger for built-in logging in each module 16pylogger = logging.getLogger(__name__) 17 18 19class MeanMetricBatch(BaseAggregator): 20 r"""A TorchMetrics metric to calculate the mean of metrics across data batches. 21 22 This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details. 23 """ 24 25 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 26 r"""Initialize the metric. Add state variables.""" 27 BaseAggregator.__init__( 28 self, 29 "sum", 30 torch.tensor(0.0, dtype=torch.get_default_dtype()), 31 nan_strategy, 32 state_name="sum", 33 **kwargs, 34 ) 35 36 self.sum: Tensor 37 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 38 39 self.add_state( 40 "num", 41 default=torch.tensor(0, dtype=torch.get_default_dtype()), 42 dist_reduce_fx="sum", 43 ) 44 self.num: Tensor 45 r"""State variable created by `add_state()` to store the number of the data till this batch.""" 46 47 def update(self, value: torch.Tensor, batch_size: int) -> None: 48 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 49 50 **Args:** 51 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 52 - **batch_size** (`int`): the value to update the num, which is the batch size. 53 """ 54 55 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 56 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 57 58 self.sum += value * batch_size 59 self.num += batch_size 60 61 def compute(self) -> Tensor: 62 r"""Compute this mean metric value till this batch. 63 64 **Returns:** 65 - **mean** (`Tensor`): the calculated mean result. 66 """ 67 return self.sum / self.num 68 69 70class HATNetworkCapacityMetric(BaseAggregator): 71 r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm. 72 73 Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 74 """ 75 76 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 77 r"""Initialise the HAT network capacity metric. Add state variables.""" 78 BaseAggregator.__init__( 79 self, 80 "sum", 81 torch.tensor(0.0, dtype=torch.get_default_dtype()), 82 nan_strategy, 83 state_name="sum_adjustment_rate", 84 **kwargs, 85 ) 86 self.sum_adjustment_rate: Tensor 87 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 88 89 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 90 self.num_params: Tensor 91 r"""State variable created by `add_state()` to store the number of the parameters till this layer.""" 92 93 def update( 94 self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor 95 ) -> None: 96 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 97 98 **Args:** 99 - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 100 - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer. 101 """ 102 adjustment_rate_weight_layer = torch.as_tensor( 103 adjustment_rate_weight_layer, dtype=self.dtype, device=self.device 104 ) 105 adjustment_rate_bias_layer = torch.as_tensor( 106 adjustment_rate_bias_layer, dtype=self.dtype, device=self.device 107 ) 108 109 self.sum_adjustment_rate += ( 110 adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum() 111 ) 112 self.num_params += ( 113 adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel() 114 ) 115 116 def compute(self) -> Tensor: 117 r"""Compute this HAT network capacity till this layer. 118 119 **Returns:** 120 - **network_capacity** (`Tensor`): the calculated network capacity result. 121 """ 122 123 return self.sum_adjustment_rate / self.num_params
20class MeanMetricBatch(BaseAggregator): 21 r"""A TorchMetrics metric to calculate the mean of metrics across data batches. 22 23 This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details. 24 """ 25 26 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 27 r"""Initialize the metric. Add state variables.""" 28 BaseAggregator.__init__( 29 self, 30 "sum", 31 torch.tensor(0.0, dtype=torch.get_default_dtype()), 32 nan_strategy, 33 state_name="sum", 34 **kwargs, 35 ) 36 37 self.sum: Tensor 38 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 39 40 self.add_state( 41 "num", 42 default=torch.tensor(0, dtype=torch.get_default_dtype()), 43 dist_reduce_fx="sum", 44 ) 45 self.num: Tensor 46 r"""State variable created by `add_state()` to store the number of the data till this batch.""" 47 48 def update(self, value: torch.Tensor, batch_size: int) -> None: 49 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 50 51 **Args:** 52 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 53 - **batch_size** (`int`): the value to update the num, which is the batch size. 54 """ 55 56 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 57 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 58 59 self.sum += value * batch_size 60 self.num += batch_size 61 62 def compute(self) -> Tensor: 63 r"""Compute this mean metric value till this batch. 64 65 **Returns:** 66 - **mean** (`Tensor`): the calculated mean result. 67 """ 68 return self.sum / self.num
A TorchMetrics metric to calculate the mean of metrics across data batches.
This is used for accumulated metrics in deep learning. See here for more details.
26 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 27 r"""Initialize the metric. Add state variables.""" 28 BaseAggregator.__init__( 29 self, 30 "sum", 31 torch.tensor(0.0, dtype=torch.get_default_dtype()), 32 nan_strategy, 33 state_name="sum", 34 **kwargs, 35 ) 36 37 self.sum: Tensor 38 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 39 40 self.add_state( 41 "num", 42 default=torch.tensor(0, dtype=torch.get_default_dtype()), 43 dist_reduce_fx="sum", 44 ) 45 self.num: Tensor 46 r"""State variable created by `add_state()` to store the number of the data till this batch."""
Initialize the metric. Add state variables.
State variable created by super().__init__() to store the sum of the metric values till this batch.
State variable created by add_state() to store the number of the data till this batch.
48 def update(self, value: torch.Tensor, batch_size: int) -> None: 49 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 50 51 **Args:** 52 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 53 - **batch_size** (`int`): the value to update the num, which is the batch size. 54 """ 55 56 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 57 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 58 59 self.sum += value * batch_size 60 self.num += batch_size
Update and accumulate the sum of metric value and num of the data till this batch from the batch.
Args:
- val (
torch.Tensor): the metric value of the batch to update the sum. - batch_size (
int): the value to update the num, which is the batch size.
62 def compute(self) -> Tensor: 63 r"""Compute this mean metric value till this batch. 64 65 **Returns:** 66 - **mean** (`Tensor`): the calculated mean result. 67 """ 68 return self.sum / self.num
Compute this mean metric value till this batch.
Returns:
- mean (
Tensor): the calculated mean result.
71class HATNetworkCapacityMetric(BaseAggregator): 72 r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm. 73 74 Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 75 """ 76 77 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 78 r"""Initialise the HAT network capacity metric. Add state variables.""" 79 BaseAggregator.__init__( 80 self, 81 "sum", 82 torch.tensor(0.0, dtype=torch.get_default_dtype()), 83 nan_strategy, 84 state_name="sum_adjustment_rate", 85 **kwargs, 86 ) 87 self.sum_adjustment_rate: Tensor 88 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 89 90 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 91 self.num_params: Tensor 92 r"""State variable created by `add_state()` to store the number of the parameters till this layer.""" 93 94 def update( 95 self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor 96 ) -> None: 97 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 98 99 **Args:** 100 - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 101 - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer. 102 """ 103 adjustment_rate_weight_layer = torch.as_tensor( 104 adjustment_rate_weight_layer, dtype=self.dtype, device=self.device 105 ) 106 adjustment_rate_bias_layer = torch.as_tensor( 107 adjustment_rate_bias_layer, dtype=self.dtype, device=self.device 108 ) 109 110 self.sum_adjustment_rate += ( 111 adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum() 112 ) 113 self.num_params += ( 114 adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel() 115 ) 116 117 def compute(self) -> Tensor: 118 r"""Compute this HAT network capacity till this layer. 119 120 **Returns:** 121 - **network_capacity** (`Tensor`): the calculated network capacity result. 122 """ 123 124 return self.sum_adjustment_rate / self.num_params
A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the AdaHAT paper.
77 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 78 r"""Initialise the HAT network capacity metric. Add state variables.""" 79 BaseAggregator.__init__( 80 self, 81 "sum", 82 torch.tensor(0.0, dtype=torch.get_default_dtype()), 83 nan_strategy, 84 state_name="sum_adjustment_rate", 85 **kwargs, 86 ) 87 self.sum_adjustment_rate: Tensor 88 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 89 90 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 91 self.num_params: Tensor 92 r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
Initialise the HAT network capacity metric. Add state variables.
State variable created by add_state() to store the sum of the adjustment rate values till this layer.
State variable created by add_state() to store the number of the parameters till this layer.
94 def update( 95 self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor 96 ) -> None: 97 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 98 99 **Args:** 100 - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 101 - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer. 102 """ 103 adjustment_rate_weight_layer = torch.as_tensor( 104 adjustment_rate_weight_layer, dtype=self.dtype, device=self.device 105 ) 106 adjustment_rate_bias_layer = torch.as_tensor( 107 adjustment_rate_bias_layer, dtype=self.dtype, device=self.device 108 ) 109 110 self.sum_adjustment_rate += ( 111 adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum() 112 ) 113 self.num_params += ( 114 adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel() 115 )
Update and accumulate the sum of adjustment rate values till this layer from the layer.
Args:
- adjustment_rate_weight_layer (
Tensor): the adjustment rate values of the weight matrix of the layer. - adjustment_rate_bias_layer (
Tensor): the adjustment rate values of the bias vector of the layer.
117 def compute(self) -> Tensor: 118 r"""Compute this HAT network capacity till this layer. 119 120 **Returns:** 121 - **network_capacity** (`Tensor`): the calculated network capacity result. 122 """ 123 124 return self.sum_adjustment_rate / self.num_params
Compute this HAT network capacity till this layer.
Returns:
- network_capacity (
Tensor): the calculated network capacity result.