clarena.utils.metrics
The submodule in utils
for plotting utils.
1"""The submodule in `utils` for plotting utils.""" 2 3__all__ = ["MeanMetricBatch", "HATNetworkCapacity"] 4 5from typing import Any 6 7import torch 8from torch import Tensor 9from torchmetrics.aggregation import BaseAggregator 10 11 12class MeanMetricBatch(BaseAggregator): 13 r"""A TorchMetrics metric to calculate the mean of metrics across data batches. 14 15 This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details. 16 """ 17 18 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 19 r"""Initialise the metric. Add state variables.""" 20 BaseAggregator.__init__( 21 self, 22 "sum", 23 torch.tensor(0.0, dtype=torch.get_default_dtype()), 24 nan_strategy, 25 state_name="sum", 26 **kwargs, 27 ) 28 29 self.sum: Tensor 30 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 31 32 self.add_state( 33 "num", 34 default=torch.tensor(0, dtype=torch.get_default_dtype()), 35 dist_reduce_fx="sum", 36 ) 37 self.num: Tensor 38 r"""State variable created by `add_state()` to store the number of the data till this batch.""" 39 40 def update(self, value: torch.Tensor, batch_size: int) -> None: 41 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 42 43 **Args:** 44 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 45 - **batch_size** (`int`): the value to update the num, which is the batch size. 46 """ 47 48 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 49 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 50 51 self.sum += value * batch_size 52 self.num += batch_size 53 54 def compute(self) -> Tensor: 55 r"""Compute this mean metric value till this batch. 56 57 **Returns:** 58 - **mean** (`Tensor`): the calculated mean result. 59 """ 60 return self.sum / self.num 61 62 63class HATNetworkCapacity(BaseAggregator): 64 r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm. 65 66 Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 67 """ 68 69 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 70 r"""Initialise the HAT network capacity metric. Add state variables.""" 71 BaseAggregator.__init__( 72 self, 73 "sum", 74 torch.tensor(0.0, dtype=torch.get_default_dtype()), 75 nan_strategy, 76 state_name="sum_adjustment_rate", 77 **kwargs, 78 ) 79 self.sum_adjustment_rate: Tensor 80 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 81 82 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 83 self.num_params: Tensor 84 r"""State variable created by `add_state()` to store the number of the parameters till this layer.""" 85 86 def update( 87 self, adjustment_rate_weight: Tensor, adjustment_rate_bias: Tensor 88 ) -> None: 89 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 90 91 **Args:** 92 - **adjustment_rate_weight** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 93 - **adjustment_rate_bias** (`Tensor`): the adjustment rate values of the bias vector of the layer. 94 """ 95 adjustment_rate_weight = torch.as_tensor( 96 adjustment_rate_weight, dtype=self.dtype, device=self.device 97 ) 98 adjustment_rate_bias = torch.as_tensor( 99 adjustment_rate_bias, dtype=self.dtype, device=self.device 100 ) 101 102 self.sum_adjustment_rate += ( 103 adjustment_rate_weight.sum() + adjustment_rate_bias.sum() 104 ) 105 self.num_params += adjustment_rate_weight.numel() + adjustment_rate_bias.numel() 106 107 def compute(self) -> Tensor: 108 r"""Compute this HAT network capacity till this layer. 109 110 **Returns:** 111 - **network_capacity** (`Tensor`): the calculated network capacity result. 112 """ 113 114 return self.sum_adjustment_rate / self.num_params
13class MeanMetricBatch(BaseAggregator): 14 r"""A TorchMetrics metric to calculate the mean of metrics across data batches. 15 16 This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details. 17 """ 18 19 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 20 r"""Initialise the metric. Add state variables.""" 21 BaseAggregator.__init__( 22 self, 23 "sum", 24 torch.tensor(0.0, dtype=torch.get_default_dtype()), 25 nan_strategy, 26 state_name="sum", 27 **kwargs, 28 ) 29 30 self.sum: Tensor 31 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 32 33 self.add_state( 34 "num", 35 default=torch.tensor(0, dtype=torch.get_default_dtype()), 36 dist_reduce_fx="sum", 37 ) 38 self.num: Tensor 39 r"""State variable created by `add_state()` to store the number of the data till this batch.""" 40 41 def update(self, value: torch.Tensor, batch_size: int) -> None: 42 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 43 44 **Args:** 45 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 46 - **batch_size** (`int`): the value to update the num, which is the batch size. 47 """ 48 49 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 50 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 51 52 self.sum += value * batch_size 53 self.num += batch_size 54 55 def compute(self) -> Tensor: 56 r"""Compute this mean metric value till this batch. 57 58 **Returns:** 59 - **mean** (`Tensor`): the calculated mean result. 60 """ 61 return self.sum / self.num
A TorchMetrics metric to calculate the mean of metrics across data batches.
This is used for accumulated metrics in deep learning. See here for more details.
19 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 20 r"""Initialise the metric. Add state variables.""" 21 BaseAggregator.__init__( 22 self, 23 "sum", 24 torch.tensor(0.0, dtype=torch.get_default_dtype()), 25 nan_strategy, 26 state_name="sum", 27 **kwargs, 28 ) 29 30 self.sum: Tensor 31 r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch.""" 32 33 self.add_state( 34 "num", 35 default=torch.tensor(0, dtype=torch.get_default_dtype()), 36 dist_reduce_fx="sum", 37 ) 38 self.num: Tensor 39 r"""State variable created by `add_state()` to store the number of the data till this batch."""
Initialise the metric. Add state variables.
State variable created by super().__init__()
to store the sum of the metric values till this batch.
State variable created by add_state()
to store the number of the data till this batch.
41 def update(self, value: torch.Tensor, batch_size: int) -> None: 42 r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch. 43 44 **Args:** 45 - **val** (`torch.Tensor`): the metric value of the batch to update the sum. 46 - **batch_size** (`int`): the value to update the num, which is the batch size. 47 """ 48 49 value = torch.as_tensor(value, dtype=self.dtype, device=self.device) 50 batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device) 51 52 self.sum += value * batch_size 53 self.num += batch_size
Update and accumulate the sum of metric value and num of the data till this batch from the batch.
Args:
- val (
torch.Tensor
): the metric value of the batch to update the sum. - batch_size (
int
): the value to update the num, which is the batch size.
55 def compute(self) -> Tensor: 56 r"""Compute this mean metric value till this batch. 57 58 **Returns:** 59 - **mean** (`Tensor`): the calculated mean result. 60 """ 61 return self.sum / self.num
Compute this mean metric value till this batch.
Returns:
- mean (
Tensor
): the calculated mean result.
64class HATNetworkCapacity(BaseAggregator): 65 r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm. 66 67 Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9). 68 """ 69 70 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 71 r"""Initialise the HAT network capacity metric. Add state variables.""" 72 BaseAggregator.__init__( 73 self, 74 "sum", 75 torch.tensor(0.0, dtype=torch.get_default_dtype()), 76 nan_strategy, 77 state_name="sum_adjustment_rate", 78 **kwargs, 79 ) 80 self.sum_adjustment_rate: Tensor 81 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 82 83 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 84 self.num_params: Tensor 85 r"""State variable created by `add_state()` to store the number of the parameters till this layer.""" 86 87 def update( 88 self, adjustment_rate_weight: Tensor, adjustment_rate_bias: Tensor 89 ) -> None: 90 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 91 92 **Args:** 93 - **adjustment_rate_weight** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 94 - **adjustment_rate_bias** (`Tensor`): the adjustment rate values of the bias vector of the layer. 95 """ 96 adjustment_rate_weight = torch.as_tensor( 97 adjustment_rate_weight, dtype=self.dtype, device=self.device 98 ) 99 adjustment_rate_bias = torch.as_tensor( 100 adjustment_rate_bias, dtype=self.dtype, device=self.device 101 ) 102 103 self.sum_adjustment_rate += ( 104 adjustment_rate_weight.sum() + adjustment_rate_bias.sum() 105 ) 106 self.num_params += adjustment_rate_weight.numel() + adjustment_rate_bias.numel() 107 108 def compute(self) -> Tensor: 109 r"""Compute this HAT network capacity till this layer. 110 111 **Returns:** 112 - **network_capacity** (`Tensor`): the calculated network capacity result. 113 """ 114 115 return self.sum_adjustment_rate / self.num_params
A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in AdaHAT paper.
70 def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None: 71 r"""Initialise the HAT network capacity metric. Add state variables.""" 72 BaseAggregator.__init__( 73 self, 74 "sum", 75 torch.tensor(0.0, dtype=torch.get_default_dtype()), 76 nan_strategy, 77 state_name="sum_adjustment_rate", 78 **kwargs, 79 ) 80 self.sum_adjustment_rate: Tensor 81 r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer.""" 82 83 self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum") 84 self.num_params: Tensor 85 r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
Initialise the HAT network capacity metric. Add state variables.
State variable created by add_state()
to store the sum of the adjustment rate values till this layer.
State variable created by add_state()
to store the number of the parameters till this layer.
87 def update( 88 self, adjustment_rate_weight: Tensor, adjustment_rate_bias: Tensor 89 ) -> None: 90 r"""Update and accumulate the sum of adjustment rate values till this layer from the layer. 91 92 **Args:** 93 - **adjustment_rate_weight** (`Tensor`): the adjustment rate values of the weight matrix of the layer. 94 - **adjustment_rate_bias** (`Tensor`): the adjustment rate values of the bias vector of the layer. 95 """ 96 adjustment_rate_weight = torch.as_tensor( 97 adjustment_rate_weight, dtype=self.dtype, device=self.device 98 ) 99 adjustment_rate_bias = torch.as_tensor( 100 adjustment_rate_bias, dtype=self.dtype, device=self.device 101 ) 102 103 self.sum_adjustment_rate += ( 104 adjustment_rate_weight.sum() + adjustment_rate_bias.sum() 105 ) 106 self.num_params += adjustment_rate_weight.numel() + adjustment_rate_bias.numel()
Update and accumulate the sum of adjustment rate values till this layer from the layer.
Args:
- adjustment_rate_weight (
Tensor
): the adjustment rate values of the weight matrix of the layer. - adjustment_rate_bias (
Tensor
): the adjustment rate values of the bias vector of the layer.
108 def compute(self) -> Tensor: 109 r"""Compute this HAT network capacity till this layer. 110 111 **Returns:** 112 - **network_capacity** (`Tensor`): the calculated network capacity result. 113 """ 114 115 return self.sum_adjustment_rate / self.num_params
Compute this HAT network capacity till this layer.
Returns:
- network_capacity (
Tensor
): the calculated network capacity result.