clarena.pipelines.stl_eval
The submodule in pipelines for single-task learning evaluation.
1r"""The submodule in `pipelines` for single-task learning evaluation.""" 2 3__all__ = ["STLEvaluation"] 4 5import logging 6 7import hydra 8import lightning as L 9import torch 10from lightning import Callback, Trainer 11from lightning.pytorch.loggers import Logger 12from omegaconf import DictConfig 13 14from clarena.pipelines import STLExperiment 15from clarena.stl_datasets import STLDataset 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class STLEvaluation: 22 r"""The base class for single-task learning evaluation.""" 23 24 def __init__(self, cfg: DictConfig) -> None: 25 r""" 26 **Args:** 27 - **cfg** (`DictConfig`): the config dict for the single-task learning evaluation. 28 """ 29 self.cfg: DictConfig = cfg 30 r"""The complete config dict.""" 31 32 STLEvaluation.sanity_check(self) 33 34 # required config fields 35 self.global_seed: int = cfg.global_seed 36 r"""The global seed for the entire experiment.""" 37 self.output_dir: str = cfg.output_dir 38 r"""The folder for storing the experiment results.""" 39 self.model_path: str = cfg.model_path 40 r"""The file path of the model to evaluate.""" 41 42 # components 43 self.stl_dataset: STLDataset 44 r"""STL dataset object.""" 45 self.lightning_loggers: list[Logger] 46 r"""Lightning logger objects.""" 47 self.callbacks: list[Callback] 48 r"""Callback objects.""" 49 self.trainer: Trainer 50 r"""Trainer object.""" 51 52 def sanity_check(self) -> None: 53 r"""Sanity check for config.""" 54 55 # check required config fields 56 required_config_fields = [ 57 "pipeline", 58 "model_path", 59 "global_seed", 60 "stl_dataset", 61 "trainer", 62 "metrics", 63 "callbacks", 64 "output_dir", 65 # "hydra" is excluded as it doesn't appear 66 "misc", 67 ] 68 for field in required_config_fields: 69 if not self.cfg.get(field): 70 raise KeyError( 71 f"Field `{field}` is required in the experiment index config." 72 ) 73 74 def instantiate_stl_dataset( 75 self, 76 stl_dataset_cfg: DictConfig, 77 ) -> None: 78 r"""Instantiate the STL dataset object from `stl_dataset_cfg`.""" 79 pylogger.debug( 80 "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...", 81 stl_dataset_cfg.get("_target_"), 82 ) 83 self.stl_dataset = hydra.utils.instantiate( 84 stl_dataset_cfg, 85 ) 86 pylogger.debug( 87 "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!", 88 stl_dataset_cfg.get("_target_"), 89 ) 90 91 def instantiate_callbacks( 92 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 93 ) -> None: 94 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 95 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 96 97 # instantiate metric callbacks 98 metric_callbacks = [ 99 hydra.utils.instantiate(callback) for callback in metrics_cfg 100 ] 101 102 # instantiate other callbacks 103 other_callbacks = [ 104 hydra.utils.instantiate(callback) for callback in callbacks_cfg 105 ] 106 107 # add metric callbacks to the list of callbacks 108 self.callbacks = metric_callbacks + other_callbacks 109 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 110 111 def instantiate_trainer( 112 self, 113 trainer_cfg: DictConfig, 114 callbacks: list[Callback], 115 ) -> None: 116 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 117 118 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 119 self.trainer = hydra.utils.instantiate( 120 trainer_cfg, callbacks=callbacks 121 ) 122 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 123 124 def set_global_seed(self, global_seed: int) -> None: 125 r"""Set the `global_seed` for the entire evaluation.""" 126 L.seed_everything(self.global_seed, workers=True) 127 pylogger.debug("Global seed is set as %d.", global_seed) 128 129 def run(self) -> None: 130 r"""The main method to run the single-task learning experiment.""" 131 self.set_global_seed(self.global_seed) 132 133 # load the model from file 134 model = torch.load(self.model_path) 135 136 self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset) 137 self.instantiate_callbacks( 138 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 139 ) 140 self.instantiate_trainer( 141 trainer_cfg=self.cfg.trainer, 142 callbacks=self.callbacks, 143 ) # trainer should be instantiated after callbacks 144 145 # setup task for dataset 146 self.stl_dataset.setup_task() 147 148 # evaluation skipping training and validation 149 self.trainer.test( 150 model=model, 151 datamodule=self.stl_dataset, 152 )
class
STLEvaluation:
22class STLEvaluation: 23 r"""The base class for single-task learning evaluation.""" 24 25 def __init__(self, cfg: DictConfig) -> None: 26 r""" 27 **Args:** 28 - **cfg** (`DictConfig`): the config dict for the single-task learning evaluation. 29 """ 30 self.cfg: DictConfig = cfg 31 r"""The complete config dict.""" 32 33 STLEvaluation.sanity_check(self) 34 35 # required config fields 36 self.global_seed: int = cfg.global_seed 37 r"""The global seed for the entire experiment.""" 38 self.output_dir: str = cfg.output_dir 39 r"""The folder for storing the experiment results.""" 40 self.model_path: str = cfg.model_path 41 r"""The file path of the model to evaluate.""" 42 43 # components 44 self.stl_dataset: STLDataset 45 r"""STL dataset object.""" 46 self.lightning_loggers: list[Logger] 47 r"""Lightning logger objects.""" 48 self.callbacks: list[Callback] 49 r"""Callback objects.""" 50 self.trainer: Trainer 51 r"""Trainer object.""" 52 53 def sanity_check(self) -> None: 54 r"""Sanity check for config.""" 55 56 # check required config fields 57 required_config_fields = [ 58 "pipeline", 59 "model_path", 60 "global_seed", 61 "stl_dataset", 62 "trainer", 63 "metrics", 64 "callbacks", 65 "output_dir", 66 # "hydra" is excluded as it doesn't appear 67 "misc", 68 ] 69 for field in required_config_fields: 70 if not self.cfg.get(field): 71 raise KeyError( 72 f"Field `{field}` is required in the experiment index config." 73 ) 74 75 def instantiate_stl_dataset( 76 self, 77 stl_dataset_cfg: DictConfig, 78 ) -> None: 79 r"""Instantiate the STL dataset object from `stl_dataset_cfg`.""" 80 pylogger.debug( 81 "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...", 82 stl_dataset_cfg.get("_target_"), 83 ) 84 self.stl_dataset = hydra.utils.instantiate( 85 stl_dataset_cfg, 86 ) 87 pylogger.debug( 88 "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!", 89 stl_dataset_cfg.get("_target_"), 90 ) 91 92 def instantiate_callbacks( 93 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 94 ) -> None: 95 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 96 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 97 98 # instantiate metric callbacks 99 metric_callbacks = [ 100 hydra.utils.instantiate(callback) for callback in metrics_cfg 101 ] 102 103 # instantiate other callbacks 104 other_callbacks = [ 105 hydra.utils.instantiate(callback) for callback in callbacks_cfg 106 ] 107 108 # add metric callbacks to the list of callbacks 109 self.callbacks = metric_callbacks + other_callbacks 110 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 111 112 def instantiate_trainer( 113 self, 114 trainer_cfg: DictConfig, 115 callbacks: list[Callback], 116 ) -> None: 117 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 118 119 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 120 self.trainer = hydra.utils.instantiate( 121 trainer_cfg, callbacks=callbacks 122 ) 123 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 124 125 def set_global_seed(self, global_seed: int) -> None: 126 r"""Set the `global_seed` for the entire evaluation.""" 127 L.seed_everything(self.global_seed, workers=True) 128 pylogger.debug("Global seed is set as %d.", global_seed) 129 130 def run(self) -> None: 131 r"""The main method to run the single-task learning experiment.""" 132 self.set_global_seed(self.global_seed) 133 134 # load the model from file 135 model = torch.load(self.model_path) 136 137 self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset) 138 self.instantiate_callbacks( 139 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 140 ) 141 self.instantiate_trainer( 142 trainer_cfg=self.cfg.trainer, 143 callbacks=self.callbacks, 144 ) # trainer should be instantiated after callbacks 145 146 # setup task for dataset 147 self.stl_dataset.setup_task() 148 149 # evaluation skipping training and validation 150 self.trainer.test( 151 model=model, 152 datamodule=self.stl_dataset, 153 )
The base class for single-task learning evaluation.
STLEvaluation(cfg: omegaconf.dictconfig.DictConfig)
25 def __init__(self, cfg: DictConfig) -> None: 26 r""" 27 **Args:** 28 - **cfg** (`DictConfig`): the config dict for the single-task learning evaluation. 29 """ 30 self.cfg: DictConfig = cfg 31 r"""The complete config dict.""" 32 33 STLEvaluation.sanity_check(self) 34 35 # required config fields 36 self.global_seed: int = cfg.global_seed 37 r"""The global seed for the entire experiment.""" 38 self.output_dir: str = cfg.output_dir 39 r"""The folder for storing the experiment results.""" 40 self.model_path: str = cfg.model_path 41 r"""The file path of the model to evaluate.""" 42 43 # components 44 self.stl_dataset: STLDataset 45 r"""STL dataset object.""" 46 self.lightning_loggers: list[Logger] 47 r"""Lightning logger objects.""" 48 self.callbacks: list[Callback] 49 r"""Callback objects.""" 50 self.trainer: Trainer 51 r"""Trainer object."""
Args:
- cfg (
DictConfig): the config dict for the single-task learning evaluation.
def
sanity_check(self) -> None:
53 def sanity_check(self) -> None: 54 r"""Sanity check for config.""" 55 56 # check required config fields 57 required_config_fields = [ 58 "pipeline", 59 "model_path", 60 "global_seed", 61 "stl_dataset", 62 "trainer", 63 "metrics", 64 "callbacks", 65 "output_dir", 66 # "hydra" is excluded as it doesn't appear 67 "misc", 68 ] 69 for field in required_config_fields: 70 if not self.cfg.get(field): 71 raise KeyError( 72 f"Field `{field}` is required in the experiment index config." 73 )
Sanity check for config.
def
instantiate_stl_dataset(self, stl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
75 def instantiate_stl_dataset( 76 self, 77 stl_dataset_cfg: DictConfig, 78 ) -> None: 79 r"""Instantiate the STL dataset object from `stl_dataset_cfg`.""" 80 pylogger.debug( 81 "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...", 82 stl_dataset_cfg.get("_target_"), 83 ) 84 self.stl_dataset = hydra.utils.instantiate( 85 stl_dataset_cfg, 86 ) 87 pylogger.debug( 88 "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!", 89 stl_dataset_cfg.get("_target_"), 90 )
Instantiate the STL dataset object from stl_dataset_cfg.
def
instantiate_callbacks( self, metrics_cfg: omegaconf.dictconfig.DictConfig, callbacks_cfg: omegaconf.dictconfig.DictConfig) -> None:
92 def instantiate_callbacks( 93 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 94 ) -> None: 95 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 96 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 97 98 # instantiate metric callbacks 99 metric_callbacks = [ 100 hydra.utils.instantiate(callback) for callback in metrics_cfg 101 ] 102 103 # instantiate other callbacks 104 other_callbacks = [ 105 hydra.utils.instantiate(callback) for callback in callbacks_cfg 106 ] 107 108 # add metric callbacks to the list of callbacks 109 self.callbacks = metric_callbacks + other_callbacks 110 pylogger.debug("Callbacks (lightning.Callback) instantiated!")
Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.
def
instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
112 def instantiate_trainer( 113 self, 114 trainer_cfg: DictConfig, 115 callbacks: list[Callback], 116 ) -> None: 117 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 118 119 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 120 self.trainer = hydra.utils.instantiate( 121 trainer_cfg, callbacks=callbacks 122 ) 123 pylogger.debug("Trainer (lightning.Trainer) instantiated!")
Instantiate the trainer object from trainer_cfg, lightning_loggers, and callbacks.
def
set_global_seed(self, global_seed: int) -> None:
125 def set_global_seed(self, global_seed: int) -> None: 126 r"""Set the `global_seed` for the entire evaluation.""" 127 L.seed_everything(self.global_seed, workers=True) 128 pylogger.debug("Global seed is set as %d.", global_seed)
Set the global_seed for the entire evaluation.
def
run(self) -> None:
130 def run(self) -> None: 131 r"""The main method to run the single-task learning experiment.""" 132 self.set_global_seed(self.global_seed) 133 134 # load the model from file 135 model = torch.load(self.model_path) 136 137 self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset) 138 self.instantiate_callbacks( 139 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 140 ) 141 self.instantiate_trainer( 142 trainer_cfg=self.cfg.trainer, 143 callbacks=self.callbacks, 144 ) # trainer should be instantiated after callbacks 145 146 # setup task for dataset 147 self.stl_dataset.setup_task() 148 149 # evaluation skipping training and validation 150 self.trainer.test( 151 model=model, 152 datamodule=self.stl_dataset, 153 )
The main method to run the single-task learning experiment.