clarena.metrics.hat_masks

The submodule in metrics for HATMasks.

  1r"""
  2The submodule in `metrics` for `HATMasks`.
  3"""
  4
  5__all__ = ["HATMasks"]
  6
  7import logging
  8import os
  9from typing import Any
 10
 11from lightning import Trainer
 12from lightning.pytorch.utilities import rank_zero_only
 13from matplotlib import pyplot as plt
 14from torch import Tensor
 15
 16from clarena.cl_algorithms import HAT
 17from clarena.metrics import MetricCallback
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23class HATMasks(MetricCallback):
 24    r"""Provides all actions that are related to masks of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include:
 25
 26    - Visualizing mask and cumulative mask figures during training and testing as figures.
 27
 28    The callback is able to produce the following outputs:
 29
 30    - Figures of both training and test, masks and cumulative masks.
 31
 32    """
 33
 34    def __init__(
 35        self,
 36        save_dir: str,
 37        test_masks_dir_name: str | None = None,
 38        test_cumulative_masks_dir_name: str | None = None,
 39        test_masks_grouped_dir_name: str | None = None,
 40        training_masks_dir_name: str | None = None,
 41        plot_training_mask_every_n_steps: int | None = None,
 42    ) -> None:
 43        r"""
 44        **Args:**
 45        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
 46        - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved.
 47        - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
 48        - **test_masks_grouped_dir_name** (`str` | `None`): the directory to save the test mask grouped figures. If `None`, no file will be saved.
 49        - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
 50        - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`.
 51        """
 52        super().__init__(save_dir=save_dir)
 53
 54        # paths
 55        if test_masks_dir_name is not None:
 56            self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name)
 57            r"""The directory to save the test mask figures."""
 58        if test_cumulative_masks_dir_name is not None:
 59            self.test_cumulative_masks_dir: str = os.path.join(
 60                self.save_dir, test_cumulative_masks_dir_name
 61            )
 62            r"""The directory to save the test cumulative mask figures."""
 63            os.makedirs(self.test_cumulative_masks_dir, exist_ok=True)
 64        if test_masks_grouped_dir_name is not None:
 65            self.test_masks_grouped_dir: str = os.path.join(
 66                self.save_dir, test_masks_grouped_dir_name
 67            )
 68            r"""The directory to save the test mask grouped figures."""
 69            os.makedirs(self.test_masks_grouped_dir, exist_ok=True)
 70
 71        if training_masks_dir_name is not None:
 72            self.training_masks_dir: str = os.path.join(
 73                self.save_dir, training_masks_dir_name
 74            )
 75            r"""The directory to save the training mask figures."""
 76            os.makedirs(self.training_masks_dir, exist_ok=True)
 77
 78        # other settings
 79        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
 80        r"""The frequency of plotting training masks in terms of number of batches."""
 81
 82        # task ID control
 83        self.task_id: int
 84        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 85
 86    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
 87        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
 88
 89        **Raises:**
 90        -**TypeError**: when the `pl_module` is not `HAT`.
 91        """
 92
 93        # get the current task_id from the `CLAlgorithm` object
 94        self.task_id = pl_module.task_id
 95
 96        # sanity check
 97        if not isinstance(pl_module, HAT):
 98            raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!")
 99
100    @rank_zero_only
101    def on_train_batch_end(
102        self,
103        trainer: Trainer,
104        pl_module: HAT,
105        outputs: dict[str, Any],
106        batch: Any,
107        batch_idx: int,
108    ) -> None:
109        r"""Plot training mask after training batch.
110
111        **Args:**
112        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
113        - **batch** (`Any`): the training data batch.
114        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
115        """
116
117        # get the mask over the model after training the batch
118        mask = outputs["mask"]
119
120        # plot the mask
121        if hasattr(self, "training_masks_dir"):
122            if batch_idx % self.plot_training_mask_every_n_steps == 0:
123                self.plot_hat_mask(
124                    mask=mask,
125                    plot_dir=self.training_masks_dir,
126                    task_id=self.task_id,
127                    step=batch_idx,
128                )
129
130    @rank_zero_only
131    def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None:
132        r"""Plot test mask and cumulative mask figures."""
133
134        # test mask
135        if hasattr(self, "test_masks_dir"):
136            mask = pl_module.masks[self.task_id]
137            self.plot_hat_mask(
138                mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id
139            )
140
141        # cumulative mask
142        if hasattr(self, "test_cumulative_masks_dir"):
143            cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
144            self.plot_hat_mask(
145                mask=cumulative_mask,
146                plot_dir=self.test_cumulative_masks_dir,
147                task_id=self.task_id,
148            )
149
150        # summative mask
151        if hasattr(self, "test_masks_grouped_dir"):
152            masks = pl_module.backbone.masks
153            self.plot_hat_mask_grouped(
154                masks=masks,
155                plot_dir=self.test_masks_grouped_dir,
156                task_id=self.task_id,
157            )
158
159    def plot_hat_mask(
160        self,
161        mask: dict[str, Tensor],
162        plot_dir: str,
163        task_id: int,
164        step: int | None = None,
165    ) -> None:
166        """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask.
167
168        **Args:**
169        - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
170        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
171        - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name.
172        - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name.
173        """
174
175        for layer_name, m in mask.items():
176            layer_name = layer_name.replace(
177                "/", "."
178            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
179
180            m = m.view(
181                1, -1
182            )  # reshape the 1D mask to 2D so can be plotted by image show
183
184            fig = plt.figure()
185            plt.imshow(
186                m.detach().cpu(), aspect="auto", cmap="Greys"
187            )  # can only convert to tensors in CPU to numpy arrays
188            plt.yticks()  # hide yticks
189            plt.colorbar()
190            if step:
191                plot_name = f"{layer_name}_task{task_id}_step{step}.png"
192            else:
193                plot_name = f"{layer_name}_task{task_id}.png"
194            plot_path = os.path.join(plot_dir, plot_name)
195            fig.savefig(plot_path)
196            plt.close(fig)
197
198    def plot_hat_mask_grouped(
199        self,
200        masks: dict[int, dict[str, Tensor]],
201        plot_dir: str,
202        task_id: int,
203    ) -> None:
204        """Plot masks in a grouped way in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm.
205
206        **Args:**
207        - **masks** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) masks. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ).
208        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
209        - **task_id** (`int`): the task ID till which the masks to be plotted. This is to form the plot name.
210        """
class HATMasks(clarena.metrics.base.MetricCallback):
 24class HATMasks(MetricCallback):
 25    r"""Provides all actions that are related to masks of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include:
 26
 27    - Visualizing mask and cumulative mask figures during training and testing as figures.
 28
 29    The callback is able to produce the following outputs:
 30
 31    - Figures of both training and test, masks and cumulative masks.
 32
 33    """
 34
 35    def __init__(
 36        self,
 37        save_dir: str,
 38        test_masks_dir_name: str | None = None,
 39        test_cumulative_masks_dir_name: str | None = None,
 40        test_masks_grouped_dir_name: str | None = None,
 41        training_masks_dir_name: str | None = None,
 42        plot_training_mask_every_n_steps: int | None = None,
 43    ) -> None:
 44        r"""
 45        **Args:**
 46        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
 47        - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved.
 48        - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
 49        - **test_masks_grouped_dir_name** (`str` | `None`): the directory to save the test mask grouped figures. If `None`, no file will be saved.
 50        - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
 51        - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`.
 52        """
 53        super().__init__(save_dir=save_dir)
 54
 55        # paths
 56        if test_masks_dir_name is not None:
 57            self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name)
 58            r"""The directory to save the test mask figures."""
 59        if test_cumulative_masks_dir_name is not None:
 60            self.test_cumulative_masks_dir: str = os.path.join(
 61                self.save_dir, test_cumulative_masks_dir_name
 62            )
 63            r"""The directory to save the test cumulative mask figures."""
 64            os.makedirs(self.test_cumulative_masks_dir, exist_ok=True)
 65        if test_masks_grouped_dir_name is not None:
 66            self.test_masks_grouped_dir: str = os.path.join(
 67                self.save_dir, test_masks_grouped_dir_name
 68            )
 69            r"""The directory to save the test mask grouped figures."""
 70            os.makedirs(self.test_masks_grouped_dir, exist_ok=True)
 71
 72        if training_masks_dir_name is not None:
 73            self.training_masks_dir: str = os.path.join(
 74                self.save_dir, training_masks_dir_name
 75            )
 76            r"""The directory to save the training mask figures."""
 77            os.makedirs(self.training_masks_dir, exist_ok=True)
 78
 79        # other settings
 80        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
 81        r"""The frequency of plotting training masks in terms of number of batches."""
 82
 83        # task ID control
 84        self.task_id: int
 85        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 86
 87    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
 88        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
 89
 90        **Raises:**
 91        -**TypeError**: when the `pl_module` is not `HAT`.
 92        """
 93
 94        # get the current task_id from the `CLAlgorithm` object
 95        self.task_id = pl_module.task_id
 96
 97        # sanity check
 98        if not isinstance(pl_module, HAT):
 99            raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!")
100
101    @rank_zero_only
102    def on_train_batch_end(
103        self,
104        trainer: Trainer,
105        pl_module: HAT,
106        outputs: dict[str, Any],
107        batch: Any,
108        batch_idx: int,
109    ) -> None:
110        r"""Plot training mask after training batch.
111
112        **Args:**
113        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
114        - **batch** (`Any`): the training data batch.
115        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
116        """
117
118        # get the mask over the model after training the batch
119        mask = outputs["mask"]
120
121        # plot the mask
122        if hasattr(self, "training_masks_dir"):
123            if batch_idx % self.plot_training_mask_every_n_steps == 0:
124                self.plot_hat_mask(
125                    mask=mask,
126                    plot_dir=self.training_masks_dir,
127                    task_id=self.task_id,
128                    step=batch_idx,
129                )
130
131    @rank_zero_only
132    def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None:
133        r"""Plot test mask and cumulative mask figures."""
134
135        # test mask
136        if hasattr(self, "test_masks_dir"):
137            mask = pl_module.masks[self.task_id]
138            self.plot_hat_mask(
139                mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id
140            )
141
142        # cumulative mask
143        if hasattr(self, "test_cumulative_masks_dir"):
144            cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
145            self.plot_hat_mask(
146                mask=cumulative_mask,
147                plot_dir=self.test_cumulative_masks_dir,
148                task_id=self.task_id,
149            )
150
151        # summative mask
152        if hasattr(self, "test_masks_grouped_dir"):
153            masks = pl_module.backbone.masks
154            self.plot_hat_mask_grouped(
155                masks=masks,
156                plot_dir=self.test_masks_grouped_dir,
157                task_id=self.task_id,
158            )
159
160    def plot_hat_mask(
161        self,
162        mask: dict[str, Tensor],
163        plot_dir: str,
164        task_id: int,
165        step: int | None = None,
166    ) -> None:
167        """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask.
168
169        **Args:**
170        - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
171        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
172        - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name.
173        - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name.
174        """
175
176        for layer_name, m in mask.items():
177            layer_name = layer_name.replace(
178                "/", "."
179            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
180
181            m = m.view(
182                1, -1
183            )  # reshape the 1D mask to 2D so can be plotted by image show
184
185            fig = plt.figure()
186            plt.imshow(
187                m.detach().cpu(), aspect="auto", cmap="Greys"
188            )  # can only convert to tensors in CPU to numpy arrays
189            plt.yticks()  # hide yticks
190            plt.colorbar()
191            if step:
192                plot_name = f"{layer_name}_task{task_id}_step{step}.png"
193            else:
194                plot_name = f"{layer_name}_task{task_id}.png"
195            plot_path = os.path.join(plot_dir, plot_name)
196            fig.savefig(plot_path)
197            plt.close(fig)
198
199    def plot_hat_mask_grouped(
200        self,
201        masks: dict[int, dict[str, Tensor]],
202        plot_dir: str,
203        task_id: int,
204    ) -> None:
205        """Plot masks in a grouped way in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm.
206
207        **Args:**
208        - **masks** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) masks. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ).
209        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
210        - **task_id** (`int`): the task ID till which the masks to be plotted. This is to form the plot name.
211        """

Provides all actions that are related to masks of HAT (Hard Attention to the Task) algorithm and its extensions, which include:

  • Visualizing mask and cumulative mask figures during training and testing as figures.

The callback is able to produce the following outputs:

  • Figures of both training and test, masks and cumulative masks.
HATMasks( save_dir: str, test_masks_dir_name: str | None = None, test_cumulative_masks_dir_name: str | None = None, test_masks_grouped_dir_name: str | None = None, training_masks_dir_name: str | None = None, plot_training_mask_every_n_steps: int | None = None)
35    def __init__(
36        self,
37        save_dir: str,
38        test_masks_dir_name: str | None = None,
39        test_cumulative_masks_dir_name: str | None = None,
40        test_masks_grouped_dir_name: str | None = None,
41        training_masks_dir_name: str | None = None,
42        plot_training_mask_every_n_steps: int | None = None,
43    ) -> None:
44        r"""
45        **Args:**
46        - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder.
47        - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved.
48        - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved.
49        - **test_masks_grouped_dir_name** (`str` | `None`): the directory to save the test mask grouped figures. If `None`, no file will be saved.
50        - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved.
51        - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`.
52        """
53        super().__init__(save_dir=save_dir)
54
55        # paths
56        if test_masks_dir_name is not None:
57            self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name)
58            r"""The directory to save the test mask figures."""
59        if test_cumulative_masks_dir_name is not None:
60            self.test_cumulative_masks_dir: str = os.path.join(
61                self.save_dir, test_cumulative_masks_dir_name
62            )
63            r"""The directory to save the test cumulative mask figures."""
64            os.makedirs(self.test_cumulative_masks_dir, exist_ok=True)
65        if test_masks_grouped_dir_name is not None:
66            self.test_masks_grouped_dir: str = os.path.join(
67                self.save_dir, test_masks_grouped_dir_name
68            )
69            r"""The directory to save the test mask grouped figures."""
70            os.makedirs(self.test_masks_grouped_dir, exist_ok=True)
71
72        if training_masks_dir_name is not None:
73            self.training_masks_dir: str = os.path.join(
74                self.save_dir, training_masks_dir_name
75            )
76            r"""The directory to save the training mask figures."""
77            os.makedirs(self.training_masks_dir, exist_ok=True)
78
79        # other settings
80        self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps
81        r"""The frequency of plotting training masks in terms of number of batches."""
82
83        # task ID control
84        self.task_id: int
85        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""

Args:

  • save_dir (str): the directory to save the mask figures. Better inside the output folder.
  • test_masks_dir_name (str | None): the relative path to save_dir to save the test mask figures. If None, no file will be saved.
  • test_cumulative_masks_dir_name (str | None): the directory to save the test cumulative mask figures. If None, no file will be saved.
  • test_masks_grouped_dir_name (str | None): the directory to save the test mask grouped figures. If None, no file will be saved.
  • training_masks_dir_name (str | None): the directory to save the training mask figures. If None, no file will be saved.
  • plot_training_mask_every_n_steps (int | None): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when training_masks_dir_name is not None.
plot_training_mask_every_n_steps: int

The frequency of plotting training masks in terms of number of batches.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT) -> None:
87    def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None:
88        r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`.
89
90        **Raises:**
91        -**TypeError**: when the `pl_module` is not `HAT`.
92        """
93
94        # get the current task_id from the `CLAlgorithm` object
95        self.task_id = pl_module.task_id
96
97        # sanity check
98        if not isinstance(pl_module, HAT):
99            raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!")

Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module to be HAT.

