clarena.metrics.hat_masks

The submodule in metrics for HATMasks.

  1r"""
  2The submodule in `metrics` for `HATMasks`.
  3"""
  4
  5__all__ = ["HATMasks"]
  6
  7import logging
  8import os
  9from typing import Any
 10
 11from lightning import Trainer
 12from lightning.pytorch.utilities import rank_zero_only
 13from matplotlib import pyplot as plt
 14from torch import Tensor
 15
 16from clarena.cl_algorithms import HAT
 17from clarena.metrics import MetricCallback
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23class HATMasks(MetricCallback):
 24    r"""Provides all actions that are related to masks of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include:
 25
 26    - Visualizing mask and cumulative mask figures during training and testing as figures.
 27
 28    The callback is able to produce the following outputs:
 29
 30    - Figures of both training and test, masks and cumulative masks.
 31
 32    """
 33
 34    def __init__(
 35        self,
 36        save_dir: str,
 37        test_masks_dir_name: str | None = None,
 38        test_cumulative_masks_dir_name: str | None = None,
 39        training_masks_dir_name: str | None = None,
 40        plot_training_mask_every_n_steps: int | None = None,
 41    ) -> None:
 42        r"""
 43        **Args:**
 44        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
 45        - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved.
 46        - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
 47        - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
 48        - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`.
 49        """
 50        super().__init__(save_dir=save_dir)
 51
 52        # paths
 53        if test_masks_dir_name is not None:
 54            self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name)
 55            r"""The directory to save the test mask figures."""
 56        if test_cumulative_masks_dir_name is not None:
 57            self.test_cumulative_masks_dir: str = os.path.join(
 58                self.save_dir, test_cumulative_masks_dir_name
 59            )
 60            r"""The directory to save the test cumulative mask figures."""
 61        if training_masks_dir_name is not None:
 62            self.training_masks_dir: str = os.path.join(
 63                self.save_dir, training_masks_dir_name
 64            )
 65            r"""The directory to save the training mask figures."""
 66
 67        # other settings
 68        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
 69        r"""The frequency of plotting training masks in terms of number of batches."""
 70
 71        # task ID control
 72        self.task_id: int
 73        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 74
 75    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
 76        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
 77
 78        **Raises:**
 79        -**TypeError**: when the `pl_module` is not `HAT`.
 80        """
 81
 82        # get the current task_id from the `CLAlgorithm` object
 83        self.task_id = pl_module.task_id
 84
 85        # sanity check
 86        if not isinstance(pl_module, HAT):
 87            raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!")
 88
 89    @rank_zero_only
 90    def on_train_batch_end(
 91        self,
 92        trainer: Trainer,
 93        pl_module: HAT,
 94        outputs: dict[str, Any],
 95        batch: Any,
 96        batch_idx: int,
 97    ) -> None:
 98        r"""Plot training mask after training batch.
 99
