clarena.metrics.hat_network_capacity

The submodule in metrics for HATNetworkCapacity.

 1r"""
 2The submodule in `metrics` for `HATNetworkCapacity`.
 3"""
 4
 5__all__ = ["HATNetworkCapacity"]
 6
 7import logging
 8from typing import Any
 9
10from lightning import Trainer
11from lightning.pytorch.utilities import rank_zero_only
12
13from clarena.cl_algorithms import HAT
14from clarena.metrics import MetricCallback
15
16# always get logger for built-in logging in each module
17pylogger = logging.getLogger(__name__)
18
19
20class HATNetworkCapacity(MetricCallback):
21    r"""Provides all actions that are related to network capacity of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include:
22
23    - Logging network capacity during training. See the "Evaluation Metrics" section in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) for more details about network capacity.
24
25    """
26
27    def __init__(
28        self,
29        save_dir: str,
30    ) -> None:
31        r"""
32        **Args:**
33        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
34        """
35        super().__init__(save_dir=save_dir)
36
37        # task ID control
38        self.task_id: int
39        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
40
41    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
42        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
43
44        **Raises:**
45        -**TypeError**: when the `pl_module` is not `HAT`.
46        """
47
48        # get the current task_id from the `CLAlgorithm` object
49        self.task_id = pl_module.task_id
50
51        # sanity check
52        if not isinstance(pl_module, HAT):
53            raise TypeError(
54                "The `CLAlgorithm` should be `HAT` to apply `HATMetricCallback`!"
55            )
56
57    @rank_zero_only
58    def on_train_batch_end(
59        self,
60        trainer: Trainer,
61        pl_module: HAT,
62        outputs: dict[str, Any],
63        batch: Any,
64        batch_idx: int,
65    ) -> None:
66        r"""Plot training mask, adjustment rate and log network capacity after training batch.
67
68        **Args:**
69        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
70        - **batch** (`Any`): the training data batch.
71        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
72        """
73
74        # get the current network capacity
75        capacity = outputs["capacity"]
76
77        # log the network capacity to Lightning loggers
78        pl_module.log(
79            f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True
80        )
class HATNetworkCapacity(clarena.metrics.base.MetricCallback):
21class HATNetworkCapacity(MetricCallback):
22    r"""Provides all actions that are related to network capacity of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include:
23
24    - Logging network capacity during training. See the "Evaluation Metrics" section in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) for more details about network capacity.
25
26    """
27
28    def __init__(
29        self,
30        save_dir: str,
31    ) -> None:
32        r"""
33        **Args:**
34        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
35        """
36        super().__init__(save_dir=save_dir)
37
38        # task ID control
39        self.task_id: int
40        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
41
42    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
43        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
44
45        **Raises:**
46        -**TypeError**: when the `pl_module` is not `HAT`.
47        """
48
49        # get the current task_id from the `CLAlgorithm` object
50        self.task_id = pl_module.task_id
51
52        # sanity check
53        if not isinstance(pl_module, HAT):
54            raise TypeError(
55                "The `CLAlgorithm` should be `HAT` to apply `HATMetricCallback`!"
56            )
57
58    @rank_zero_only
59    def on_train_batch_end(
60        self,
61        trainer: Trainer,
62        pl_module: HAT,
63        outputs: dict[str, Any],
64        batch: Any,
65        batch_idx: int,
66    ) -> None:
67        r"""Plot training mask, adjustment rate and log network capacity after training batch.
68
69        **Args:**
70        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
71        - **batch** (`Any`): the training data batch.
72        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
73        """
74
75        # get the current network capacity
76        capacity = outputs["capacity"]
77
78        # log the network capacity to Lightning loggers
79        pl_module.log(
80            f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True
81        )

Provides all actions that are related to network capacity of HAT (Hard Attention to the Task) algorithm and its extensions, which include:

  • Logging network capacity during training. See the "Evaluation Metrics" section in Sec. 4.1 in the AdaHAT paper for more details about network capacity.
HATNetworkCapacity(save_dir: str)
28    def __init__(
29        self,
30        save_dir: str,
31    ) -> None:
32        r"""
33        **Args:**
34        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
35        """
36        super().__init__(save_dir=save_dir)
37
38        # task ID control
39        self.task_id: int
40        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""

Args:

  • save_dir (str): the directory to save the mask figures. Better inside the output folder.
task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT) -> None:
42    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
43        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
44
45        **Raises:**
46        -**TypeError**: when the `pl_module` is not `HAT`.
47        """
48
49        # get the current task_id from the `CLAlgorithm` object
50        self.task_id = pl_module.task_id
51
52        # sanity check
53        if not isinstance(pl_module, HAT):
54            raise TypeError(
55                "The `CLAlgorithm` should be `HAT` to apply `HATMetricCallback`!"
56            )

Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module to be HAT.

Raises: -TypeError: when the pl_module is not HAT.

@rank_zero_only
def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
58    @rank_zero_only
59    def on_train_batch_end(
60        self,
61        trainer: Trainer,
62        pl_module: HAT,
63        outputs: dict[str, Any],
64        batch: Any,
65        batch_idx: int,
66    ) -> None:
67        r"""Plot training mask, adjustment rate and log network capacity after training batch.
68
69        **Args:**
70        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
71        - **batch** (`Any`): the training data batch.
72        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
73        """
74
75        # get the current network capacity
76        capacity = outputs["capacity"]
77
78        # log the network capacity to Lightning loggers
79        pl_module.log(
80            f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True
81        )

Plot training mask, adjustment rate and log network capacity after training batch.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, which is the returns of the training_step() method in the HAT.
  • batch (Any): the training data batch.
  • batch_idx (int): the index of the current batch. This is for the file name of mask figures.