Raises: -TypeError: when the pl_module is not HAT.

@rank_zero_only
def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
101    @rank_zero_only
102    def on_train_batch_end(
103        self,
104        trainer: Trainer,
105        pl_module: HAT,
106        outputs: dict[str, Any],
107        batch: Any,
108        batch_idx: int,
109    ) -> None:
110        r"""Plot training mask after training batch.
111
112        **Args:**
113        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`.
114        - **batch** (`Any`): the training data batch.
115        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
116        """
117
118        # get the mask over the model after training the batch
119        mask = outputs["mask"]
120
121        # plot the mask
122        if hasattr(self, "training_masks_dir"):
123            if batch_idx % self.plot_training_mask_every_n_steps == 0:
124                self.plot_hat_mask(
125                    mask=mask,
126                    plot_dir=self.training_masks_dir,
127                    task_id=self.task_id,
128                    step=batch_idx,
129                )

Plot training mask after training batch.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, which is the returns of the training_step() method in the HAT.
  • batch (Any): the training data batch.
  • batch_idx (int): the index of the current batch. This is for the file name of mask figures.
@rank_zero_only
def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT) -> None:
131    @rank_zero_only
132    def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None:
133        r"""Plot test mask and cumulative mask figures."""
134
135        # test mask
136        if hasattr(self, "test_masks_dir"):
137            mask = pl_module.masks[self.task_id]
138            self.plot_hat_mask(
139                mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id
140            )
141
142        # cumulative mask
143        if hasattr(self, "test_cumulative_masks_dir"):
144            cumulative_mask = pl_module.cumulative_mask_for_previous_tasks
145            self.plot_hat_mask(
146                mask=cumulative_mask,
147                plot_dir=self.test_cumulative_masks_dir,
148                task_id=self.task_id,
149            )
150
151        # summative mask
152        if hasattr(self, "test_masks_grouped_dir"):
153            masks = pl_module.backbone.masks
154            self.plot_hat_mask_grouped(
155                masks=masks,
156                plot_dir=self.test_masks_grouped_dir,
157                task_id=self.task_id,
158            )

