clarena.metrics.hat_adjustment_rate
The submodule in metrics
for HATAdjustmentRate
.
1r""" 2The submodule in `metrics` for `HATAdjustmentRate`. 3""" 4 5__all__ = ["HATAdjustmentRate"] 6 7import logging 8import os 9from typing import Any 10 11from lightning import Trainer 12from lightning.pytorch.utilities import rank_zero_only 13from matplotlib import pyplot as plt 14from matplotlib.colors import LogNorm 15from torch import Tensor 16 17from clarena.cl_algorithms import HAT 18from clarena.metrics import MetricCallback 19 20# always get logger for built-in logging in each module 21pylogger = logging.getLogger(__name__) 22 23 24class HATAdjustmentRate(MetricCallback): 25 r"""Provides all actions that are related to adjustment rate of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include: 26 27 - Visualizing adjustment rate during training as figures. 28 29 The callback is able to produce the following outputs: 30 31 - Figures of training adjustment rate. 32 33 """ 34 35 def __init__( 36 self, 37 save_dir: str, 38 plot_adjustment_rate_every_n_steps: int | None = None, 39 ) -> None: 40 r""" 41 **Args:** 42 - **save_dir** (`str` | `None`): The directory to save the adjustment rate figures. Better inside the output folder. 43 - **plot_adjustment_rate_every_n_steps** (`int` | `None`): the frequency of plotting adjustment rate figures in terms of number of batches during training. 44 """ 45 super().__init__(save_dir=save_dir) 46 47 # other settings 48 self.plot_adjustment_rate_every_n_steps: int = ( 49 plot_adjustment_rate_every_n_steps 50 ) 51 r"""The frequency of plotting adjustment rate in terms of number of batches.""" 52 53 # task ID control 54 self.task_id: int 55 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 56 57 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 58 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 59 60 **Raises:** 61 -**TypeError**: when the `pl_module` is not `HAT`. 62 """ 63 64 # get the current task_id from the `CLAlgorithm` object 65 self.task_id = pl_module.task_id 66 67 # sanity check 68 if not isinstance(pl_module, HAT): 69 raise TypeError( 70 "The `CLAlgorithm` should be `HAT` to apply `HATAdjustmentRate`!" 71 ) 72 73 @rank_zero_only 74 def on_train_batch_end( 75 self, 76 trainer: Trainer, 77 pl_module: HAT, 78 outputs: dict[str, Any], 79 batch: Any, 80 batch_idx: int, 81 ) -> None: 82 r"""Plot adjustment rate after training batch. 83 84 **Args:** 85 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 86 - **batch** (`Any`): the training data batch. 87 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 88 """ 89 90 # get the adjustment rate 91 adjustment_rate_weight = outputs["adjustment_rate_weight"] 92 adjustment_rate_bias = outputs["adjustment_rate_bias"] 93 94 # plot the adjustment rate 95 if self.save_dir is not None: 96 if self.task_id > 1: 97 if batch_idx % self.plot_adjustment_rate_every_n_steps == 0: 98 self.plot_hat_adjustment_rate( 99 adjustment_rate=adjustment_rate_weight, 100 weight_or_bias="weight", 101 step=batch_idx, 102 ) 103 self.plot_hat_adjustment_rate( 104 adjustment_rate=adjustment_rate_bias, 105 weight_or_bias="bias", 106 step=batch_idx, 107 ) 108 109 def plot_hat_adjustment_rate( 110 self, 111 adjustment_rate: dict[str, Tensor], 112 weight_or_bias: str, 113 step: int | None = None, 114 ) -> None: 115 """Plot adjustment rate in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the adjustment rate weight and adjustment rate bias (if applicable). 116 117 **Args:** 118 - **adjustment_rate** (`dict[str, Tensor]`): the adjustment rate. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. If it's adjustment rate weight, it has size same as weights. If it's adjustment rate bias, it has size same as biases. 119 - **weight_or_bias** (`str`): the type of adjustment rate. It can be either 'weight' or 'bias'. This is to form the plot name. 120 - **step** (`int`): the training step (batch index) of the adjustment rate to be plotted. This is to form the plot name. Keep `None` for not showing the step in the plot name. 121 """ 122 123 for layer_name, a in adjustment_rate.items(): 124 layer_name = layer_name.replace( 125 "/", "." 126 ) # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'. 127 128 if weight_or_bias == "bias": 129 a = a.view( 130 1, -1 131 ) # reshape the 1D mask to 2D so can be plotted by image show 132 133 fig = plt.figure() 134 plt.imshow( 135 a.detach().cpu(), 136 aspect="auto", 137 cmap="Wistia", 138 norm=LogNorm(vmin=1e-7, vmax=1e-5), 139 ) # can only convert to tensors in CPU to numpy arrays 140 plt.yticks() # hide yticks 141 plt.colorbar() 142 if step: 143 plot_name = ( 144 f"{layer_name}_{weight_or_bias}_task{self.task_id}_step{step}.png" 145 ) 146 else: 147 plot_name = f"{layer_name}_{weight_or_bias}_task{self.task_id}.png" 148 plot_path = os.path.join(self.save_dir, plot_name) 149 fig.savefig(plot_path) 150 plt.close(fig)
25class HATAdjustmentRate(MetricCallback): 26 r"""Provides all actions that are related to adjustment rate of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include: 27 28 - Visualizing adjustment rate during training as figures. 29 30 The callback is able to produce the following outputs: 31 32 - Figures of training adjustment rate. 33 34 """ 35 36 def __init__( 37 self, 38 save_dir: str, 39 plot_adjustment_rate_every_n_steps: int | None = None, 40 ) -> None: 41 r""" 42 **Args:** 43 - **save_dir** (`str` | `None`): The directory to save the adjustment rate figures. Better inside the output folder. 44 - **plot_adjustment_rate_every_n_steps** (`int` | `None`): the frequency of plotting adjustment rate figures in terms of number of batches during training. 45 """ 46 super().__init__(save_dir=save_dir) 47 48 # other settings 49 self.plot_adjustment_rate_every_n_steps: int = ( 50 plot_adjustment_rate_every_n_steps 51 ) 52 r"""The frequency of plotting adjustment rate in terms of number of batches.""" 53 54 # task ID control 55 self.task_id: int 56 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 57 58 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 59 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 60 61 **Raises:** 62 -**TypeError**: when the `pl_module` is not `HAT`. 63 """ 64 65 # get the current task_id from the `CLAlgorithm` object 66 self.task_id = pl_module.task_id 67 68 # sanity check 69 if not isinstance(pl_module, HAT): 70 raise TypeError( 71 "The `CLAlgorithm` should be `HAT` to apply `HATAdjustmentRate`!" 72 ) 73 74 @rank_zero_only 75 def on_train_batch_end( 76 self, 77 trainer: Trainer, 78 pl_module: HAT, 79 outputs: dict[str, Any], 80 batch: Any, 81 batch_idx: int, 82 ) -> None: 83 r"""Plot adjustment rate after training batch. 84 85 **Args:** 86 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 87 - **batch** (`Any`): the training data batch. 88 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 89 """ 90 91 # get the adjustment rate 92 adjustment_rate_weight = outputs["adjustment_rate_weight"] 93 adjustment_rate_bias = outputs["adjustment_rate_bias"] 94 95 # plot the adjustment rate 96 if self.save_dir is not None: 97 if self.task_id > 1: 98 if batch_idx % self.plot_adjustment_rate_every_n_steps == 0: 99 self.plot_hat_adjustment_rate( 100 adjustment_rate=adjustment_rate_weight, 101 weight_or_bias="weight", 102 step=batch_idx, 103 ) 104 self.plot_hat_adjustment_rate( 105 adjustment_rate=adjustment_rate_bias, 106 weight_or_bias="bias", 107 step=batch_idx, 108 ) 109 110 def plot_hat_adjustment_rate( 111 self, 112 adjustment_rate: dict[str, Tensor], 113 weight_or_bias: str, 114 step: int | None = None, 115 ) -> None: 116 """Plot adjustment rate in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the adjustment rate weight and adjustment rate bias (if applicable). 117 118 **Args:** 119 - **adjustment_rate** (`dict[str, Tensor]`): the adjustment rate. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. If it's adjustment rate weight, it has size same as weights. If it's adjustment rate bias, it has size same as biases. 120 - **weight_or_bias** (`str`): the type of adjustment rate. It can be either 'weight' or 'bias'. This is to form the plot name. 121 - **step** (`int`): the training step (batch index) of the adjustment rate to be plotted. This is to form the plot name. Keep `None` for not showing the step in the plot name. 122 """ 123 124 for layer_name, a in adjustment_rate.items(): 125 layer_name = layer_name.replace( 126 "/", "." 127 ) # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'. 128 129 if weight_or_bias == "bias": 130 a = a.view( 131 1, -1 132 ) # reshape the 1D mask to 2D so can be plotted by image show 133 134 fig = plt.figure() 135 plt.imshow( 136 a.detach().cpu(), 137 aspect="auto", 138 cmap="Wistia", 139 norm=LogNorm(vmin=1e-7, vmax=1e-5), 140 ) # can only convert to tensors in CPU to numpy arrays 141 plt.yticks() # hide yticks 142 plt.colorbar() 143 if step: 144 plot_name = ( 145 f"{layer_name}_{weight_or_bias}_task{self.task_id}_step{step}.png" 146 ) 147 else: 148 plot_name = f"{layer_name}_{weight_or_bias}_task{self.task_id}.png" 149 plot_path = os.path.join(self.save_dir, plot_name) 150 fig.savefig(plot_path) 151 plt.close(fig)
Provides all actions that are related to adjustment rate of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
- Visualizing adjustment rate during training as figures.
The callback is able to produce the following outputs:
- Figures of training adjustment rate.
36 def __init__( 37 self, 38 save_dir: str, 39 plot_adjustment_rate_every_n_steps: int | None = None, 40 ) -> None: 41 r""" 42 **Args:** 43 - **save_dir** (`str` | `None`): The directory to save the adjustment rate figures. Better inside the output folder. 44 - **plot_adjustment_rate_every_n_steps** (`int` | `None`): the frequency of plotting adjustment rate figures in terms of number of batches during training. 45 """ 46 super().__init__(save_dir=save_dir) 47 48 # other settings 49 self.plot_adjustment_rate_every_n_steps: int = ( 50 plot_adjustment_rate_every_n_steps 51 ) 52 r"""The frequency of plotting adjustment rate in terms of number of batches.""" 53 54 # task ID control 55 self.task_id: int 56 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
Args:
- save_dir (
str
|None
): The directory to save the adjustment rate figures. Better inside the output folder. - plot_adjustment_rate_every_n_steps (
int
|None
): the frequency of plotting adjustment rate figures in terms of number of batches during training.
The frequency of plotting adjustment rate in terms of number of batches.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks
.
58 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 59 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 60 61 **Raises:** 62 -**TypeError**: when the `pl_module` is not `HAT`. 63 """ 64 65 # get the current task_id from the `CLAlgorithm` object 66 self.task_id = pl_module.task_id 67 68 # sanity check 69 if not isinstance(pl_module, HAT): 70 raise TypeError( 71 "The `CLAlgorithm` should be `HAT` to apply `HATAdjustmentRate`!" 72 )
Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module
to be HAT
.
Raises:
-TypeError: when the pl_module
is not HAT
.
74 @rank_zero_only 75 def on_train_batch_end( 76 self, 77 trainer: Trainer, 78 pl_module: HAT, 79 outputs: dict[str, Any], 80 batch: Any, 81 batch_idx: int, 82 ) -> None: 83 r"""Plot adjustment rate after training batch. 84 85 **Args:** 86 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 87 - **batch** (`Any`): the training data batch. 88 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 89 """ 90 91 # get the adjustment rate 92 adjustment_rate_weight = outputs["adjustment_rate_weight"] 93 adjustment_rate_bias = outputs["adjustment_rate_bias"] 94 95 # plot the adjustment rate 96 if self.save_dir is not None: 97 if self.task_id > 1: 98 if batch_idx % self.plot_adjustment_rate_every_n_steps == 0: 99 self.plot_hat_adjustment_rate( 100 adjustment_rate=adjustment_rate_weight, 101 weight_or_bias="weight", 102 step=batch_idx, 103 ) 104 self.plot_hat_adjustment_rate( 105 adjustment_rate=adjustment_rate_bias, 106 weight_or_bias="bias", 107 step=batch_idx, 108 )
Plot adjustment rate after training batch.
Args:
- outputs (
dict[str, Any]
): the outputs of the training step, which is the returns of thetraining_step()
method in theHAT
. - batch (
Any
): the training data batch. - batch_idx (
int
): the index of the current batch. This is for the file name of mask figures.
110 def plot_hat_adjustment_rate( 111 self, 112 adjustment_rate: dict[str, Tensor], 113 weight_or_bias: str, 114 step: int | None = None, 115 ) -> None: 116 """Plot adjustment rate in [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a)) algorithm. This includes the adjustment rate weight and adjustment rate bias (if applicable). 117 118 **Args:** 119 - **adjustment_rate** (`dict[str, Tensor]`): the adjustment rate. Keys (`str`) are layer names and values (`Tensor`) are the adjustment rate tensors. If it's adjustment rate weight, it has size same as weights. If it's adjustment rate bias, it has size same as biases. 120 - **weight_or_bias** (`str`): the type of adjustment rate. It can be either 'weight' or 'bias'. This is to form the plot name. 121 - **step** (`int`): the training step (batch index) of the adjustment rate to be plotted. This is to form the plot name. Keep `None` for not showing the step in the plot name. 122 """ 123 124 for layer_name, a in adjustment_rate.items(): 125 layer_name = layer_name.replace( 126 "/", "." 127 ) # the layer name contains '/', which is not allowed in the file name. We replace it back with '.'. 128 129 if weight_or_bias == "bias": 130 a = a.view( 131 1, -1 132 ) # reshape the 1D mask to 2D so can be plotted by image show 133 134 fig = plt.figure() 135 plt.imshow( 136 a.detach().cpu(), 137 aspect="auto", 138 cmap="Wistia", 139 norm=LogNorm(vmin=1e-7, vmax=1e-5), 140 ) # can only convert to tensors in CPU to numpy arrays 141 plt.yticks() # hide yticks 142 plt.colorbar() 143 if step: 144 plot_name = ( 145 f"{layer_name}_{weight_or_bias}_task{self.task_id}_step{step}.png" 146 ) 147 else: 148 plot_name = f"{layer_name}_{weight_or_bias}_task{self.task_id}.png" 149 plot_path = os.path.join(self.save_dir, plot_name) 150 fig.savefig(plot_path) 151 plt.close(fig)
Plot adjustment rate in HAT (Hard Attention to the Task)) algorithm. This includes the adjustment rate weight and adjustment rate bias (if applicable).
Args:
- adjustment_rate (
dict[str, Tensor]
): the adjustment rate. Keys (str
) are layer names and values (Tensor
) are the adjustment rate tensors. If it's adjustment rate weight, it has size same as weights. If it's adjustment rate bias, it has size same as biases. - weight_or_bias (
str
): the type of adjustment rate. It can be either 'weight' or 'bias'. This is to form the plot name. - step (
int
): the training step (batch index) of the adjustment rate to be plotted. This is to form the plot name. KeepNone
for not showing the step in the plot name.