100        **Args:**
101        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
102        - **batch** (`Any`): the training data batch.
103        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
104        """
105
106        # get the mask over the model after training the batch
107        mask = outputs["mask"]
108
109        # plot the mask
110        if hasattr(self, "training_masks_dir"):
111            if batch_idx % self.plot_training_mask_every_n_steps == 0:
112                self.plot_hat_mask(
113                    mask=mask,
114                    plot_dir=self.training_masks_dir,
115                    task_id=self.task_id,
116                    step=batch_idx,
117                )
118
119    @rank_zero_only
120    def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None:
121        r"""Plot test mask and cumulative mask figures."""
122
123        # test mask
124        if hasattr(self, "test_masks_dir"):
125            mask = pl_module.masks[self.task_id]
126            self.plot_hat_mask(
127                mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id
128            )
129
130        # cumulative mask
131        if hasattr(self, "test_cumulative_masks_dir"):
132            cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
133            self.plot_hat_mask(
134                mask=cumulative_mask,
135                plot_dir=self.test_cumulative_masks_dir,
136                task_id=self.task_id,
137            )
138
139    def plot_hat_mask(
140        self,
141        mask: dict[str, Tensor],
142        plot_dir: str,
143        task_id: int,
144        step: int | None = None,
145    ) -> None:
146        """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask.
147
148        **Args:**
149        - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
150        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
151        - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name.
152        - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name.
153        """
154
155        for layer_name, m in mask.items():
156            layer_name = layer_name.replace(
157                "/", "."
158            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
159
160            m = m.view(
161                1, -1
162            )  # reshape the 1D mask to 2D so can be plotted by image show
163
164            fig = plt.figure()
165            plt.imshow(
166                m.detach().cpu(), aspect="auto", cmap="Greys"
167            )  # can only convert to tensors in CPU to numpy arrays
168            plt.yticks()  # hide yticks
169            plt.colorbar()
170            if step:
171                plot_name = f"{layer_name}_task{task_id}_step{step}.png"
172            else:
173                plot_name = f"{layer_name}_task{task_id}.png"
174            plot_path = os.path.join(plot_dir, plot_name)
175            fig.savefig(plot_path)
176            plt.close(fig)
177            plt.close(fig)
178            plt.close(fig)
class HATMasks(clarena.metrics.base.MetricCallback):
 24class HATMasks(MetricCallback):
 25    r"""Provides all actions that are related to masks of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include:
 26
 27    - Visualizing mask and cumulative mask figures during training and testing as figures.
 28
 29    The callback is able to produce the following outputs:
 30
 31    - Figures of both training and test, masks and cumulative masks.
 32
 33    """
 34
 35    def __init__(
 36        self,
 37        save_dir: str,
 38        test_masks_dir_name: str | None = None,
 39        test_cumulative_masks_dir_name: str | None = None,
 40        training_masks_dir_name: str | None = None,
 41        plot_training_mask_every_n_steps: int | None = None,
 42    ) -> None:
 43        r"""
 44        **Args:**
 45        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
 46        - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved.
 47        - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
 48        - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
 49        - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`.
 50        """
 51        super().__init__(save_dir=save_dir)
 52
 53        # paths
 54        if test_masks_dir_name is not None:
 55            self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name)
 56            r"""The directory to save the test mask figures."""
 57        if test_cumulative_masks_dir_name is not None:
 58            self.test_cumulative_masks_dir: str = os.path.join(
 59                self.save_dir, test_cumulative_masks_dir_name
 60            )
 61            r"""The directory to save the test cumulative mask figures."""
 62        if training_masks_dir_name is not None:
 63            self.training_masks_dir: str = os.path.join(
 64                self.save_dir, training_masks_dir_name
 65            )
 66            r"""The directory to save the training mask figures."""
 67
 68        # other settings
 69        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
 70        r"""The frequency of plotting training masks in terms of number of batches."""
 71
 72        # task ID control
 73        self.task_id: int
 74        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 75
 76    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
 77        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
 78
 79        **Raises:**
 80        -**TypeError**: when the `pl_module` is not `HAT`.
 81        """
 82
 83        # get the current task_id from the `CLAlgorithm` object
 84        self.task_id = pl_module.task_id
 85
 86        # sanity check
 87        if not isinstance(pl_module, HAT):
 88            raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!")
 89
 90    @rank_zero_only
 91    def on_train_batch_end(
 92        self,
 93        trainer: Trainer,
 94        pl_module: HAT,
 95        outputs: dict[str, Any],
 96        batch: Any,
 97        batch_idx: int,
 98    ) -> None:
 99        r"""Plot training mask after training batch.
100
101        **Args:**
102        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
103        - **batch** (`Any`): the training data batch.
104        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
105        """
106
107        # get the mask over the model after training the batch
108        mask = outputs["mask"]
109
110        # plot the mask
111        if hasattr(self, "training_masks_dir"):
112            if batch_idx % self.plot_training_mask_every_n_steps == 0:
113                self.plot_hat_mask(
114                    mask=mask,
115                    plot_dir=self.training_masks_dir,
116                    task_id=self.task_id,
117                    step=batch_idx,
118                )
119
120    @rank_zero_only
121    def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None:
122        r"""Plot test mask and cumulative mask figures."""
123
124        # test mask
125        if hasattr(self, "test_masks_dir"):
126            mask = pl_module.masks[self.task_id]
127            self.plot_hat_mask(
128                mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id
129            )
130
131        # cumulative mask
132        if hasattr(self, "test_cumulative_masks_dir"):
133            cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
134            self.plot_hat_mask(
135                mask=cumulative_mask,
136                plot_dir=self.test_cumulative_masks_dir,
137                task_id=self.task_id,
138            )
139
140    def plot_hat_mask(
141        self,
142        mask: dict[str, Tensor],
143        plot_dir: str,
144        task_id: int,
145        step: int | None = None,
146    ) -> None:
147        """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask.
148
149        **Args:**
150        - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
151        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
152        - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name.
153        - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name.
154        """
155
156        for layer_name, m in mask.items():
157            layer_name = layer_name.replace(
158                "/", "."
159            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
160
161            m = m.view(
162                1, -1
163            )  # reshape the 1D mask to 2D so can be plotted by image show
164
165            fig = plt.figure()
166            plt.imshow(
167                m.detach().cpu(), aspect="auto", cmap="Greys"
168            )  # can only convert to tensors in CPU to numpy arrays
169            plt.yticks()  # hide yticks
170            plt.colorbar()
171            if step:
172                plot_name = f"{layer_name}_task{task_id}_step{step}.png"
173            else:
174                plot_name = f"{layer_name}_task{task_id}.png"
175            plot_path = os.path.join(plot_dir, plot_name)
176            fig.savefig(plot_path)
177            plt.close(fig)
178            plt.close(fig)
179            plt.close(fig)

Provides all actions that are related to masks of HAT (Hard Attention to the Task) algorithm and its extensions, which include:

  • Visualizing mask and cumulative mask figures during training and testing as figures.

The callback is able to produce the following outputs:

  • Figures of both training and test, masks and cumulative masks.
HATMasks( save_dir: str, test_masks_dir_name: str | None = None, test_cumulative_masks_dir_name: str | None = None, training_masks_dir_name: str | None = None, plot_training_mask_every_n_steps: int | None = None)
35    def __init__(
36        self,
37        save_dir: str,
38        test_masks_dir_name: str | None = None,
39        test_cumulative_masks_dir_name: str | None = None,
40        training_masks_dir_name: str | None = None,
41        plot_training_mask_every_n_steps: int | None = None,
42    ) -> None:
43        r"""
44        **Args:**
45        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
46        - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved.
47        - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
48        - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
49        - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`.
50        """
51        super().__init__(save_dir=save_dir)
52
53        # paths
54        if test_masks_dir_name is not None:
55            self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name)
56            r"""The directory to save the test mask figures."""
57        if test_cumulative_masks_dir_name is not None:
58            self.test_cumulative_masks_dir: str = os.path.join(
59                self.save_dir, test_cumulative_masks_dir_name
60            )
61            r"""The directory to save the test cumulative mask figures."""
62        if training_masks_dir_name is not None:
63            self.training_masks_dir: str = os.path.join(
64                self.save_dir, training_masks_dir_name
65            )
66            r"""The directory to save the training mask figures."""
67
68        # other settings
69        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
70        r"""The frequency of plotting training masks in terms of number of batches."""
71
72        # task ID control
73        self.task_id: int
74        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""

Args:

  • save_dir (str): the directory to save the mask figures. Better inside the output folder.
  • test_masks_dir_name (str | None): the relative path to save_dir to save the test mask figures. If None, no file will be saved.
  • test_cumulative_masks_dir_name (str | None): the directory to save the test cumulative mask figures. If None, no file will be saved.
  • training_masks_dir_name (str | None): the directory to save the training mask figures. If None, no file will be saved.
  • plot_training_mask_every_n_steps (int | None): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when training_masks_dir_name is not None.
plot_training_mask_every_n_steps: int

The frequency of plotting training masks in terms of number of batches.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT) -> None:
76    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
77        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
78
79        **Raises:**
80        -**TypeError**: when the `pl_module` is not `HAT`.
81        """
82
83        # get the current task_id from the `CLAlgorithm` object
84        self.task_id = pl_module.task_id
85
86        # sanity check
87        if not isinstance(pl_module, HAT):
88            raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!")

Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module to be HAT.

