clarena.callbacks.hat_metrics

The submodule in callbacks for HATMetricsCallback.

  1r"""
  2The submodule in `callbacks` for `HATMetricsCallback`.
  3"""
  4
  5__all__ = ["HATMetricsCallback"]
  6
  7import logging
  8import os
  9from typing import Any
 10
 11from lightning import Callback, Trainer
 12
 13from clarena.cl_algorithms import HAT, CLAlgorithm
 14from clarena.utils import plot
 15
 16# always get logger for built-in logging in each module
 17pylogger = logging.getLogger(__name__)
 18
 19
 20class HATMetricsCallback(Callback):
 21    r"""Provides all actions that are related to metrics used for [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm, which include:
 22
 23    - Visualising mask and cumulative mask figures during training and testing as figures.
 24    - Logging network capacity during training. See the "Evaluation Metrics" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) for more details about network capacity.
 25
 26    Lightning provides `self.log()` to log metrics in `LightningModule` where our `CLAlgorithm` based. You can put `self.log()` here if you don't want to mess up the `CLAlgorithm` with a huge amount of logging codes.
 27
 28    The callback is able to produce the following outputs:
 29
 30    - **Mask Figures**: both training and test, masks and cumulative masks.
 31
 32
 33    """
 34
 35    def __init__(
 36        self,
 37        test_masks_plot_dir: str | None,
 38        test_cumulative_masks_plot_dir: str | None,
 39        training_masks_plot_dir: str | None,
 40        plot_training_mask_every_n_steps: int | None,
 41    ) -> None:
 42        r"""Initialise the `HATMetricsCallback`.
 43
 44        **Args:**
 45        - **test_masks_plot_dir** (`str` | `None`): the directory to save the test mask figures. If `None`, no file will be saved.
 46        - **test_cumulative_masks_plot_dir** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
 47        - **training_masks_plot_dir** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
 48        - **plot_training_mask_every_n_steps** (`int`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_plot_dir` is not `None`.
 49        """
 50        Callback.__init__(self)
 51
 52        # paths
 53        if not os.path.exists(test_masks_plot_dir):
 54            os.makedirs(test_masks_plot_dir, exist_ok=True)
 55        self.test_masks_plot_dir: str = test_masks_plot_dir
 56        r"""Store the directory to save the test mask figures."""
 57        if not os.path.exists(test_cumulative_masks_plot_dir):
 58            os.makedirs(test_cumulative_masks_plot_dir, exist_ok=True)
 59        self.test_cumulative_masks_plot_dir: str = test_cumulative_masks_plot_dir
 60        r"""Store the directory to save the test cumulative mask figures."""
 61        if not os.path.exists(training_masks_plot_dir):
 62            os.makedirs(training_masks_plot_dir, exist_ok=True)
 63        self.training_masks_plot_dir: str = training_masks_plot_dir
 64        r"""Store the directory to save the training mask figures."""
 65
 66        # other settings
 67        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
 68        r"""Store the frequency of plotting training masks in terms of number of batches."""
 69
 70        self.task_id: int
 71        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
 72
 73    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 74        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT` or `AdaHAT`.
 75
 76        **Raises:**
 77        -**TypeError**: when the `pl_module` is not `HAT` or `AdaHAT`.
 78        """
 79
 80        # get the current task_id from the `CLAlgorithm` object
 81        self.task_id = pl_module.task_id
 82
 83        # sanity check
 84        if not isinstance(pl_module, HAT):
 85            raise TypeError(
 86                "The `CLAlgorithm` should be `HAT` or `AdaHAT` to apply `HATMetricsCallback`!"
 87            )
 88
 89    def on_train_batch_end(
 90        self,
 91        trainer: Trainer,
 92        pl_module: CLAlgorithm,
 93        outputs: dict[str, Any],
 94        batch: Any,
 95        batch_idx: int,
 96    ) -> None:
 97        r"""Plot training mask and log network capacity after training batch.
 98
 99        **Args:**
100        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
101        - **batch** (`Any`): the training data batch.
102        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
103        """
104
105        # get the mask over the model after training the batch
106        mask = outputs["mask"]
107        # get the current network capacity
108        capacity = outputs["capacity"]
109
110        # plot the mask
111        if batch_idx % self.plot_training_mask_every_n_steps == 0:
112            plot.plot_hat_mask(
113                mask=mask,
114                plot_dir=self.training_masks_plot_dir,
115                task_id=self.task_id,
116                step=batch_idx,
117            )
118
119        # log the network capacity to Lightning loggers
120        pl_module.log(
121            f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True
122        )
123
124    def on_test_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
125        r"""Plot test mask and cumulative mask figures."""
126
127        # test mask
128        mask = pl_module.masks[f"{self.task_id}"]
129        plot.plot_hat_mask(
130            mask=mask, plot_dir=self.test_masks_plot_dir, task_id=self.task_id
131        )
132
133        # cumulative mask
134        cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
135        plot.plot_hat_mask(
136            mask=cumulative_mask,
137            plot_dir=self.test_cumulative_masks_plot_dir,
138            task_id=self.task_id,
139        )
class HATMetricsCallback(lightning.pytorch.callbacks.callback.Callback):
 21class HATMetricsCallback(Callback):
 22    r"""Provides all actions that are related to metrics used for [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm, which include:
 23
 24    - Visualising mask and cumulative mask figures during training and testing as figures.
 25    - Logging network capacity during training. See the "Evaluation Metrics" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) for more details about network capacity.
 26
 27    Lightning provides `self.log()` to log metrics in `LightningModule` where our `CLAlgorithm` based. You can put `self.log()` here if you don't want to mess up the `CLAlgorithm` with a huge amount of logging codes.
 28
 29    The callback is able to produce the following outputs:
 30
 31    - **Mask Figures**: both training and test, masks and cumulative masks.
 32
 33
 34    """
 35
 36    def __init__(
 37        self,
 38        test_masks_plot_dir: str | None,
 39        test_cumulative_masks_plot_dir: str | None,
 40        training_masks_plot_dir: str | None,
 41        plot_training_mask_every_n_steps: int | None,
 42    ) -> None:
 43        r"""Initialise the `HATMetricsCallback`.
 44
 45        **Args:**
 46        - **test_masks_plot_dir** (`str` | `None`): the directory to save the test mask figures. If `None`, no file will be saved.
 47        - **test_cumulative_masks_plot_dir** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
 48        - **training_masks_plot_dir** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
 49        - **plot_training_mask_every_n_steps** (`int`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_plot_dir` is not `None`.
 50        """
 51        Callback.__init__(self)
 52
 53        # paths
 54        if not os.path.exists(test_masks_plot_dir):
 55            os.makedirs(test_masks_plot_dir, exist_ok=True)
 56        self.test_masks_plot_dir: str = test_masks_plot_dir
 57        r"""Store the directory to save the test mask figures."""
 58        if not os.path.exists(test_cumulative_masks_plot_dir):
 59            os.makedirs(test_cumulative_masks_plot_dir, exist_ok=True)
 60        self.test_cumulative_masks_plot_dir: str = test_cumulative_masks_plot_dir
 61        r"""Store the directory to save the test cumulative mask figures."""
 62        if not os.path.exists(training_masks_plot_dir):
 63            os.makedirs(training_masks_plot_dir, exist_ok=True)
 64        self.training_masks_plot_dir: str = training_masks_plot_dir
 65        r"""Store the directory to save the training mask figures."""
 66
 67        # other settings
 68        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
 69        r"""Store the frequency of plotting training masks in terms of number of batches."""
 70
 71        self.task_id: int
 72        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
 73
 74    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 75        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT` or `AdaHAT`.
 76
 77        **Raises:**
 78        -**TypeError**: when the `pl_module` is not `HAT` or `AdaHAT`.
 79        """
 80
 81        # get the current task_id from the `CLAlgorithm` object
 82        self.task_id = pl_module.task_id
 83
 84        # sanity check
 85        if not isinstance(pl_module, HAT):
 86            raise TypeError(
 87                "The `CLAlgorithm` should be `HAT` or `AdaHAT` to apply `HATMetricsCallback`!"
 88            )
 89
 90    def on_train_batch_end(
 91        self,
 92        trainer: Trainer,
 93        pl_module: CLAlgorithm,
 94        outputs: dict[str, Any],
 95        batch: Any,
 96        batch_idx: int,
 97    ) -> None:
 98        r"""Plot training mask and log network capacity after training batch.
 99
100        **Args:**
101        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
102        - **batch** (`Any`): the training data batch.
103        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
104        """
105
106        # get the mask over the model after training the batch
107        mask = outputs["mask"]
108        # get the current network capacity
109        capacity = outputs["capacity"]
110
111        # plot the mask
112        if batch_idx % self.plot_training_mask_every_n_steps == 0:
113            plot.plot_hat_mask(
114                mask=mask,
115                plot_dir=self.training_masks_plot_dir,
116                task_id=self.task_id,
117                step=batch_idx,
118            )
119
120        # log the network capacity to Lightning loggers
121        pl_module.log(
122            f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True
123        )
124
125    def on_test_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
126        r"""Plot test mask and cumulative mask figures."""
127
128        # test mask
129        mask = pl_module.masks[f"{self.task_id}"]
130        plot.plot_hat_mask(
131            mask=mask, plot_dir=self.test_masks_plot_dir, task_id=self.task_id
132        )
133
134        # cumulative mask
135        cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
136        plot.plot_hat_mask(
137            mask=cumulative_mask,
138            plot_dir=self.test_cumulative_masks_plot_dir,
139            task_id=self.task_id,
140        )

