clarena.utils.metrics

The submodule in utils for custom torchmetrics.

  1r"""
  2The submodule in `utils` for custom torchmetrics.
  3"""
  4
  5__all__ = ["MeanMetricBatch", "HATNetworkCapacityMetric"]
  6
  7
  8import logging
  9from typing import Any
 10
 11import torch
 12from torch import Tensor
 13from torchmetrics.aggregation import BaseAggregator
 14
 15# always get logger for built-in logging in each module
 16pylogger = logging.getLogger(__name__)
 17
 18
 19class MeanMetricBatch(BaseAggregator):
 20    r"""A TorchMetrics metric to calculate the mean of metrics across data batches.
 21
 22    This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details.
 23    """
 24
 25    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
 26        r"""Initialize the metric. Add state variables."""
 27        BaseAggregator.__init__(
 28            self,
 29            "sum",
 30            torch.tensor(0.0, dtype=torch.get_default_dtype()),
 31            nan_strategy,
 32            state_name="sum",
 33            **kwargs,
 34        )
 35
 36        self.sum: Tensor
 37        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
 38
 39        self.add_state(
 40            "num",
 41            default=torch.tensor(0, dtype=torch.get_default_dtype()),
 42            dist_reduce_fx="sum",
 43        )
 44        self.num: Tensor
 45        r"""State variable created by `add_state()` to store the number of the data till this batch."""
 46
 47    def update(self, value: torch.Tensor, batch_size: int) -> None:
 48        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
 49
 50        **Args:**
 51        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
 52        - **batch_size** (`int`): the value to update the num, which is the batch size.
 53        """
 54
 55        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
 56        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
 57
 58        self.sum += value * batch_size
 59        self.num += batch_size
 60
 61    def compute(self) -> Tensor:
 62        r"""Compute this mean metric value till this batch.
 63
 64        **Returns:**
 65        - **mean** (`Tensor`): the calculated mean result.
 66        """
 67        return self.sum / self.num
 68
 69
 70class HATNetworkCapacityMetric(BaseAggregator):
 71    r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
 72
 73    Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 74    """
 75
 76    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
 77        r"""Initialise the HAT network capacity metric. Add state variables."""
 78        BaseAggregator.__init__(
 79            self,
 80            "sum",
 81            torch.tensor(0.0, dtype=torch.get_default_dtype()),
 82            nan_strategy,
 83            state_name="sum_adjustment_rate",
 84            **kwargs,
 85        )
 86        self.sum_adjustment_rate: Tensor
 87        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
 88
 89        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
 90        self.num_params: Tensor
 91        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
 92
 93    def update(
 94        self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor
 95    ) -> None:
 96        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
 97
 98        **Args:**
 99        - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
100        - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer.
101        """
102        adjustment_rate_weight_layer = torch.as_tensor(
103            adjustment_rate_weight_layer, dtype=self.dtype, device=self.device
104        )
105        adjustment_rate_bias_layer = torch.as_tensor(
106            adjustment_rate_bias_layer, dtype=self.dtype, device=self.device
107        )
108
109        self.sum_adjustment_rate += (
110            adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum()
111        )
112        self.num_params += (
113            adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel()
114        )
115
116    def compute(self) -> Tensor:
117        r"""Compute this HAT network capacity till this layer.
118
119        **Returns:**
120        - **network_capacity** (`Tensor`): the calculated network capacity result.
121        """
122
123        return self.sum_adjustment_rate / self.num_params
class MeanMetricBatch(torchmetrics.aggregation.BaseAggregator):
20class MeanMetricBatch(BaseAggregator):
21    r"""A TorchMetrics metric to calculate the mean of metrics across data batches.
22
23    This is used for accumulated metrics in deep learning. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#nte-accumulate) for more details.
24    """
25
26    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
27        r"""Initialize the metric. Add state variables."""
28        BaseAggregator.__init__(
29            self,
30            "sum",
31            torch.tensor(0.0, dtype=torch.get_default_dtype()),
32            nan_strategy,
33            state_name="sum",
34            **kwargs,
35        )
36
37        self.sum: Tensor
38        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
39
40        self.add_state(
41            "num",
42            default=torch.tensor(0, dtype=torch.get_default_dtype()),
43            dist_reduce_fx="sum",
44        )
45        self.num: Tensor
46        r"""State variable created by `add_state()` to store the number of the data till this batch."""
47
48    def update(self, value: torch.Tensor, batch_size: int) -> None:
49        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
50
51        **Args:**
52        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
53        - **batch_size** (`int`): the value to update the num, which is the batch size.
54        """
55
56        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
57        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
58
59        self.sum += value * batch_size
60        self.num += batch_size
61
62    def compute(self) -> Tensor:
63        r"""Compute this mean metric value till this batch.
64
65        **Returns:**
66        - **mean** (`Tensor`): the calculated mean result.
67        """
68        return self.sum / self.num