Raises: -TypeError: when the pl_module is not HAT.

@rank_zero_only
def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 90    @rank_zero_only
 91    def on_train_batch_end(
 92        self,
 93        trainer: Trainer,
 94        pl_module: HAT,
 95        outputs: dict[str, Any],
 96        batch: Any,
 97        batch_idx: int,
 98    ) -> None:
 99        r"""Plot training mask after training batch.
100
101        **Args:**
102        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
103        - **batch** (`Any`): the training data batch.
104        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
105        """
106
107        # get the mask over the model after training the batch
108        mask = outputs["mask"]
109
110        # plot the mask
111        if hasattr(self, "training_masks_dir"):
112            if batch_idx % self.plot_training_mask_every_n_steps == 0:
113                self.plot_hat_mask(
114                    mask=mask,
115                    plot_dir=self.training_masks_dir,
116                    task_id=self.task_id,
117                    step=batch_idx,
118                )

Plot training mask after training batch.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, which is the returns of the training_step() method in the HAT.
  • batch (Any): the training data batch.
  • batch_idx (int): the index of the current batch. This is for the file name of mask figures.
@rank_zero_only
def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT) -> None:
120    @rank_zero_only
121    def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None:
122        r"""Plot test mask and cumulative mask figures."""
123
124        # test mask
125        if hasattr(self, "test_masks_dir"):
126            mask = pl_module.masks[self.task_id]
127            self.plot_hat_mask(
128                mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id
129            )
130
131        # cumulative mask
132        if hasattr(self, "test_cumulative_masks_dir"):
133            cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
134            self.plot_hat_mask(
135                mask=cumulative_mask,
136                plot_dir=self.test_cumulative_masks_dir,
137                task_id=self.task_id,
138            )

