clarena.metrics.hat_masks
The submodule in metrics
for HATMasks
.
1r""" 2The submodule in `metrics` for `HATMasks`. 3""" 4 5__all__ = ["HATMasks"] 6 7import logging 8import os 9from typing import Any 10 11from lightning import Trainer 12from lightning.pytorch.utilities import rank_zero_only 13from matplotlib import pyplot as plt 14from torch import Tensor 15 16from clarena.cl_algorithms import HAT 17from clarena.metrics import MetricCallback 18 19# always get logger for built-in logging in each module 20pylogger = logging.getLogger(__name__) 21 22 23class HATMasks(MetricCallback): 24 r"""Provides all actions that are related to masks of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include: 25 26 - Visualizing mask and cumulative mask figures during training and testing as figures. 27 28 The callback is able to produce the following outputs: 29 30 - Figures of both training and test, masks and cumulative masks. 31 32 """ 33 34 def __init__( 35 self, 36 save_dir: str, 37 test_masks_dir_name: str | None = None, 38 test_cumulative_masks_dir_name: str | None = None, 39 training_masks_dir_name: str | None = None, 40 plot_training_mask_every_n_steps: int | None = None, 41 ) -> None: 42 r""" 43 **Args:** 44 - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder. 45 - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved. 46 - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved. 47 - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved. 48 - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`. 49 """ 50 super().__init__(save_dir=save_dir) 51 52 # paths 53 if test_masks_dir_name is not None: 54 self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name) 55 r"""The directory to save the test mask figures.""" 56 if test_cumulative_masks_dir_name is not None: 57 self.test_cumulative_masks_dir: str = os.path.join( 58 self.save_dir, test_cumulative_masks_dir_name 59 ) 60 r"""The directory to save the test cumulative mask figures.""" 61 if training_masks_dir_name is not None: 62 self.training_masks_dir: str = os.path.join( 63 self.save_dir, training_masks_dir_name 64 ) 65 r"""The directory to save the training mask figures.""" 66 67 # other settings 68 self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps 69 r"""The frequency of plotting training masks in terms of number of batches.""" 70 71 # task ID control 72 self.task_id: int 73 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 74 75 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 76 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 77 78 **Raises:** 79 -**TypeError**: when the `pl_module` is not `HAT`. 80 """ 81 82 # get the current task_id from the `CLAlgorithm` object 83 self.task_id = pl_module.task_id 84 85 # sanity check 86 if not isinstance(pl_module, HAT): 87 raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!") 88 89 @rank_zero_only 90 def on_train_batch_end( 91 self, 92 trainer: Trainer, 93 pl_module: HAT, 94 outputs: dict[str, Any], 95 batch: Any, 96 batch_idx: int, 97 ) -> None: 98 r"""Plot training mask after training batch. 99 100 **Args:** 101 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 102 - **batch** (`Any`): the training data batch. 103 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 104 """ 105 106 # get the mask over the model after training the batch 107 mask = outputs["mask"] 108 109 # plot the mask 110 if hasattr(self, "training_masks_dir"): 111 if batch_idx % self.plot_training_mask_every_n_steps == 0: 112 self.plot_hat_mask( 113 mask=mask, 114 plot_dir=self.training_masks_dir, 115 task_id=self.task_id, 116 step=batch_idx, 117 ) 118 119 @rank_zero_only 120 def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None: 121 r"""Plot test mask and cumulative mask figures.""" 122 123 # test mask 124 if hasattr(self, "test_masks_dir"): 125 mask = pl_module.masks[self.task_id] 126 self.plot_hat_mask( 127 mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id 128 ) 129 130 # cumulative mask 131 if hasattr(self, "test_cumulative_masks_dir"): 132 cumulative_mask = pl_module.cumulative_mask_for_previous_tasks 133 self.plot_hat_mask( 134 mask=cumulative_mask, 135 plot_dir=self.test_cumulative_masks_dir, 136 task_id=self.task_id, 137 ) 138 139 def plot_hat_mask( 140 self, 141 mask: dict[str, Tensor], 142 plot_dir: str, 143 task_id: int, 144 step: int | None = None, 145 ) -> None: 146 """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask. 147 148 **Args:** 149 - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 150 - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment. 151 - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name. 152 - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name. 153 """ 154 155 for layer_name, m in mask.items(): 156 layer_name = layer_name.replace( 157 "/", "." 158 ) # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'. 159 160 m = m.view( 161 1, -1 162 ) # reshape the 1D mask to 2D so can be plotted by image show 163 164 fig = plt.figure() 165 plt.imshow( 166 m.detach().cpu(), aspect="auto", cmap="Greys" 167 ) # can only convert to tensors in CPU to numpy arrays 168 plt.yticks() # hide yticks 169 plt.colorbar() 170 if step: 171 plot_name = f"{layer_name}_task{task_id}_step{step}.png" 172 else: 173 plot_name = f"{layer_name}_task{task_id}.png" 174 plot_path = os.path.join(plot_dir, plot_name) 175 fig.savefig(plot_path) 176 plt.close(fig) 177 plt.close(fig) 178 plt.close(fig)
24class HATMasks(MetricCallback): 25 r"""Provides all actions that are related to masks of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include: 26 27 - Visualizing mask and cumulative mask figures during training and testing as figures. 28 29 The callback is able to produce the following outputs: 30 31 - Figures of both training and test, masks and cumulative masks. 32 33 """ 34 35 def __init__( 36 self, 37 save_dir: str, 38 test_masks_dir_name: str | None = None, 39 test_cumulative_masks_dir_name: str | None = None, 40 training_masks_dir_name: str | None = None, 41 plot_training_mask_every_n_steps: int | None = None, 42 ) -> None: 43 r""" 44 **Args:** 45 - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder. 46 - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved. 47 - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved. 48 - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved. 49 - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`. 50 """ 51 super().__init__(save_dir=save_dir) 52 53 # paths 54 if test_masks_dir_name is not None: 55 self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name) 56 r"""The directory to save the test mask figures.""" 57 if test_cumulative_masks_dir_name is not None: 58 self.test_cumulative_masks_dir: str = os.path.join( 59 self.save_dir, test_cumulative_masks_dir_name 60 ) 61 r"""The directory to save the test cumulative mask figures.""" 62 if training_masks_dir_name is not None: 63 self.training_masks_dir: str = os.path.join( 64 self.save_dir, training_masks_dir_name 65 ) 66 r"""The directory to save the training mask figures.""" 67 68 # other settings 69 self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps 70 r"""The frequency of plotting training masks in terms of number of batches.""" 71 72 # task ID control 73 self.task_id: int 74 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 75 76 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 77 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 78 79 **Raises:** 80 -**TypeError**: when the `pl_module` is not `HAT`. 81 """ 82 83 # get the current task_id from the `CLAlgorithm` object 84 self.task_id = pl_module.task_id 85 86 # sanity check 87 if not isinstance(pl_module, HAT): 88 raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!") 89 90 @rank_zero_only 91 def on_train_batch_end( 92 self, 93 trainer: Trainer, 94 pl_module: HAT, 95 outputs: dict[str, Any], 96 batch: Any, 97 batch_idx: int, 98 ) -> None: 99 r"""Plot training mask after training batch. 100 101 **Args:** 102 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 103 - **batch** (`Any`): the training data batch. 104 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 105 """ 106 107 # get the mask over the model after training the batch 108 mask = outputs["mask"] 109 110 # plot the mask 111 if hasattr(self, "training_masks_dir"): 112 if batch_idx % self.plot_training_mask_every_n_steps == 0: 113 self.plot_hat_mask( 114 mask=mask, 115 plot_dir=self.training_masks_dir, 116 task_id=self.task_id, 117 step=batch_idx, 118 ) 119 120 @rank_zero_only 121 def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None: 122 r"""Plot test mask and cumulative mask figures.""" 123 124 # test mask 125 if hasattr(self, "test_masks_dir"): 126 mask = pl_module.masks[self.task_id] 127 self.plot_hat_mask( 128 mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id 129 ) 130 131 # cumulative mask 132 if hasattr(self, "test_cumulative_masks_dir"): 133 cumulative_mask = pl_module.cumulative_mask_for_previous_tasks 134 self.plot_hat_mask( 135 mask=cumulative_mask, 136 plot_dir=self.test_cumulative_masks_dir, 137 task_id=self.task_id, 138 ) 139 140 def plot_hat_mask( 141 self, 142 mask: dict[str, Tensor], 143 plot_dir: str, 144 task_id: int, 145 step: int | None = None, 146 ) -> None: 147 """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask. 148 149 **Args:** 150 - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 151 - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment. 152 - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name. 153 - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name. 154 """ 155 156 for layer_name, m in mask.items(): 157 layer_name = layer_name.replace( 158 "/", "." 159 ) # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'. 160 161 m = m.view( 162 1, -1 163 ) # reshape the 1D mask to 2D so can be plotted by image show 164 165 fig = plt.figure() 166 plt.imshow( 167 m.detach().cpu(), aspect="auto", cmap="Greys" 168 ) # can only convert to tensors in CPU to numpy arrays 169 plt.yticks() # hide yticks 170 plt.colorbar() 171 if step: 172 plot_name = f"{layer_name}_task{task_id}_step{step}.png" 173 else: 174 plot_name = f"{layer_name}_task{task_id}.png" 175 plot_path = os.path.join(plot_dir, plot_name) 176 fig.savefig(plot_path) 177 plt.close(fig) 178 plt.close(fig) 179 plt.close(fig)
Provides all actions that are related to masks of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
- Visualizing mask and cumulative mask figures during training and testing as figures.
The callback is able to produce the following outputs:
- Figures of both training and test, masks and cumulative masks.
35 def __init__( 36 self, 37 save_dir: str, 38 test_masks_dir_name: str | None = None, 39 test_cumulative_masks_dir_name: str | None = None, 40 training_masks_dir_name: str | None = None, 41 plot_training_mask_every_n_steps: int | None = None, 42 ) -> None: 43 r""" 44 **Args:** 45 - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder. 46 - **test_masks_dir_name** (`str` | `None`): the relative path to `save_dir` to save the test mask figures. If `None`, no file will be saved. 47 - **test_cumulative_masks_dir_name** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved. 48 - **training_masks_dir_name** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved. 49 - **plot_training_mask_every_n_steps** (`int` | `None`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_dir_name` is not `None`. 50 """ 51 super().__init__(save_dir=save_dir) 52 53 # paths 54 if test_masks_dir_name is not None: 55 self.test_masks_dir: str = os.path.join(self.save_dir, test_masks_dir_name) 56 r"""The directory to save the test mask figures.""" 57 if test_cumulative_masks_dir_name is not None: 58 self.test_cumulative_masks_dir: str = os.path.join( 59 self.save_dir, test_cumulative_masks_dir_name 60 ) 61 r"""The directory to save the test cumulative mask figures.""" 62 if training_masks_dir_name is not None: 63 self.training_masks_dir: str = os.path.join( 64 self.save_dir, training_masks_dir_name 65 ) 66 r"""The directory to save the training mask figures.""" 67 68 # other settings 69 self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps 70 r"""The frequency of plotting training masks in terms of number of batches.""" 71 72 # task ID control 73 self.task_id: int 74 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
Args:
- save_dir (
str
): the directory to save the mask figures. Better inside the output folder. - test_masks_dir_name (
str
|None
): the relative path tosave_dir
to save the test mask figures. IfNone
, no file will be saved. - test_cumulative_masks_dir_name (
str
|None
): the directory to save the test cumulative mask figures. IfNone
, no file will be saved. - training_masks_dir_name (
str
|None
): the directory to save the training mask figures. IfNone
, no file will be saved. - plot_training_mask_every_n_steps (
int
|None
): the frequency of plotting training mask figures in terms of number of batches during training. Only applies whentraining_masks_dir_name
is notNone
.
The frequency of plotting training masks in terms of number of batches.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks
.
76 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 77 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 78 79 **Raises:** 80 -**TypeError**: when the `pl_module` is not `HAT`. 81 """ 82 83 # get the current task_id from the `CLAlgorithm` object 84 self.task_id = pl_module.task_id 85 86 # sanity check 87 if not isinstance(pl_module, HAT): 88 raise TypeError("The `CLAlgorithm` should be `HAT` to apply `HATMasks`!")
Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module
to be HAT
.
Raises:
-TypeError: when the pl_module
is not HAT
.
90 @rank_zero_only 91 def on_train_batch_end( 92 self, 93 trainer: Trainer, 94 pl_module: HAT, 95 outputs: dict[str, Any], 96 batch: Any, 97 batch_idx: int, 98 ) -> None: 99 r"""Plot training mask after training batch. 100 101 **Args:** 102 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 103 - **batch** (`Any`): the training data batch. 104 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 105 """ 106 107 # get the mask over the model after training the batch 108 mask = outputs["mask"] 109 110 # plot the mask 111 if hasattr(self, "training_masks_dir"): 112 if batch_idx % self.plot_training_mask_every_n_steps == 0: 113 self.plot_hat_mask( 114 mask=mask, 115 plot_dir=self.training_masks_dir, 116 task_id=self.task_id, 117 step=batch_idx, 118 )
Plot training mask after training batch.
Args:
- outputs (
dict[str, Any]
): the outputs of the training step, which is the returns of thetraining_step()
method in theHAT
. - batch (
Any
): the training data batch. - batch_idx (
int
): the index of the current batch. This is for the file name of mask figures.
120 @rank_zero_only 121 def on_test_start(self, trainer: Trainer, pl_module: HAT) -> None: 122 r"""Plot test mask and cumulative mask figures.""" 123 124 # test mask 125 if hasattr(self, "test_masks_dir"): 126 mask = pl_module.masks[self.task_id] 127 self.plot_hat_mask( 128 mask=mask, plot_dir=self.test_masks_dir, task_id=self.task_id 129 ) 130 131 # cumulative mask 132 if hasattr(self, "test_cumulative_masks_dir"): 133 cumulative_mask = pl_module.cumulative_mask_for_previous_tasks 134 self.plot_hat_mask( 135 mask=cumulative_mask, 136 plot_dir=self.test_cumulative_masks_dir, 137 task_id=self.task_id, 138 )
Plot test mask and cumulative mask figures.
140 def plot_hat_mask( 141 self, 142 mask: dict[str, Tensor], 143 plot_dir: str, 144 task_id: int, 145 step: int | None = None, 146 ) -> None: 147 """Plot mask in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the mask and cumulative mask. 148 149 **Args:** 150 - **mask** (`dict[str, Tensor]`): the hard attention (whose values are 0 or 1) mask. Keys (`str`) are layer name and values (`Tensor`) are the mask tensors. The mask tensor has size (number of units, ). 151 - **plot_dir** (`str`): the directory to save plot. Better same as the output directory of the experiment. 152 - **task_id** (`int`): the task ID of the mask to be plotted. This is to form the plot name. 153 - **step** (`int`): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. Keep `None` for not showing the step in the plot name. 154 """ 155 156 for layer_name, m in mask.items(): 157 layer_name = layer_name.replace( 158 "/", "." 159 ) # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'. 160 161 m = m.view( 162 1, -1 163 ) # reshape the 1D mask to 2D so can be plotted by image show 164 165 fig = plt.figure() 166 plt.imshow( 167 m.detach().cpu(), aspect="auto", cmap="Greys" 168 ) # can only convert to tensors in CPU to numpy arrays 169 plt.yticks() # hide yticks 170 plt.colorbar() 171 if step: 172 plot_name = f"{layer_name}_task{task_id}_step{step}.png" 173 else: 174 plot_name = f"{layer_name}_task{task_id}.png" 175 plot_path = os.path.join(plot_dir, plot_name) 176 fig.savefig(plot_path) 177 plt.close(fig) 178 plt.close(fig) 179 plt.close(fig)
Plot mask in HAT (Hard Attention to the Task)) algorithm. This includes the mask and cumulative mask.
Args:
- mask (
dict[str, Tensor]
): the hard attention (whose values are 0 or 1) mask. Keys (str
) are layer name and values (Tensor
) are the mask tensors. The mask tensor has size (number of units, ). - plot_dir (
str
): the directory to save plot. Better same as the output directory of the experiment. - task_id (
int
): the task ID of the mask to be plotted. This is to form the plot name. - step (
int
): the training step (batch index) of the mask to be plotted. Apply to the training mask only. This is to form the plot name. KeepNone
for not showing the step in the plot name.