Provides all actions that are related to metrics used for HAT (Hard Attention to the Task) algorithm, which include:

  • Visualising mask and cumulative mask figures during training and testing as figures.
  • Logging network capacity during training. See the "Evaluation Metrics" section in chapter 4.1 in AdaHAT paper for more details about network capacity.

Lightning provides self.log() to log metrics in LightningModule where our CLAlgorithm based. You can put self.log() here if you don't want to mess up the CLAlgorithm with a huge amount of logging codes.

The callback is able to produce the following outputs:

  • Mask Figures: both training and test, masks and cumulative masks.
HATMetricsCallback( test_masks_plot_dir: str | None, test_cumulative_masks_plot_dir: str | None, training_masks_plot_dir: str | None, plot_training_mask_every_n_steps: int | None)
36    def __init__(
37        self,
38        test_masks_plot_dir: str | None,
39        test_cumulative_masks_plot_dir: str | None,
40        training_masks_plot_dir: str | None,
41        plot_training_mask_every_n_steps: int | None,
42    ) -> None:
43        r"""Initialise the `HATMetricsCallback`.
44
45        **Args:**
46        - **test_masks_plot_dir** (`str` | `None`): the directory to save the test mask figures. If `None`, no file will be saved.
47        - **test_cumulative_masks_plot_dir** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
48        - **training_masks_plot_dir** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
49        - **plot_training_mask_every_n_steps** (`int`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_plot_dir` is not `None`.
50        """
51        Callback.__init__(self)
52
53        # paths
54        if not os.path.exists(test_masks_plot_dir):
55            os.makedirs(test_masks_plot_dir, exist_ok=True)
56        self.test_masks_plot_dir: str = test_masks_plot_dir
57        r"""Store the directory to save the test mask figures."""
58        if not os.path.exists(test_cumulative_masks_plot_dir):
59            os.makedirs(test_cumulative_masks_plot_dir, exist_ok=True)
60        self.test_cumulative_masks_plot_dir: str = test_cumulative_masks_plot_dir
61        r"""Store the directory to save the test cumulative mask figures."""
62        if not os.path.exists(training_masks_plot_dir):
63            os.makedirs(training_masks_plot_dir, exist_ok=True)
64        self.training_masks_plot_dir: str = training_masks_plot_dir
65        r"""Store the directory to save the training mask figures."""
66
67        # other settings
68        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
69        r"""Store the frequency of plotting training masks in terms of number of batches."""
70
71        self.task_id: int
72        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""