Plot test mask and cumulative mask figures.

def plot_hat_mask( self, mask: dict[str, torch.Tensor], plot_dir: str, task_id: int, step: int | None = None) -> None:
160    def plot_hat_mask(
161        self,
162        mask: dict[str, Tensor],
163        plot_dir: str,
164        task_id: int,
165        step: int | None = None,
166    ) -> None:
167        """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask.
168
169        **Args:**
170        - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ).
171        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
172        - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name.
173        - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name.
174        """
175
176        for layer_name, m in mask.items():
177            layer_name = layer_name.replace(
178                "/", "."
179            )  # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'.
180
181            m = m.view(
182                1, -1
183            )  # reshape the 1D mask to 2D so can be plotted by image show
184
185            fig = plt.figure()
186            plt.imshow(
187                m.detach().cpu(), aspect="auto", cmap="Greys"
188            )  # can only convert to tensors in CPU to numpy arrays
189            plt.yticks()  # hide yticks
190            plt.colorbar()
191            if step:
192                plot_name = f"{layer_name}_task{task_id}_step{step}.png"
193            else:
194                plot_name = f"{layer_name}_task{task_id}.png"
195            plot_path = os.path.join(plot_dir, plot_name)
196            fig.savefig(plot_path)
197            plt.close(fig)