A TorchMetrics metric to calculate the mean of metrics across data batches.

This is used for accumulated metrics in deep learning. See here for more details.

MeanMetricBatch(nan_strategy: str | float = 'error', **kwargs: Any)
26    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
27        r"""Initialize the metric. Add state variables."""
28        BaseAggregator.__init__(
29            self,
30            "sum",
31            torch.tensor(0.0, dtype=torch.get_default_dtype()),
32            nan_strategy,
33            state_name="sum",
34            **kwargs,
35        )
36
37        self.sum: Tensor
38        r"""State variable created by `super().__init__()` to store the sum of the metric values till this batch."""
39
40        self.add_state(
41            "num",
42            default=torch.tensor(0, dtype=torch.get_default_dtype()),
43            dist_reduce_fx="sum",
44        )
45        self.num: Tensor
46        r"""State variable created by `add_state()` to store the number of the data till this batch."""

Initialize the metric. Add state variables.

sum: torch.Tensor

State variable created by super().__init__() to store the sum of the metric values till this batch.

num: torch.Tensor

State variable created by add_state() to store the number of the data till this batch.

def update(self, value: torch.Tensor, batch_size: int) -> None:
48    def update(self, value: torch.Tensor, batch_size: int) -> None:
49        r"""Update and accumulate the sum of metric value and num of the data till this batch from the batch.
50
51        **Args:**
52        - **val** (`torch.Tensor`): the metric value of the batch to update the sum.
53        - **batch_size** (`int`): the value to update the num, which is the batch size.
54        """
55
56        value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
57        batch_size = torch.as_tensor(batch_size, dtype=self.dtype, device=self.device)
58
59        self.sum += value * batch_size
60        self.num += batch_size

Update and accumulate the sum of metric value and num of the data till this batch from the batch.

Args:

  • val (torch.Tensor): the metric value of the batch to update the sum.
  • batch_size (int): the value to update the num, which is the batch size.
def compute(self) -> torch.Tensor:
62    def compute(self) -> Tensor:
63        r"""Compute this mean metric value till this batch.
64
65        **Returns:**
66        - **mean** (`Tensor`): the calculated mean result.
67        """
68        return self.sum / self.num

Compute this mean metric value till this batch.

Returns:

  • mean (Tensor): the calculated mean result.
class HATNetworkCapacityMetric(torchmetrics.aggregation.BaseAggregator):
 71class HATNetworkCapacityMetric(BaseAggregator):
 72    r"""A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.
 73
 74    Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9).
 75    """
 76
 77    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
 78        r"""Initialise the HAT network capacity metric. Add state variables."""
 79        BaseAggregator.__init__(
 80            self,
 81            "sum",
 82            torch.tensor(0.0, dtype=torch.get_default_dtype()),
 83            nan_strategy,
 84            state_name="sum_adjustment_rate",
 85            **kwargs,
 86        )
 87        self.sum_adjustment_rate: Tensor
 88        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
 89
 90        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
 91        self.num_params: Tensor
 92        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""
 93
 94    def update(
 95        self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor
 96    ) -> None:
 97        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
 98
 99        **Args:**
