clarena.utils.metrics

The submodule in utils for plotting utils.

  1"""The submodule in `utils` for plotting utils."""
  2
  3__all__ = ["MeanMetricBatch", "HATNetworkCapacity"]
  4
  5from typing import Any
  6
  7import torch
  8from torch import Tensor
  9from torchmetrics.aggregation import BaseAggregator
 10
 11
 12class MeanMetricBatch(BaseAggregator):
 13    r"""A TorchMetrics metric to calculate the mean of metrics across data batches.
 14
 15    This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details.
 16    """
 17
 18    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
 19        r"""Initialise the metric. Add state variables."""
 20        BaseAggregator.__init__(
 21            self,
 22            "sum",
 23            torch.tensor(0.0, dtype=torch.get_default_dtype()),
 24            nan_strategy,
 25            state_name="sum",
 26            **kwargs,
 27        )
 28
 29        self.sum: Tensor
 30        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
 31
 32        self.add_state(
 33            "num",
 34            default=torch.tensor(0, dtype=torch.get_default_dtype()),
 35            dist_reduce_fx="sum",
 36        )
 37        self.num: Tensor
 38        r"""State variable created by `add_state()` to store the number of the data till this batch."""
 39
 40    def update(self, value: torch.Tensor, batch_size: int) -> None:
 41        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
 42
 43        **Args:**
 44        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
 45        - **batch_size** (`int`): the value to update the num, which is the batch size.
 46        """
 47
 48        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
 49        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
 50
 51        self.sum += value * batch_size
 52        self.num += batch_size
 53
 54    def compute(self) -> Tensor:
 55        r"""Compute this mean metric value till this batch.
 56
 57        **Returns:**
 58        - **mean** (`Tensor`): the calculated mean result.
 59        """
 60        return self.sum / self.num
 61
 62
 63class HATNetworkCapacity(BaseAggregator):
 64    r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
 65
 66    Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 67    """
 68
 69    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
 70        r"""Initialise the HAT network capacity metric. Add state variables."""
 71        BaseAggregator.__init__(
 72            self,
 73            "sum",
 74            torch.tensor(0.0, dtype=torch.get_default_dtype()),
 75            nan_strategy,
 76            state_name="sum_adjustment_rate",
 77            **kwargs,
 78        )
 79        self.sum_adjustment_rate: Tensor
 80        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
 81
 82        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
 83        self.num_params: Tensor
 84        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
 85
 86    def update(
 87        self, adjustment_rate_weight: Tensor, adjustment_rate_bias: Tensor
 88    ) -> None:
 89        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
 90
 91        **Args:**
 92        - **adjustment_rate_weight** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
 93        - **adjustment_rate_bias** (`Tensor`): the adjustment rate values of the bias vector of the layer.
 94        """
 95        adjustment_rate_weight = torch.as_tensor(
 96            adjustment_rate_weight, dtype=self.dtype, device=self.device
 97        )
 98        adjustment_rate_bias = torch.as_tensor(
 99            adjustment_rate_bias, dtype=self.dtype, device=self.device
100        )
101
102        self.sum_adjustment_rate += (
103            adjustment_rate_weight.sum() + adjustment_rate_bias.sum()
104        )
105        self.num_params += adjustment_rate_weight.numel() + adjustment_rate_bias.numel()
106
107    def compute(self) -> Tensor:
108        r"""Compute this HAT network capacity till this layer.
109
110        **Returns:**
111        - **network_capacity** (`Tensor`): the calculated network capacity result.
112        """
113
114        return self.sum_adjustment_rate / self.num_params
class MeanMetricBatch(torchmetrics.aggregation.BaseAggregator):
13class MeanMetricBatch(BaseAggregator):
14    r"""A TorchMetrics metric to calculate the mean of metrics across data batches.
15
16    This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details.
17    """
18
19    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
20        r"""Initialise the metric. Add state variables."""
21        BaseAggregator.__init__(
22            self,
23            "sum",
24            torch.tensor(0.0, dtype=torch.get_default_dtype()),
25            nan_strategy,
26            state_name="sum",
27            **kwargs,
28        )
29
30        self.sum: Tensor
31        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
32
33        self.add_state(
34            "num",
35            default=torch.tensor(0, dtype=torch.get_default_dtype()),
36            dist_reduce_fx="sum",
37        )
38        self.num: Tensor
39        r"""State variable created by `add_state()` to store the number of the data till this batch."""
40
41    def update(self, value: torch.Tensor, batch_size: int) -> None:
42        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
43
44        **Args:**
45        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
46        - **batch_size** (`int`): the value to update the num, which is the batch size.
47        """
48
49        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
50        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
51
52        self.sum += value * batch_size
53        self.num += batch_size
54
55    def compute(self) -> Tensor:
56        r"""Compute this mean metric value till this batch.
57
58        **Returns:**
59        - **mean** (`Tensor`): the calculated mean result.
60        """
61        return self.sum / self.num

A TorchMetrics metric to calculate the mean of metrics across data batches.

This is used for accumulated metrics in deep learning. See here for more details.

MeanMetricBatch(nan_strategy: str | float = 'error', **kwargs: Any)
19    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
20        r"""Initialise the metric. Add state variables."""
21        BaseAggregator.__init__(
22            self,
23            "sum",
24            torch.tensor(0.0, dtype=torch.get_default_dtype()),
25            nan_strategy,
26            state_name="sum",
27            **kwargs,
28        )
29
30        self.sum: Tensor
31        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
32
33        self.add_state(
34            "num",
35            default=torch.tensor(0, dtype=torch.get_default_dtype()),
36            dist_reduce_fx="sum",
37        )
38        self.num: Tensor
39        r"""State variable created by `add_state()` to store the number of the data till this batch."""

