clarena.callbacks.hat_metrics
The submodule in callbacks
for HATMetricsCallback
.
1r""" 2The submodule in `callbacks` for `HATMetricsCallback`. 3""" 4 5__all__ = ["HATMetricsCallback"] 6 7import logging 8import os 9from typing import Any 10 11from lightning import Callback, Trainer 12 13from clarena.cl_algorithms import HAT, CLAlgorithm 14from clarena.utils import plot 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class HATMetricsCallback(Callback): 21 r"""Provides all actions that are related to metrics used for [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm, which include: 22 23 - Visualising mask and cumulative mask figures during training and testing as figures. 24 - Logging network capacity during training. See the "Evaluation Metrics" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) for more details about network capacity. 25 26 Lightning provides `self.log()` to log metrics in `LightningModule` where our `CLAlgorithm` based. You can put `self.log()` here if you don't want to mess up the `CLAlgorithm` with a huge amount of logging codes. 27 28 The callback is able to produce the following outputs: 29 30 - **Mask Figures**: both training and test, masks and cumulative masks. 31 32 33 """ 34 35 def __init__( 36 self, 37 test_masks_plot_dir: str | None, 38 test_cumulative_masks_plot_dir: str | None, 39 training_masks_plot_dir: str | None, 40 plot_training_mask_every_n_steps: int | None, 41 ) -> None: 42 r"""Initialise the `HATMetricsCallback`. 43 44 **Args:** 45 - **test_masks_plot_dir** (`str` | `None`): the directory to save the test mask figures. If `None`, no file will be saved. 46 - **test_cumulative_masks_plot_dir** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved. 47 - **training_masks_plot_dir** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved. 48 - **plot_training_mask_every_n_steps** (`int`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_plot_dir` is not `None`. 49 """ 50 Callback.__init__(self) 51 52 # paths 53 if not os.path.exists(test_masks_plot_dir): 54 os.makedirs(test_masks_plot_dir, exist_ok=True) 55 self.test_masks_plot_dir: str = test_masks_plot_dir 56 r"""Store the directory to save the test mask figures.""" 57 if not os.path.exists(test_cumulative_masks_plot_dir): 58 os.makedirs(test_cumulative_masks_plot_dir, exist_ok=True) 59 self.test_cumulative_masks_plot_dir: str = test_cumulative_masks_plot_dir 60 r"""Store the directory to save the test cumulative mask figures.""" 61 if not os.path.exists(training_masks_plot_dir): 62 os.makedirs(training_masks_plot_dir, exist_ok=True) 63 self.training_masks_plot_dir: str = training_masks_plot_dir 64 r"""Store the directory to save the training mask figures.""" 65 66 # other settings 67 self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps 68 r"""Store the frequency of plotting training masks in terms of number of batches.""" 69 70 self.task_id: int 71 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 72 73 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 74 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT` or `AdaHAT`. 75 76 **Raises:** 77 -**TypeError**: when the `pl_module` is not `HAT` or `AdaHAT`. 78 """ 79 80 # get the current task_id from the `CLAlgorithm` object 81 self.task_id = pl_module.task_id 82 83 # sanity check 84 if not isinstance(pl_module, HAT): 85 raise TypeError( 86 "The `CLAlgorithm` should be `HAT` or `AdaHAT` to apply `HATMetricsCallback`!" 87 ) 88 89 def on_train_batch_end( 90 self, 91 trainer: Trainer, 92 pl_module: CLAlgorithm, 93 outputs: dict[str, Any], 94 batch: Any, 95 batch_idx: int, 96 ) -> None: 97 r"""Plot training mask and log network capacity after training batch. 98 99 **Args:** 100 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 101 - **batch** (`Any`): the training data batch. 102 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 103 """ 104 105 # get the mask over the model after training the batch 106 mask = outputs["mask"] 107 # get the current network capacity 108 capacity = outputs["capacity"] 109 110 # plot the mask 111 if batch_idx % self.plot_training_mask_every_n_steps == 0: 112 plot.plot_hat_mask( 113 mask=mask, 114 plot_dir=self.training_masks_plot_dir, 115 task_id=self.task_id, 116 step=batch_idx, 117 ) 118 119 # log the network capacity to Lightning loggers 120 pl_module.log( 121 f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True 122 ) 123 124 def on_test_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 125 r"""Plot test mask and cumulative mask figures.""" 126 127 # test mask 128 mask = pl_module.masks[f"{self.task_id}"] 129 plot.plot_hat_mask( 130 mask=mask, plot_dir=self.test_masks_plot_dir, task_id=self.task_id 131 ) 132 133 # cumulative mask 134 cumulative_mask = pl_module.cumulative_mask_for_previous_tasks 135 plot.plot_hat_mask( 136 mask=cumulative_mask, 137 plot_dir=self.test_cumulative_masks_plot_dir, 138 task_id=self.task_id, 139 )
21class HATMetricsCallback(Callback): 22 r"""Provides all actions that are related to metrics used for [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm, which include: 23 24 - Visualising mask and cumulative mask figures during training and testing as figures. 25 - Logging network capacity during training. See the "Evaluation Metrics" section in chapter 4.1 in [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) for more details about network capacity. 26 27 Lightning provides `self.log()` to log metrics in `LightningModule` where our `CLAlgorithm` based. You can put `self.log()` here if you don't want to mess up the `CLAlgorithm` with a huge amount of logging codes. 28 29 The callback is able to produce the following outputs: 30 31 - **Mask Figures**: both training and test, masks and cumulative masks. 32 33 34 """ 35 36 def __init__( 37 self, 38 test_masks_plot_dir: str | None, 39 test_cumulative_masks_plot_dir: str | None, 40 training_masks_plot_dir: str | None, 41 plot_training_mask_every_n_steps: int | None, 42 ) -> None: 43 r"""Initialise the `HATMetricsCallback`. 44 45 **Args:** 46 - **test_masks_plot_dir** (`str` | `None`): the directory to save the test mask figures. If `None`, no file will be saved. 47 - **test_cumulative_masks_plot_dir** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved. 48 - **training_masks_plot_dir** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved. 49 - **plot_training_mask_every_n_steps** (`int`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_plot_dir` is not `None`. 50 """ 51 Callback.__init__(self) 52 53 # paths 54 if not os.path.exists(test_masks_plot_dir): 55 os.makedirs(test_masks_plot_dir, exist_ok=True) 56 self.test_masks_plot_dir: str = test_masks_plot_dir 57 r"""Store the directory to save the test mask figures.""" 58 if not os.path.exists(test_cumulative_masks_plot_dir): 59 os.makedirs(test_cumulative_masks_plot_dir, exist_ok=True) 60 self.test_cumulative_masks_plot_dir: str = test_cumulative_masks_plot_dir 61 r"""Store the directory to save the test cumulative mask figures.""" 62 if not os.path.exists(training_masks_plot_dir): 63 os.makedirs(training_masks_plot_dir, exist_ok=True) 64 self.training_masks_plot_dir: str = training_masks_plot_dir 65 r"""Store the directory to save the training mask figures.""" 66 67 # other settings 68 self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps 69 r"""Store the frequency of plotting training masks in terms of number of batches.""" 70 71 self.task_id: int 72 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 73 74 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 75 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT` or `AdaHAT`. 76 77 **Raises:** 78 -**TypeError**: when the `pl_module` is not `HAT` or `AdaHAT`. 79 """ 80 81 # get the current task_id from the `CLAlgorithm` object 82 self.task_id = pl_module.task_id 83 84 # sanity check 85 if not isinstance(pl_module, HAT): 86 raise TypeError( 87 "The `CLAlgorithm` should be `HAT` or `AdaHAT` to apply `HATMetricsCallback`!" 88 ) 89 90 def on_train_batch_end( 91 self, 92 trainer: Trainer, 93 pl_module: CLAlgorithm, 94 outputs: dict[str, Any], 95 batch: Any, 96 batch_idx: int, 97 ) -> None: 98 r"""Plot training mask and log network capacity after training batch. 99 100 **Args:** 101 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 102 - **batch** (`Any`): the training data batch. 103 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 104 """ 105 106 # get the mask over the model after training the batch 107 mask = outputs["mask"] 108 # get the current network capacity 109 capacity = outputs["capacity"] 110 111 # plot the mask 112 if batch_idx % self.plot_training_mask_every_n_steps == 0: 113 plot.plot_hat_mask( 114 mask=mask, 115 plot_dir=self.training_masks_plot_dir, 116 task_id=self.task_id, 117 step=batch_idx, 118 ) 119 120 # log the network capacity to Lightning loggers 121 pl_module.log( 122 f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True 123 ) 124 125 def on_test_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 126 r"""Plot test mask and cumulative mask figures.""" 127 128 # test mask 129 mask = pl_module.masks[f"{self.task_id}"] 130 plot.plot_hat_mask( 131 mask=mask, plot_dir=self.test_masks_plot_dir, task_id=self.task_id 132 ) 133 134 # cumulative mask 135 cumulative_mask = pl_module.cumulative_mask_for_previous_tasks 136 plot.plot_hat_mask( 137 mask=cumulative_mask, 138 plot_dir=self.test_cumulative_masks_plot_dir, 139 task_id=self.task_id, 140 )
Provides all actions that are related to metrics used for HAT (Hard Attention to the Task) algorithm, which include:
- Visualising mask and cumulative mask figures during training and testing as figures.
- Logging network capacity during training. See the "Evaluation Metrics" section in chapter 4.1 in AdaHAT paper for more details about network capacity.
Lightning provides self.log()
to log metrics in LightningModule
where our CLAlgorithm
based. You can put self.log()
here if you don't want to mess up the CLAlgorithm
with a huge amount of logging codes.
The callback is able to produce the following outputs:
- Mask Figures: both training and test, masks and cumulative masks.
36 def __init__( 37 self, 38 test_masks_plot_dir: str | None, 39 test_cumulative_masks_plot_dir: str | None, 40 training_masks_plot_dir: str | None, 41 plot_training_mask_every_n_steps: int | None, 42 ) -> None: 43 r"""Initialise the `HATMetricsCallback`. 44 45 **Args:** 46 - **test_masks_plot_dir** (`str` | `None`): the directory to save the test mask figures. If `None`, no file will be saved. 47 - **test_cumulative_masks_plot_dir** (`str` | `None`): the directory to save the test cumulative mask figures. If `None`, no file will be saved. 48 - **training_masks_plot_dir** (`str` | `None`): the directory to save the training mask figures. If `None`, no file will be saved. 49 - **plot_training_mask_every_n_steps** (`int`): the frequency of plotting training mask figures in terms of number of batches during training. Only applies when `training_masks_plot_dir` is not `None`. 50 """ 51 Callback.__init__(self) 52 53 # paths 54 if not os.path.exists(test_masks_plot_dir): 55 os.makedirs(test_masks_plot_dir, exist_ok=True) 56 self.test_masks_plot_dir: str = test_masks_plot_dir 57 r"""Store the directory to save the test mask figures.""" 58 if not os.path.exists(test_cumulative_masks_plot_dir): 59 os.makedirs(test_cumulative_masks_plot_dir, exist_ok=True) 60 self.test_cumulative_masks_plot_dir: str = test_cumulative_masks_plot_dir 61 r"""Store the directory to save the test cumulative mask figures.""" 62 if not os.path.exists(training_masks_plot_dir): 63 os.makedirs(training_masks_plot_dir, exist_ok=True) 64 self.training_masks_plot_dir: str = training_masks_plot_dir 65 r"""Store the directory to save the training mask figures.""" 66 67 # other settings 68 self.plot_training_mask_every_n_steps: int = plot_training_mask_every_n_steps 69 r"""Store the frequency of plotting training masks in terms of number of batches.""" 70 71 self.task_id: int 72 r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
Initialise the HATMetricsCallback
.
Args:
- test_masks_plot_dir (
str
|None
): the directory to save the test mask figures. IfNone
, no file will be saved. - test_cumulative_masks_plot_dir (
str
|None
): the directory to save the test cumulative mask figures. IfNone
, no file will be saved. - training_masks_plot_dir (
str
|None
): the directory to save the training mask figures. IfNone
, no file will be saved. - plot_training_mask_every_n_steps (
int
): the frequency of plotting training mask figures in terms of number of batches during training. Only applies whentraining_masks_plot_dir
is notNone
.
Store the frequency of plotting training masks in terms of number of batches.
Task ID counter indicating which task is being processed. Self updated during the task loop.
74 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 75 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT` or `AdaHAT`. 76 77 **Raises:** 78 -**TypeError**: when the `pl_module` is not `HAT` or `AdaHAT`. 79 """ 80 81 # get the current task_id from the `CLAlgorithm` object 82 self.task_id = pl_module.task_id 83 84 # sanity check 85 if not isinstance(pl_module, HAT): 86 raise TypeError( 87 "The `CLAlgorithm` should be `HAT` or `AdaHAT` to apply `HATMetricsCallback`!" 88 )
Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module
to be HAT
or AdaHAT
.
Raises:
-TypeError: when the pl_module
is not HAT
or AdaHAT
.
90 def on_train_batch_end( 91 self, 92 trainer: Trainer, 93 pl_module: CLAlgorithm, 94 outputs: dict[str, Any], 95 batch: Any, 96 batch_idx: int, 97 ) -> None: 98 r"""Plot training mask and log network capacity after training batch. 99 100 **Args:** 101 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`. 102 - **batch** (`Any`): the training data batch. 103 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 104 """ 105 106 # get the mask over the model after training the batch 107 mask = outputs["mask"] 108 # get the current network capacity 109 capacity = outputs["capacity"] 110 111 # plot the mask 112 if batch_idx % self.plot_training_mask_every_n_steps == 0: 113 plot.plot_hat_mask( 114 mask=mask, 115 plot_dir=self.training_masks_plot_dir, 116 task_id=self.task_id, 117 step=batch_idx, 118 ) 119 120 # log the network capacity to Lightning loggers 121 pl_module.log( 122 f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True 123 )
Plot training mask and log network capacity after training batch.
Args:
- outputs (
dict[str, Any]
): the outputs of the training step, which is the returns of thetraining_step()
method in theCLAlgorithm
. - batch (
Any
): the training data batch. - batch_idx (
int
): the index of the current batch. This is for the file name of mask figures.
125 def on_test_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 126 r"""Plot test mask and cumulative mask figures.""" 127 128 # test mask 129 mask = pl_module.masks[f"{self.task_id}"] 130 plot.plot_hat_mask( 131 mask=mask, plot_dir=self.test_masks_plot_dir, task_id=self.task_id 132 ) 133 134 # cumulative mask 135 cumulative_mask = pl_module.cumulative_mask_for_previous_tasks 136 plot.plot_hat_mask( 137 mask=cumulative_mask, 138 plot_dir=self.test_cumulative_masks_plot_dir, 139 task_id=self.task_id, 140 )
Plot test mask and cumulative mask figures.