clarena.metrics.hat_adjustment_rate

The submodule in metrics for HATAdjustmentRate.

  1r"""
  2The submodule in `metrics` for `HATAdjustmentRate`.
  3"""
  4
  5__all__ = ["HATAdjustmentRate"]
  6
  7import logging
  8import os
  9from typing import Any
 10
 11from lightning import Trainer
 12from lightning.pytorch.utilities import rank_zero_only
 13from matplotlib import pyplot as plt
 14from matplotlib.colors import LogNorm
 15from torch import Tensor
 16
 17from clarena.cl_algorithms import HAT
 18from clarena.metrics import MetricCallback
 19
 20# always get logger for built-in logging in each module
 21pylogger = logging.getLogger(__name__)
 22
 23
 24class HATAdjustmentRate(MetricCallback):
 25    r"""Provides all actions that are related to adjustment rate of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include:
 26
 27    - Visualizing adjustment rate during training as figures.
 28
 29    The callback is able to produce the following outputs:
 30
 31    - Figures of training adjustment rate.
 32
 33    """
 34
 35    def __init__(
 36        self,
 37        save_dir: str,
 38        plot_adjustment_rate_every_n_steps: int | None = None,
 39    ) -> None:
 40        r"""
 41        **Args:**
 42        - **save_dir** (`str` | `None`): The directory to save the adjustment rate figures. Better inside the output folder.
 43        - **plot_adjustment_rate_every_n_steps** (`int` | `None`): the frequency of plotting adjustment rate figures in terms of number of batches during training.
 44        """
 45        super().__init__(save_dir=save_dir)
 46
 47        # other settings
 48        self.plot_adjustment_rate_every_n_steps: int = (
 49            plot_adjustment_rate_every_n_steps
 50        )
 51        r"""The frequency of plotting adjustment rate in terms of number of batches."""
 52
 53        # task ID control
 54        self.task_id: int
 55        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 56
 57    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
 58        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
 59
 60        **Raises:**
 61        -**TypeError**: when the `pl_module` is not `HAT`.
 62        """
 63
 64        # get the current task_id from the `CLAlgorithm` object
 65        self.task_id = pl_module.task_id
 66
 67        # sanity check
 68        if not isinstance(pl_module, HAT):
 69            raise TypeError(
 70                "The `CLAlgorithm` should be `HAT` to apply `HATAdjustmentRate`!"
 71            )
 72
 73    @rank_zero_only
 74    def on_train_batch_end(
 75        self,
 76        trainer: Trainer,
 77        pl_module: HAT,
 78        outputs: dict[str, Any],
 79        batch: Any,
 80        batch_idx: int,
 81    ) -> None:
 82        r"""Plot adjustment rate after training batch.
 83
 84        **Args:**
 85        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
 86        - **batch** (`Any`): the training data batch.
 87        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
 88        """
 89
 90        # get the adjustment rate
 91        adjustment_rate_weight = outputs["adjustment_rate_weight"]
 92        adjustment_rate_bias = outputs["adjustment_rate_bias"]
 93
 94        # plot the adjustment rate
 95        if self.save_dir is not None:
 96            if self.task_id > 1:
 97                if batch_idx % self.plot_adjustment_rate_every_n_steps == 0:
 98                    self.plot_hat_adjustment_rate(
 99                        adjustment_rate=adjustment_rate_weight,
100                        weight_or_bias="weight",
101                        step=batch_idx,
102                    )
103                    self.plot_hat_adjustment_rate(
104                        adjustment_rate=adjustment_rate_bias,
105                        weight_or_bias="bias",
106                        step=batch_idx,
107                    )
108
109    def plot_hat_adjustment_rate(
110        self,
111        adjustment_rate: dict[str, Tensor],
112        weight_or_bias: str,
113        step: int | None = None,
114    ) -> None:
115        """Plot adjustment rate in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the adjustment rate weight and adjustment rate bias (if applicable).
116
117        **Args:**
118        - **adjustment_rate** (`dict[str, Tensor]`): the adjustment rate. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. If it's adjustment rate weight, it has size same as weights. If it's adjustment rate bias, it has size same as biases.
119        - **weight_or_bias** (`str`): the type of adjustment rate. It can be either 'weight' or 'bias'. This is to form the plot name.
120        - **step** (`int`): the training step (batch index) of the adjustment rate to be plotted. This is to form the plot name. Keep `None` for not showing the step in the plot name.
121        """
122
123        for layer_name, a in adjustment_rate.items():
124            layer_name = layer_name.replace(
125                "/", "."
126            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
127
128            if weight_or_bias == "bias":
129                a = a.view(
130                    1, -1
131                )  # reshape the 1D mask to 2D so can be plotted by image show
132
133            fig = plt.figure()
134            plt.imshow(
135                a.detach().cpu(),
136                aspect="auto",
137                cmap="Wistia",
138                norm=LogNorm(vmin=1e-7, vmax=1e-5),
139            )  # can only convert to tensors in CPU to numpy arrays
140            plt.yticks()  # hide yticks
141            plt.colorbar()
142            if step:
143                plot_name = (
144                    f"{layer_name}_{weight_or_bias}_task{self.task_id}_step{step}.png"
145                )
146            else:
147                plot_name = f"{layer_name}_{weight_or_bias}_task{self.task_id}.png"
148            plot_path = os.path.join(self.save_dir, plot_name)
149            fig.savefig(plot_path)
150            plt.close(fig)
class HATAdjustmentRate(clarena.metrics.base.MetricCallback):
 25class HATAdjustmentRate(MetricCallback):
 26    r"""Provides all actions that are related to adjustment rate of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include:
 27
 28    - Visualizing adjustment rate during training as figures.
 29
 30    The callback is able to produce the following outputs:
 31
 32    - Figures of training adjustment rate.
 33
 34    """
 35
 36    def __init__(
 37        self,
 38        save_dir: str,
 39        plot_adjustment_rate_every_n_steps: int | None = None,
 40    ) -> None:
 41        r"""
 42        **Args:**
 43        - **save_dir** (`str` | `None`): The directory to save the adjustment rate figures. Better inside the output folder.
 44        - **plot_adjustment_rate_every_n_steps** (`int` | `None`): the frequency of plotting adjustment rate figures in terms of number of batches during training.
 45        """
 46        super().__init__(save_dir=save_dir)
 47
 48        # other settings
 49        self.plot_adjustment_rate_every_n_steps: int = (
 50            plot_adjustment_rate_every_n_steps
 51        )
 52        r"""The frequency of plotting adjustment rate in terms of number of batches."""
 53
 54        # task ID control
 55        self.task_id: int
 56        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 57
 58    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
 59        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
 60
 61        **Raises:**
 62        -**TypeError**: when the `pl_module` is not `HAT`.
 63        """
 64
 65        # get the current task_id from the `CLAlgorithm` object
 66        self.task_id = pl_module.task_id
 67
 68        # sanity check
 69        if not isinstance(pl_module, HAT):
 70            raise TypeError(
 71                "The `CLAlgorithm` should be `HAT` to apply `HATAdjustmentRate`!"
 72            )
 73
 74    @rank_zero_only
 75    def on_train_batch_end(
 76        self,
 77        trainer: Trainer,
 78        pl_module: HAT,
 79        outputs: dict[str, Any],
 80        batch: Any,
 81        batch_idx: int,
 82    ) -> None:
 83        r"""Plot adjustment rate after training batch.
 84
 85        **Args:**
 86        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
 87        - **batch** (`Any`): the training data batch.
 88        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
 89        """
 90
 91        # get the adjustment rate
 92        adjustment_rate_weight = outputs["adjustment_rate_weight"]
 93        adjustment_rate_bias = outputs["adjustment_rate_bias"]
 94
 95        # plot the adjustment rate
 96        if self.save_dir is not None:
 97            if self.task_id > 1:
 98                if batch_idx % self.plot_adjustment_rate_every_n_steps == 0:
 99                    self.plot_hat_adjustment_rate(
100                        adjustment_rate=adjustment_rate_weight,
101                        weight_or_bias="weight",
102                        step=batch_idx,
103                    )
104                    self.plot_hat_adjustment_rate(
105                        adjustment_rate=adjustment_rate_bias,
106                        weight_or_bias="bias",
107                        step=batch_idx,
108                    )
109
110    def plot_hat_adjustment_rate(
111        self,
112        adjustment_rate: dict[str, Tensor],
113        weight_or_bias: str,
114        step: int | None = None,
115    ) -> None:
116        """Plot adjustment rate in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the adjustment rate weight and adjustment rate bias (if applicable).
117
118        **Args:**
119        - **adjustment_rate** (`dict[str, Tensor]`): the adjustment rate. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. If it's adjustment rate weight, it has size same as weights. If it's adjustment rate bias, it has size same as biases.
120        - **weight_or_bias** (`str`): the type of adjustment rate. It can be either 'weight' or 'bias'. This is to form the plot name.
121        - **step** (`int`): the training step (batch index) of the adjustment rate to be plotted. This is to form the plot name. Keep `None` for not showing the step in the plot name.
122        """
123
124        for layer_name, a in adjustment_rate.items():
125            layer_name = layer_name.replace(
126                "/", "."
127            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
128
129            if weight_or_bias == "bias":
130                a = a.view(
131                    1, -1
132                )  # reshape the 1D mask to 2D so can be plotted by image show
133
134            fig = plt.figure()
135            plt.imshow(
136                a.detach().cpu(),
137                aspect="auto",
138                cmap="Wistia",
139                norm=LogNorm(vmin=1e-7, vmax=1e-5),
140            )  # can only convert to tensors in CPU to numpy arrays
141            plt.yticks()  # hide yticks
142            plt.colorbar()
143            if step:
144                plot_name = (
145                    f"{layer_name}_{weight_or_bias}_task{self.task_id}_step{step}.png"
146                )
147            else:
148                plot_name = f"{layer_name}_{weight_or_bias}_task{self.task_id}.png"
149            plot_path = os.path.join(self.save_dir, plot_name)
150            fig.savefig(plot_path)
151            plt.close(fig)

Provides all actions that are related to adjustment rate of HAT (Hard Attention to the Task) algorithm and its extensions, which include:

  • Visualizing adjustment rate during training as figures.

The callback is able to produce the following outputs:

  • Figures of training adjustment rate.
HATAdjustmentRate(save_dir: str, plot_adjustment_rate_every_n_steps: int | None = None)
36    def __init__(
37        self,
38        save_dir: str,
39        plot_adjustment_rate_every_n_steps: int | None = None,
40    ) -> None:
41        r"""
42        **Args:**
43        - **save_dir** (`str` | `None`): The directory to save the adjustment rate figures. Better inside the output folder.
44        - **plot_adjustment_rate_every_n_steps** (`int` | `None`): the frequency of plotting adjustment rate figures in terms of number of batches during training.
45        """
46        super().__init__(save_dir=save_dir)
47
48        # other settings
49        self.plot_adjustment_rate_every_n_steps: int = (
50            plot_adjustment_rate_every_n_steps
51        )
52        r"""The frequency of plotting adjustment rate in terms of number of batches."""
53
54        # task ID control
55        self.task_id: int
56        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""

Args:

  • save_dir (str | None): The directory to save the adjustment rate figures. Better inside the output folder.
  • plot_adjustment_rate_every_n_steps (int | None): the frequency of plotting adjustment rate figures in terms of number of batches during training.
plot_adjustment_rate_every_n_steps: int

The frequency of plotting adjustment rate in terms of number of batches.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT) -> None:
58    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
59        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
60
61        **Raises:**
62        -**TypeError**: when the `pl_module` is not `HAT`.
63        """
64
65        # get the current task_id from the `CLAlgorithm` object
66        self.task_id = pl_module.task_id
67
68        # sanity check
69        if not isinstance(pl_module, HAT):
70            raise TypeError(
71                "The `CLAlgorithm` should be `HAT` to apply `HATAdjustmentRate`!"
72            )

Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module to be HAT.

Raises: -TypeError: when the pl_module is not HAT.

@rank_zero_only
def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 74    @rank_zero_only
 75    def on_train_batch_end(
 76        self,
 77        trainer: Trainer,
 78        pl_module: HAT,
 79        outputs: dict[str, Any],
 80        batch: Any,
 81        batch_idx: int,
 82    ) -> None:
 83        r"""Plot adjustment rate after training batch.
 84
 85        **Args:**
 86        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
 87        - **batch** (`Any`): the training data batch.
 88        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
 89        """
 90
 91        # get the adjustment rate
 92        adjustment_rate_weight = outputs["adjustment_rate_weight"]
 93        adjustment_rate_bias = outputs["adjustment_rate_bias"]
 94
 95        # plot the adjustment rate
 96        if self.save_dir is not None:
 97            if self.task_id > 1:
 98                if batch_idx % self.plot_adjustment_rate_every_n_steps == 0:
 99                    self.plot_hat_adjustment_rate(
100                        adjustment_rate=adjustment_rate_weight,
101                        weight_or_bias="weight",
102                        step=batch_idx,
103                    )
104                    self.plot_hat_adjustment_rate(
105                        adjustment_rate=adjustment_rate_bias,
106                        weight_or_bias="bias",
107                        step=batch_idx,
108                    )

Plot adjustment rate after training batch.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, which is the returns of the training_step() method in the HAT.
  • batch (Any): the training data batch.
  • batch_idx (int): the index of the current batch. This is for the file name of mask figures.
def plot_hat_adjustment_rate( self, adjustment_rate: dict[str, torch.Tensor], weight_or_bias: str, step: int | None = None) -> None:
110    def plot_hat_adjustment_rate(
111        self,
112        adjustment_rate: dict[str, Tensor],
113        weight_or_bias: str,
114        step: int | None = None,
115    ) -> None:
116        """Plot adjustment rate in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the adjustment rate weight and adjustment rate bias (if applicable).
117
118        **Args:**
119        - **adjustment_rate** (`dict[str, Tensor]`): the adjustment rate. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. If it's adjustment rate weight, it has size same as weights. If it's adjustment rate bias, it has size same as biases.
120        - **weight_or_bias** (`str`): the type of adjustment rate. It can be either 'weight' or 'bias'. This is to form the plot name.
121        - **step** (`int`): the training step (batch index) of the adjustment rate to be plotted. This is to form the plot name. Keep `None` for not showing the step in the plot name.
122        """
123
124        for layer_name, a in adjustment_rate.items():
125            layer_name = layer_name.replace(
126                "/", "."
127            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
128
129            if weight_or_bias == "bias":
130                a = a.view(
131                    1, -1
132                )  # reshape the 1D mask to 2D so can be plotted by image show
133
134            fig = plt.figure()
135            plt.imshow(
136                a.detach().cpu(),
137                aspect="auto",
138                cmap="Wistia",
139                norm=LogNorm(vmin=1e-7, vmax=1e-5),
140            )  # can only convert to tensors in CPU to numpy arrays
141            plt.yticks()  # hide yticks
142            plt.colorbar()
143            if step:
144                plot_name = (
145                    f"{layer_name}_{weight_or_bias}_task{self.task_id}_step{step}.png"
146                )
147            else:
148                plot_name = f"{layer_name}_{weight_or_bias}_task{self.task_id}.png"
149            plot_path = os.path.join(self.save_dir, plot_name)
150            fig.savefig(plot_path)
151            plt.close(fig)

Plot adjustment rate in HAT (Hard Attention to the Task)) algorithm. This includes the adjustment rate weight and adjustment rate bias (if applicable).

Args:

  • adjustment_rate (dict[str, Tensor]): the adjustment rate. Keys (str) are layer names and values (Tensor) are the adjustment rate tensors. If it's adjustment rate weight, it has size same as weights. If it's adjustment rate bias, it has size same as biases.
  • weight_or_bias (str): the type of adjustment rate. It can be either 'weight' or 'bias'. This is to form the plot name.
  • step (int): the training step (batch index) of the adjustment rate to be plotted. This is to form the plot name. Keep None for not showing the step in the plot name.