clarena.metrics.hat_network_capacity
The submodule in metrics
for HATNetworkCapacity
.
1r""" 2The submodule in `metrics` for `HATNetworkCapacity`. 3""" 4 5__all__ = ["HATNetworkCapacity"] 6 7import logging 8from typing import Any 9 10from lightning import Trainer 11from lightning.pytorch.utilities import rank_zero_only 12 13from clarena.cl_algorithms import HAT 14from clarena.metrics import MetricCallback 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class HATNetworkCapacity(MetricCallback): 21 r"""Provides all actions that are related to network capacity of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include: 22 23 - Logging network capacity during training. See the "Evaluation Metrics" section in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) for more details about network capacity. 24 25 """ 26 27 def __init__( 28 self, 29 save_dir: str, 30 ) -> None: 31 r""" 32 **Args:** 33 - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder. 34 """ 35 super().__init__(save_dir=save_dir) 36 37 # task ID control 38 self.task_id: int 39 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 40 41 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 42 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 43 44 **Raises:** 45 -**TypeError**: when the `pl_module` is not `HAT`. 46 """ 47 48 # get the current task_id from the `CLAlgorithm` object 49 self.task_id = pl_module.task_id 50 51 # sanity check 52 if not isinstance(pl_module, HAT): 53 raise TypeError( 54 "The `CLAlgorithm` should be `HAT` to apply `HATMetricCallback`!" 55 ) 56 57 @rank_zero_only 58 def on_train_batch_end( 59 self, 60 trainer: Trainer, 61 pl_module: HAT, 62 outputs: dict[str, Any], 63 batch: Any, 64 batch_idx: int, 65 ) -> None: 66 r"""Plot training mask, adjustment rate and log network capacity after training batch. 67 68 **Args:** 69 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 70 - **batch** (`Any`): the training data batch. 71 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 72 """ 73 74 # get the current network capacity 75 capacity = outputs["capacity"] 76 77 # log the network capacity to Lightning loggers 78 pl_module.log( 79 f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True 80 )
class
HATNetworkCapacity(clarena.metrics.base.MetricCallback):
21class HATNetworkCapacity(MetricCallback): 22 r"""Provides all actions that are related to network capacity of [HAT (Hard Attention to the Task)](http://proceedings.mlr.press/v80/serra18a) algorithm and its extensions, which include: 23 24 - Logging network capacity during training. See the "Evaluation Metrics" section in Sec. 4.1 in the [AdaHAT paper](https://link.springer.com/chapter/10.1007/978-3-031-70352-2_9) for more details about network capacity. 25 26 """ 27 28 def __init__( 29 self, 30 save_dir: str, 31 ) -> None: 32 r""" 33 **Args:** 34 - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder. 35 """ 36 super().__init__(save_dir=save_dir) 37 38 # task ID control 39 self.task_id: int 40 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 41 42 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 43 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 44 45 **Raises:** 46 -**TypeError**: when the `pl_module` is not `HAT`. 47 """ 48 49 # get the current task_id from the `CLAlgorithm` object 50 self.task_id = pl_module.task_id 51 52 # sanity check 53 if not isinstance(pl_module, HAT): 54 raise TypeError( 55 "The `CLAlgorithm` should be `HAT` to apply `HATMetricCallback`!" 56 ) 57 58 @rank_zero_only 59 def on_train_batch_end( 60 self, 61 trainer: Trainer, 62 pl_module: HAT, 63 outputs: dict[str, Any], 64 batch: Any, 65 batch_idx: int, 66 ) -> None: 67 r"""Plot training mask, adjustment rate and log network capacity after training batch. 68 69 **Args:** 70 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 71 - **batch** (`Any`): the training data batch. 72 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 73 """ 74 75 # get the current network capacity 76 capacity = outputs["capacity"] 77 78 # log the network capacity to Lightning loggers 79 pl_module.log( 80 f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True 81 )
Provides all actions that are related to network capacity of HAT (Hard Attention to the Task) algorithm and its extensions, which include:
- Logging network capacity during training. See the "Evaluation Metrics" section in Sec. 4.1 in the AdaHAT paper for more details about network capacity.
HATNetworkCapacity(save_dir: str)
28 def __init__( 29 self, 30 save_dir: str, 31 ) -> None: 32 r""" 33 **Args:** 34 - **save_dir** (`str`): the directory to save the mask figures. Better inside the output folder. 35 """ 36 super().__init__(save_dir=save_dir) 37 38 # task ID control 39 self.task_id: int 40 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
Args:
- save_dir (
str
): the directory to save the mask figures. Better inside the output folder.
task_id: int
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks
.
def
on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT) -> None:
42 def on_fit_start(self, trainer: Trainer, pl_module: HAT) -> None: 43 r"""Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the `pl_module` to be `HAT`. 44 45 **Raises:** 46 -**TypeError**: when the `pl_module` is not `HAT`. 47 """ 48 49 # get the current task_id from the `CLAlgorithm` object 50 self.task_id = pl_module.task_id 51 52 # sanity check 53 if not isinstance(pl_module, HAT): 54 raise TypeError( 55 "The `CLAlgorithm` should be `HAT` to apply `HATMetricCallback`!" 56 )
Get the current task ID in the beginning of a task's fitting (training and validation). Sanity check the pl_module
to be HAT
.
Raises:
-TypeError: when the pl_module
is not HAT
.
@rank_zero_only
def
on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.hat.HAT, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
58 @rank_zero_only 59 def on_train_batch_end( 60 self, 61 trainer: Trainer, 62 pl_module: HAT, 63 outputs: dict[str, Any], 64 batch: Any, 65 batch_idx: int, 66 ) -> None: 67 r"""Plot training mask, adjustment rate and log network capacity after training batch. 68 69 **Args:** 70 - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `HAT`. 71 - **batch** (`Any`): the training data batch. 72 - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures. 73 """ 74 75 # get the current network capacity 76 capacity = outputs["capacity"] 77 78 # log the network capacity to Lightning loggers 79 pl_module.log( 80 f"task_{self.task_id}/train/network_capacity", capacity, prog_bar=True 81 )
Plot training mask, adjustment rate and log network capacity after training batch.
Args:
- outputs (
dict[str, Any]
): the outputs of the training step, which is the returns of thetraining_step()
method in theHAT
. - batch (
Any
): the training data batch. - batch_idx (
int
): the index of the current batch. This is for the file name of mask figures.