Initialise the HATMetricsCallback.

Args:

  • test_masks_plot_dir (str | None): the directory to save the test mask figures. If None, no file will be saved.
  • test_cumulative_masks_plot_dir (str | None): the directory to save the test cumulative mask figures. If None, no file will be saved.
  • training_masks_plot_dir (str | None): the directory to save the training mask figures. If None, no file will be saved.
  • plot_training_mask_every_n_steps (int): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when training_masks_plot_dir is not None.
test_masks_plot_dir: str

Store the directory to save the test mask figures.

test_cumulative_masks_plot_dir: str

Store the directory to save the test cumulative mask figures.

training_masks_plot_dir: str

Store the directory to save the training mask figures.

plot_training_mask_every_n_steps: int

Store the frequency of plotting training masks in terms of number of batches.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
74    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
75        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT` or `AdaHAT`.
76
77        **Raises:**
78        -**TypeError**: when the `pl_module` is not `HAT` or `AdaHAT`.
79        """
80
81        # get the current task_id from the `CLAlgorithm` object
82        self.task_id = pl_module.task_id
83
84        # sanity check
85        if not isinstance(pl_module, HAT):
86            raise TypeError(
87                "The `CLAlgorithm` should be `HAT` or `AdaHAT` to apply `HATMetricsCallback`!"
88            )

Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module to be HAT or AdaHAT.

Raises: -TypeError: when the pl_module is not HAT or AdaHAT.

def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 90    def on_train_batch_end(
 91        self,
 92        trainer: Trainer,
 93        pl_module: CLAlgorithm,
 94        outputs: dict[str, Any],
 95        batch: Any,
 96        batch_idx: int,
 97    ) -> None:
 98        r"""Plot training mask and log network capacity after training batch.
 99
100        **Args:**
101        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
102        - **batch** (`Any`): the training data batch.
103        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
104        """
105
106        # get the mask over the model after training the batch
107        mask = outputs["mask"]
108        # get the current network capacity
109        capacity = outputs["capacity"]
110
111        # plot the mask
112        if batch_idx % self.plot_training_mask_every_n_steps == 0:
113            plot.plot_hat_mask(
114                mask=mask,
115                plot_dir=self.training_masks_plot_dir,
116                task_id=self.task_id,
117                step=batch_idx,
118            )
119
120        # log the network capacity to Lightning loggers
121        pl_module.log(
122            f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True
123        )

Plot training mask and log network capacity after training batch.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, which is the returns of the training_step() method in the CLAlgorithm.
  • batch (Any): the training data batch.
  • batch_idx (int): the index of the current batch. This is for the file name of mask figures.
def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
125    def on_test_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
126        r"""Plot test mask and cumulative mask figures."""
127
128        # test mask
129        mask = pl_module.masks[f"{self.task_id}"]
130        plot.plot_hat_mask(
131            mask=mask, plot_dir=self.test_masks_plot_dir, task_id=self.task_id
132        )
133
134        # cumulative mask
135        cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
136        plot.plot_hat_mask(
137            mask=cumulative_mask,
138            plot_dir=self.test_cumulative_masks_plot_dir,
139            task_id=self.task_id,
140        )

Plot test mask and cumulative mask figures.