Plot test mask and cumulative mask figures.

def plot_hat_mask( self, mask: dict[str, torch.Tensor], plot_dir: str, task_id: int, step: int | None = None) -> None:
140    def plot_hat_mask(
141        self,
142        mask: dict[str, Tensor],
143        plot_dir: str,
144        task_id: int,
145        step: int | None = None,
146    ) -> None:
147        """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask.
148
149        **Args:**
150        - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
151        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
152        - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name.
153        - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name.
154        """
155
156        for layer_name, m in mask.items():
157            layer_name = layer_name.replace(
158                "/", "."
159            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
160
161            m = m.view(
162                1, -1
163            )  # reshape the 1D mask to 2D so can be plotted by image show
164
165            fig = plt.figure()
166            plt.imshow(
167                m.detach().cpu(), aspect="auto", cmap="Greys"
168            )  # can only convert to tensors in CPU to numpy arrays
169            plt.yticks()  # hide yticks
170            plt.colorbar()
171            if step:
172                plot_name = f"{layer_name}_task{task_id}_step{step}.png"
173            else:
174                plot_name = f"{layer_name}_task{task_id}.png"
175            plot_path = os.path.join(plot_dir, plot_name)
176            fig.savefig(plot_path)
177            plt.close(fig)
178            plt.close(fig)
179            plt.close(fig)

Plot mask in HAT (Hard Attention to the Task)) algorithm. This includes the mask and cumulative mask.

Args:

  • mask (dict[str, Tensor]): the hard attention (whose values are 0 or 1) mask. Keys (str) are layer name and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
  • plot_dir (str): the directory to save plot. Better same as the output directory of the experiment.
  • task_id (int): the task ID of the mask to be plotted. This is to form the plot name.
  • step (int): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep None for not showing the step in the plot name.