clarena.pipelines.mtl_eval
The submodule in pipelines for multi-task learning evaluation.
1r"""The submodule in `pipelines` for multi-task learning evaluation.""" 2 3__all__ = ["MTLEvaluation"] 4 5import logging 6 7import hydra 8import lightning as L 9import torch 10from lightning import Callback, Trainer 11from lightning.pytorch.loggers import Logger 12from omegaconf import DictConfig, ListConfig 13 14from clarena.mtl_datasets import MTLDataset 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class MTLEvaluation: 21 r"""The base class for multi-task learning evaluation.""" 22 23 def __init__(self, cfg: DictConfig) -> None: 24 r""" 25 **Args:** 26 - **cfg** (`DictConfig`): the config dict for the multi-task learning evaluation. 27 """ 28 self.cfg: DictConfig = cfg 29 r"""The complete config dict.""" 30 31 MTLEvaluation.sanity_check(self) 32 33 # required config fields 34 self.eval_tasks: list[int] = ( 35 cfg.eval_tasks 36 if isinstance(cfg.eval_tasks, ListConfig) 37 else list(range(1, cfg.eval_tasks + 1)) 38 ) 39 r"""The list of task IDs to evaluate.""" 40 self.global_seed: int = cfg.global_seed 41 r"""The global seed for the entire experiment.""" 42 self.output_dir: str = cfg.output_dir 43 r"""The folder for storing the experiment results.""" 44 self.model_path: str = cfg.model_path 45 r"""The file path of the model to evaluate.""" 46 47 # components 48 self.mtl_dataset: MTLDataset 49 r"""MTL dataset object.""" 50 self.lightning_loggers: list[Logger] 51 r"""Lightning logger objects.""" 52 self.callbacks: list[Callback] 53 r"""Callback objects.""" 54 self.trainer: Trainer 55 r"""Trainer object.""" 56 57 def sanity_check(self) -> None: 58 r"""Sanity check for config.""" 59 60 # check required config fields 61 required_config_fields = [ 62 "pipeline", 63 "model_path", 64 "eval_tasks", 65 "global_seed", 66 "trainer", 67 "metrics", 68 "callbacks", 69 "output_dir", 70 # "hydra" is excluded as it doesn't appear 71 "misc", 72 ] 73 for field in required_config_fields: 74 if not self.cfg.get(field): 75 raise KeyError( 76 f"Field `{field}` is required in the experiment index config." 77 ) 78 79 # get dataset number of tasks 80 if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL": 81 cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset") 82 if cl_dataset_cfg.get("num_tasks"): 83 num_tasks = cl_dataset_cfg.get("num_tasks") 84 elif cl_dataset_cfg.get("class_split"): 85 num_tasks = len(cl_dataset_cfg.class_split) 86 elif cl_dataset_cfg.get("datasets"): 87 num_tasks = len(cl_dataset_cfg.datasets) 88 else: 89 raise KeyError( 90 "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 91 ) 92 else: 93 if self.cfg.mtl_dataset.get("num_tasks"): 94 num_tasks = self.cfg.mtl_dataset.num_tasks 95 else: 96 raise KeyError( 97 "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config." 98 ) 99 100 # check eval_tasks 101 eval_tasks = self.cfg.eval_tasks 102 if isinstance(eval_tasks, list): 103 if len(eval_tasks) < 1: 104 raise ValueError("`eval_tasks` must contain at least one task.") 105 if any(t < 1 or t > num_tasks for t in eval_tasks): 106 raise ValueError( 107 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 108 ) 109 elif isinstance(eval_tasks, int): 110 if eval_tasks < 0 or eval_tasks > num_tasks: 111 raise ValueError( 112 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 113 ) 114 else: 115 raise TypeError( 116 "`eval_tasks` must be either a list of integers or an integer." 117 ) 118 119 def instantiate_mtl_dataset( 120 self, 121 mtl_dataset_cfg: DictConfig, 122 ) -> None: 123 r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`.""" 124 pylogger.debug( 125 "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) ...", 126 mtl_dataset_cfg.get("_target_"), 127 ) 128 self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg) 129 pylogger.debug( 130 "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated! ", 131 mtl_dataset_cfg.get("_target_"), 132 ) 133 134 def instantiate_callbacks( 135 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 136 ) -> None: 137 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 138 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 139 140 # instantiate metric callbacks 141 metric_callbacks = [ 142 hydra.utils.instantiate(callback) for callback in metrics_cfg 143 ] 144 145 # instantiate other callbacks 146 other_callbacks = [ 147 hydra.utils.instantiate(callback) for callback in callbacks_cfg 148 ] 149 150 # add metric callbacks to the list of callbacks 151 self.callbacks = metric_callbacks + other_callbacks 152 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 153 154 def instantiate_trainer( 155 self, 156 trainer_cfg: DictConfig, 157 callbacks: list[Callback], 158 ) -> None: 159 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 160 161 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 162 self.trainer = hydra.utils.instantiate(trainer_cfg, callbacks=callbacks) 163 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 164 165 def set_global_seed(self, global_seed: int) -> None: 166 r"""Set the `global_seed` for the entire evaluation.""" 167 L.seed_everything(self.global_seed, workers=True) 168 pylogger.debug("Global seed is set as %d.", global_seed) 169 170 def run(self) -> None: 171 r"""The main method to run the multi-task learning evaluation.""" 172 173 self.set_global_seed(self.global_seed) 174 175 # load the model from file 176 model = torch.load(self.model_path) 177 178 self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset) 179 self.instantiate_callbacks( 180 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 181 ) 182 self.instantiate_trainer( 183 trainer_cfg=self.cfg.trainer, 184 callbacks=self.callbacks, 185 ) # trainer should be instantiated after callbacks 186 187 # setup tasks for dataset 188 self.mtl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks) 189 190 # evaluation skipping training and validation 191 self.trainer.test( 192 model=model, 193 datamodule=self.mtl_dataset, 194 )
class
MTLEvaluation:
21class MTLEvaluation: 22 r"""The base class for multi-task learning evaluation.""" 23 24 def __init__(self, cfg: DictConfig) -> None: 25 r""" 26 **Args:** 27 - **cfg** (`DictConfig`): the config dict for the multi-task learning evaluation. 28 """ 29 self.cfg: DictConfig = cfg 30 r"""The complete config dict.""" 31 32 MTLEvaluation.sanity_check(self) 33 34 # required config fields 35 self.eval_tasks: list[int] = ( 36 cfg.eval_tasks 37 if isinstance(cfg.eval_tasks, ListConfig) 38 else list(range(1, cfg.eval_tasks + 1)) 39 ) 40 r"""The list of task IDs to evaluate.""" 41 self.global_seed: int = cfg.global_seed 42 r"""The global seed for the entire experiment.""" 43 self.output_dir: str = cfg.output_dir 44 r"""The folder for storing the experiment results.""" 45 self.model_path: str = cfg.model_path 46 r"""The file path of the model to evaluate.""" 47 48 # components 49 self.mtl_dataset: MTLDataset 50 r"""MTL dataset object.""" 51 self.lightning_loggers: list[Logger] 52 r"""Lightning logger objects.""" 53 self.callbacks: list[Callback] 54 r"""Callback objects.""" 55 self.trainer: Trainer 56 r"""Trainer object.""" 57 58 def sanity_check(self) -> None: 59 r"""Sanity check for config.""" 60 61 # check required config fields 62 required_config_fields = [ 63 "pipeline", 64 "model_path", 65 "eval_tasks", 66 "global_seed", 67 "trainer", 68 "metrics", 69 "callbacks", 70 "output_dir", 71 # "hydra" is excluded as it doesn't appear 72 "misc", 73 ] 74 for field in required_config_fields: 75 if not self.cfg.get(field): 76 raise KeyError( 77 f"Field `{field}` is required in the experiment index config." 78 ) 79 80 # get dataset number of tasks 81 if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL": 82 cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset") 83 if cl_dataset_cfg.get("num_tasks"): 84 num_tasks = cl_dataset_cfg.get("num_tasks") 85 elif cl_dataset_cfg.get("class_split"): 86 num_tasks = len(cl_dataset_cfg.class_split) 87 elif cl_dataset_cfg.get("datasets"): 88 num_tasks = len(cl_dataset_cfg.datasets) 89 else: 90 raise KeyError( 91 "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 92 ) 93 else: 94 if self.cfg.mtl_dataset.get("num_tasks"): 95 num_tasks = self.cfg.mtl_dataset.num_tasks 96 else: 97 raise KeyError( 98 "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config." 99 ) 100 101 # check eval_tasks 102 eval_tasks = self.cfg.eval_tasks 103 if isinstance(eval_tasks, list): 104 if len(eval_tasks) < 1: 105 raise ValueError("`eval_tasks` must contain at least one task.") 106 if any(t < 1 or t > num_tasks for t in eval_tasks): 107 raise ValueError( 108 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 109 ) 110 elif isinstance(eval_tasks, int): 111 if eval_tasks < 0 or eval_tasks > num_tasks: 112 raise ValueError( 113 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 114 ) 115 else: 116 raise TypeError( 117 "`eval_tasks` must be either a list of integers or an integer." 118 ) 119 120 def instantiate_mtl_dataset( 121 self, 122 mtl_dataset_cfg: DictConfig, 123 ) -> None: 124 r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`.""" 125 pylogger.debug( 126 "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) ...", 127 mtl_dataset_cfg.get("_target_"), 128 ) 129 self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg) 130 pylogger.debug( 131 "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated! ", 132 mtl_dataset_cfg.get("_target_"), 133 ) 134 135 def instantiate_callbacks( 136 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 137 ) -> None: 138 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 139 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 140 141 # instantiate metric callbacks 142 metric_callbacks = [ 143 hydra.utils.instantiate(callback) for callback in metrics_cfg 144 ] 145 146 # instantiate other callbacks 147 other_callbacks = [ 148 hydra.utils.instantiate(callback) for callback in callbacks_cfg 149 ] 150 151 # add metric callbacks to the list of callbacks 152 self.callbacks = metric_callbacks + other_callbacks 153 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 154 155 def instantiate_trainer( 156 self, 157 trainer_cfg: DictConfig, 158 callbacks: list[Callback], 159 ) -> None: 160 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 161 162 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 163 self.trainer = hydra.utils.instantiate(trainer_cfg, callbacks=callbacks) 164 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 165 166 def set_global_seed(self, global_seed: int) -> None: 167 r"""Set the `global_seed` for the entire evaluation.""" 168 L.seed_everything(self.global_seed, workers=True) 169 pylogger.debug("Global seed is set as %d.", global_seed) 170 171 def run(self) -> None: 172 r"""The main method to run the multi-task learning evaluation.""" 173 174 self.set_global_seed(self.global_seed) 175 176 # load the model from file 177 model = torch.load(self.model_path) 178 179 self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset) 180 self.instantiate_callbacks( 181 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 182 ) 183 self.instantiate_trainer( 184 trainer_cfg=self.cfg.trainer, 185 callbacks=self.callbacks, 186 ) # trainer should be instantiated after callbacks 187 188 # setup tasks for dataset 189 self.mtl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks) 190 191 # evaluation skipping training and validation 192 self.trainer.test( 193 model=model, 194 datamodule=self.mtl_dataset, 195 )
The base class for multi-task learning evaluation.
MTLEvaluation(cfg: omegaconf.dictconfig.DictConfig)
24 def __init__(self, cfg: DictConfig) -> None: 25 r""" 26 **Args:** 27 - **cfg** (`DictConfig`): the config dict for the multi-task learning evaluation. 28 """ 29 self.cfg: DictConfig = cfg 30 r"""The complete config dict.""" 31 32 MTLEvaluation.sanity_check(self) 33 34 # required config fields 35 self.eval_tasks: list[int] = ( 36 cfg.eval_tasks 37 if isinstance(cfg.eval_tasks, ListConfig) 38 else list(range(1, cfg.eval_tasks + 1)) 39 ) 40 r"""The list of task IDs to evaluate.""" 41 self.global_seed: int = cfg.global_seed 42 r"""The global seed for the entire experiment.""" 43 self.output_dir: str = cfg.output_dir 44 r"""The folder for storing the experiment results.""" 45 self.model_path: str = cfg.model_path 46 r"""The file path of the model to evaluate.""" 47 48 # components 49 self.mtl_dataset: MTLDataset 50 r"""MTL dataset object.""" 51 self.lightning_loggers: list[Logger] 52 r"""Lightning logger objects.""" 53 self.callbacks: list[Callback] 54 r"""Callback objects.""" 55 self.trainer: Trainer 56 r"""Trainer object."""
Args:
- cfg (
DictConfig): the config dict for the multi-task learning evaluation.
def
sanity_check(self) -> None:
58 def sanity_check(self) -> None: 59 r"""Sanity check for config.""" 60 61 # check required config fields 62 required_config_fields = [ 63 "pipeline", 64 "model_path", 65 "eval_tasks", 66 "global_seed", 67 "trainer", 68 "metrics", 69 "callbacks", 70 "output_dir", 71 # "hydra" is excluded as it doesn't appear 72 "misc", 73 ] 74 for field in required_config_fields: 75 if not self.cfg.get(field): 76 raise KeyError( 77 f"Field `{field}` is required in the experiment index config." 78 ) 79 80 # get dataset number of tasks 81 if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL": 82 cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset") 83 if cl_dataset_cfg.get("num_tasks"): 84 num_tasks = cl_dataset_cfg.get("num_tasks") 85 elif cl_dataset_cfg.get("class_split"): 86 num_tasks = len(cl_dataset_cfg.class_split) 87 elif cl_dataset_cfg.get("datasets"): 88 num_tasks = len(cl_dataset_cfg.datasets) 89 else: 90 raise KeyError( 91 "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 92 ) 93 else: 94 if self.cfg.mtl_dataset.get("num_tasks"): 95 num_tasks = self.cfg.mtl_dataset.num_tasks 96 else: 97 raise KeyError( 98 "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config." 99 ) 100 101 # check eval_tasks 102 eval_tasks = self.cfg.eval_tasks 103 if isinstance(eval_tasks, list): 104 if len(eval_tasks) < 1: 105 raise ValueError("`eval_tasks` must contain at least one task.") 106 if any(t < 1 or t > num_tasks for t in eval_tasks): 107 raise ValueError( 108 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 109 ) 110 elif isinstance(eval_tasks, int): 111 if eval_tasks < 0 or eval_tasks > num_tasks: 112 raise ValueError( 113 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 114 ) 115 else: 116 raise TypeError( 117 "`eval_tasks` must be either a list of integers or an integer." 118 )
Sanity check for config.
def
instantiate_mtl_dataset(self, mtl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
120 def instantiate_mtl_dataset( 121 self, 122 mtl_dataset_cfg: DictConfig, 123 ) -> None: 124 r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`.""" 125 pylogger.debug( 126 "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) ...", 127 mtl_dataset_cfg.get("_target_"), 128 ) 129 self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg) 130 pylogger.debug( 131 "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated! ", 132 mtl_dataset_cfg.get("_target_"), 133 )
Instantiate the MTL dataset object from mtl_dataset_cfg.
def
instantiate_callbacks( self, metrics_cfg: omegaconf.dictconfig.DictConfig, callbacks_cfg: omegaconf.dictconfig.DictConfig) -> None:
135 def instantiate_callbacks( 136 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 137 ) -> None: 138 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 139 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 140 141 # instantiate metric callbacks 142 metric_callbacks = [ 143 hydra.utils.instantiate(callback) for callback in metrics_cfg 144 ] 145 146 # instantiate other callbacks 147 other_callbacks = [ 148 hydra.utils.instantiate(callback) for callback in callbacks_cfg 149 ] 150 151 # add metric callbacks to the list of callbacks 152 self.callbacks = metric_callbacks + other_callbacks 153 pylogger.debug("Callbacks (lightning.Callback) instantiated!")
Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.
def
instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
155 def instantiate_trainer( 156 self, 157 trainer_cfg: DictConfig, 158 callbacks: list[Callback], 159 ) -> None: 160 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 161 162 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 163 self.trainer = hydra.utils.instantiate(trainer_cfg, callbacks=callbacks) 164 pylogger.debug("Trainer (lightning.Trainer) instantiated!")
Instantiate the trainer object from trainer_cfg and callbacks.
def
set_global_seed(self, global_seed: int) -> None:
166 def set_global_seed(self, global_seed: int) -> None: 167 r"""Set the `global_seed` for the entire evaluation.""" 168 L.seed_everything(self.global_seed, workers=True) 169 pylogger.debug("Global seed is set as %d.", global_seed)
Set the global_seed for the entire evaluation.
def
run(self) -> None:
171 def run(self) -> None: 172 r"""The main method to run the multi-task learning evaluation.""" 173 174 self.set_global_seed(self.global_seed) 175 176 # load the model from file 177 model = torch.load(self.model_path) 178 179 self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset) 180 self.instantiate_callbacks( 181 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 182 ) 183 self.instantiate_trainer( 184 trainer_cfg=self.cfg.trainer, 185 callbacks=self.callbacks, 186 ) # trainer should be instantiated after callbacks 187 188 # setup tasks for dataset 189 self.mtl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks) 190 191 # evaluation skipping training and validation 192 self.trainer.test( 193 model=model, 194 datamodule=self.mtl_dataset, 195 )
The main method to run the multi-task learning evaluation.