Initialise the metric. Add state variables.

sum: torch.Tensor

State variable created by super().__init__() to store the sum of the metric values till this batch.

num: torch.Tensor

State variable created by add_state() to store the number of the data till this batch.

def update(self, value: torch.Tensor, batch_size: int) -> None:
41    def update(self, value: torch.Tensor, batch_size: int) -> None:
42        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
43
44        **Args:**
45        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
46        - **batch_size** (`int`): the value to update the num, which is the batch size.
47        """
48
49        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
50        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
51
52        self.sum += value * batch_size
53        self.num += batch_size

Update and accumulate the sum of metric value and num of the data till this batch from the batch.

Args:

  • val (torch.Tensor): the metric value of the batch to update the sum.
  • batch_size (int): the value to update the num, which is the batch size.
def compute(self) -> torch.Tensor:
55    def compute(self) -> Tensor:
56        r"""Compute this mean metric value till this batch.
57
58        **Returns:**
59        - **mean** (`Tensor`): the calculated mean result.
60        """
61        return self.sum / self.num

Compute this mean metric value till this batch.

Returns:

  • mean (Tensor): the calculated mean result.
class HATNetworkCapacity(torchmetrics.aggregation.BaseAggregator):
 64class HATNetworkCapacity(BaseAggregator):
 65    r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
 66
 67    Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 68    """
 69
 70    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
 71        r"""Initialise the HAT network capacity metric. Add state variables."""
 72        BaseAggregator.__init__(
 73            self,
 74            "sum",
 75            torch.tensor(0.0, dtype=torch.get_default_dtype()),
 76            nan_strategy,
 77            state_name="sum_adjustment_rate",
 78            **kwargs,
 79        )
 80        self.sum_adjustment_rate: Tensor
 81        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
 82
 83        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
 84        self.num_params: Tensor
 85        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
 86
 87    def update(
 88        self, adjustment_rate_weight: Tensor, adjustment_rate_bias: Tensor
 89    ) -> None:
 90        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
 91
 92        **Args:**
 93        - **adjustment_rate_weight** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
 94        - **adjustment_rate_bias** (`Tensor`): the adjustment rate values of the bias vector of the layer.
 95        """
 96        adjustment_rate_weight = torch.as_tensor(
 97            adjustment_rate_weight, dtype=self.dtype, device=self.device
 98        )
 99        adjustment_rate_bias = torch.as_tensor(
100            adjustment_rate_bias, dtype=self.dtype, device=self.device
101        )
102
103        self.sum_adjustment_rate += (
104            adjustment_rate_weight.sum() + adjustment_rate_bias.sum()
105        )
106        self.num_params += adjustment_rate_weight.numel() + adjustment_rate_bias.numel()
107
108    def compute(self) -> Tensor:
109        r"""Compute this HAT network capacity till this layer.
110
111        **Returns:**
112        - **network_capacity** (`Tensor`): the calculated network capacity result.
113        """
114
115        return self.sum_adjustment_rate / self.num_params

A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.

Network capacity is defined as the average adjustment rate over all paramaters. See chapter 4.1 in AdaHAT paper.

HATNetworkCapacity(nan_strategy: str | float = 'error', **kwargs: Any)
70    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
71        r"""Initialise the HAT network capacity metric. Add state variables."""
72        BaseAggregator.__init__(
73            self,
74            "sum",
75            torch.tensor(0.0, dtype=torch.get_default_dtype()),
76            nan_strategy,
77            state_name="sum_adjustment_rate",
78            **kwargs,
79        )
80        self.sum_adjustment_rate: Tensor
81        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
82
83        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
84        self.num_params: Tensor
85        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""

Initialise the HAT network capacity metric. Add state variables.

sum_adjustment_rate: torch.Tensor

State variable created by add_state() to store the sum of the adjustment rate values till this layer.

num_params: torch.Tensor

State variable created by add_state() to store the number of the parameters till this layer.

def update( self, adjustment_rate_weight: torch.Tensor, adjustment_rate_bias: torch.Tensor) -> None:
 87    def update(
 88        self, adjustment_rate_weight: Tensor, adjustment_rate_bias: Tensor
 89    ) -> None:
 90        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
 91
 92        **Args:**
 93        - **adjustment_rate_weight** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
 94        - **adjustment_rate_bias** (`Tensor`): the adjustment rate values of the bias vector of the layer.
 95        """
 96        adjustment_rate_weight = torch.as_tensor(
 97            adjustment_rate_weight, dtype=self.dtype, device=self.device
 98        )
 99        adjustment_rate_bias = torch.as_tensor(
100            adjustment_rate_bias, dtype=self.dtype, device=self.device
101        )
102
103        self.sum_adjustment_rate += (
104            adjustment_rate_weight.sum() + adjustment_rate_bias.sum()
105        )
106        self.num_params += adjustment_rate_weight.numel() + adjustment_rate_bias.numel()

Update and accumulate the sum of adjustment rate values till this layer from the layer.

Args:

  • adjustment_rate_weight (Tensor): the adjustment rate values of the weight matrix of the layer.
  • adjustment_rate_bias (Tensor): the adjustment rate values of the bias vector of the layer.
def compute(self) -> torch.Tensor:
108    def compute(self) -> Tensor:
109        r"""Compute this HAT network capacity till this layer.
110
111        **Returns:**
112        - **network_capacity** (`Tensor`): the calculated network capacity result.
113        """
114
115        return self.sum_adjustment_rate / self.num_params

Compute this HAT network capacity till this layer.

Returns:

  • network_capacity (Tensor): the calculated network capacity result.