Plot mask in HAT (Hard Attention to the Task)) algorithm. This includes the mask and cumulative mask.

Args:

  • mask (dict[str, Tensor]): the hard attention (whose values are 0 or 1) mask. Keys (str) are layer name and values (Tensor) are the mask tensors. The mask tensor has size (number of units, ).
  • plot_dir (str): the directory to save plot. Better same as the output directory of the experiment.
  • task_id (int): the task ID of the mask to be plotted. This is to form the plot name.
  • step (int): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep None for not showing the step in the plot name.
def plot_hat_mask_grouped( self, masks: dict[int, dict[str, torch.Tensor]], plot_dir: str, task_id: int) -> None:
199    def plot_hat_mask_grouped(
200        self,
201        masks: dict[int, dict[str, Tensor]],
202        plot_dir: str,
203        task_id: int,
204    ) -> None:
205        """Plot masks in a grouped way in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm.
206
207        **Args:**
208        - **masks** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) masks. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ).
209        - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment.
210        - **task_id** (`int`): the task ID till which the masks to be plotted. This is to form the plot name.
211        """

Plot masks in a grouped way in HAT (Hard Attention to the Task)) algorithm.

Args:

  • masks (dict[str, Tensor]): the hard attention (whose values are 0 or 1) masks. Keys are task IDs and values are the corresponding mask. Each mask is a dict where keys are layer names and values are the binary mask tensor for the layer. The mask tensor has size (number of units, ).
  • plot_dir (str): the directory to save plot. Better same as the output directory of the experiment.
  • task_id (int): the task ID till which the masks to be plotted. This is to form the plot name.