100        - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
101        - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer.
102        """
103        adjustment_rate_weight_layer = torch.as_tensor(
104            adjustment_rate_weight_layer, dtype=self.dtype, device=self.device
105        )
106        adjustment_rate_bias_layer = torch.as_tensor(
107            adjustment_rate_bias_layer, dtype=self.dtype, device=self.device
108        )
109
110        self.sum_adjustment_rate += (
111            adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum()
112        )
113        self.num_params += (
114            adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel()
115        )
116
117    def compute(self) -> Tensor:
118        r"""Compute this HAT network capacity till this layer.
119
120        **Returns:**
121        - **network_capacity** (`Tensor`): the calculated network capacity result.
122        """
123
124        return self.sum_adjustment_rate / self.num_params

A torchmetrics metric to calculate the network capacity of HAT (Hard Attention to the Task) algorithm.

Network capacity is defined as the average adjustment rate over all parameters. See chapter 4.1 in the AdaHAT paper.

HATNetworkCapacityMetric(nan_strategy: str | float = 'error', **kwargs: Any)
77    def __init__(self, nan_strategy: str | float = "error", **kwargs: Any) -> None:
78        r"""Initialise the HAT network capacity metric. Add state variables."""
79        BaseAggregator.__init__(
80            self,
81            "sum",
82            torch.tensor(0.0, dtype=torch.get_default_dtype()),
83            nan_strategy,
84            state_name="sum_adjustment_rate",
85            **kwargs,
86        )
87        self.sum_adjustment_rate: Tensor
88        r"""State variable created by `add_state()` to store the sum of the adjustment rate values till this layer."""
89
90        self.add_state("num_params", default=torch.tensor(0), dist_reduce_fx="sum")
91        self.num_params: Tensor
92        r"""State variable created by `add_state()` to store the number of the parameters till this layer."""

Initialise the HAT network capacity metric. Add state variables.

sum_adjustment_rate: torch.Tensor

State variable created by add_state() to store the sum of the adjustment rate values till this layer.

num_params: torch.Tensor

State variable created by add_state() to store the number of the parameters till this layer.

def update( self, adjustment_rate_weight_layer: torch.Tensor, adjustment_rate_bias_layer: torch.Tensor) -> None:
 94    def update(
 95        self, adjustment_rate_weight_layer: Tensor, adjustment_rate_bias_layer: Tensor
 96    ) -> None:
 97        r"""Update and accumulate the sum of adjustment rate values till this layer from the layer.
 98
 99        **Args:**
100        - **adjustment_rate_weight_layer** (`Tensor`): the adjustment rate values of the weight matrix of the layer.
101        - **adjustment_rate_bias_layer** (`Tensor`): the adjustment rate values of the bias vector of the layer.
102        """
103        adjustment_rate_weight_layer = torch.as_tensor(
104            adjustment_rate_weight_layer, dtype=self.dtype, device=self.device
105        )
106        adjustment_rate_bias_layer = torch.as_tensor(
107            adjustment_rate_bias_layer, dtype=self.dtype, device=self.device
108        )
109
110        self.sum_adjustment_rate += (
111            adjustment_rate_weight_layer.sum() + adjustment_rate_bias_layer.sum()
112        )
113        self.num_params += (
114            adjustment_rate_weight_layer.numel() + adjustment_rate_bias_layer.numel()
115        )

Update and accumulate the sum of adjustment rate values till this layer from the layer.

Args:

  • adjustment_rate_weight_layer (Tensor): the adjustment rate values of the weight matrix of the layer.
  • adjustment_rate_bias_layer (Tensor): the adjustment rate values of the bias vector of the layer.
def compute(self) -> torch.Tensor:
117    def compute(self) -> Tensor:
118        r"""Compute this HAT network capacity till this layer.
119
120        **Returns:**
121        - **network_capacity** (`Tensor`): the calculated network capacity result.
122        """
123
124        return self.sum_adjustment_rate / self.num_params

Compute this HAT network capacity till this layer.

Returns:

  • network_capacity (Tensor): the calculated network capacity result.