clarena.pipelines.cl_main_eval
The submodule in pipelines for continual learning main evaluation.
1r""" 2The submodule in `pipelines` for continual learning main evaluation. 3""" 4 5__all__ = ["CLMainEvaluation"] 6 7import logging 8 9import hydra 10import lightning as L 11import torch 12from lightning import Callback, Trainer 13from omegaconf import DictConfig, ListConfig 14 15from clarena.cl_datasets import CLDataset 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class CLMainEvaluation: 22 r"""The base class for continual learning main evaluation.""" 23 24 def __init__(self, cfg: DictConfig) -> None: 25 r""" 26 **Args:** 27 - **cfg** (`DictConfig`): the config dict for the continual learning main evaluation. 28 """ 29 self.cfg: DictConfig = cfg 30 r"""The complete config dict.""" 31 32 CLMainEvaluation.sanity_check(self) 33 34 # required config fields 35 self.main_model_path: str = cfg.main_model_path 36 r"""The file path of the model to evaluate.""" 37 self.cl_paradigm: str = cfg.clmlp 38 r"""The continual learning paradigm.""" 39 self.eval_tasks: list[int] = ( 40 cfg.eval_tasks 41 if isinstance(cfg.eval_tasks, ListConfig) 42 else list(range(1, cfg.eval_tasks + 1)) 43 ) 44 r"""The list of task IDs to evaluate.""" 45 self.global_seed: int = cfg.global_seed 46 r"""The global seed for the entire experiment.""" 47 self.output_dir: str = cfg.output_dir 48 r"""The folder for storing the experiment results.""" 49 50 # components 51 self.cl_dataset: CLDataset 52 r"""CL dataset object.""" 53 self.callbacks: list[Callback] 54 r"""Callback objects.""" 55 self.trainer: Trainer 56 r"""Trainer object.""" 57 58 def sanity_check(self) -> None: 59 r"""Sanity check for config.""" 60 61 # check required config fields 62 required_config_fields = [ 63 "pipeline", 64 "main_model_path", 65 "eval_tasks", 66 "cl_paradigm", 67 "global_seed", 68 "cl_dataset", 69 "trainer", 70 "metrics", 71 "callbacks", 72 "output_dir", 73 # "hydra" is excluded as it doesn't appear 74 "misc", 75 ] 76 for field in required_config_fields: 77 if not self.cfg.get(field): 78 raise KeyError( 79 f"Field `{field}` is required in the experiment index config." 80 ) 81 82 # check cl_paradigm 83 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 84 raise ValueError( 85 f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 86 ) 87 88 # check eval_tasks 89 if self.cfg.cl_dataset.get("num_tasks"): 90 num_tasks = self.cfg.cl_dataset.get("num_tasks") 91 elif self.cfg.cl_dataset.get("class_split"): 92 num_tasks = len(self.cfg.cl_dataset.class_split) 93 elif self.cfg.cl_dataset.get("datasets"): 94 num_tasks = len(self.cfg.cl_dataset.datasets) 95 else: 96 raise KeyError( 97 "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 98 ) 99 100 eval_tasks = self.cfg.eval_tasks 101 if isinstance(eval_tasks, ListConfig): 102 if len(eval_tasks) < 1: 103 raise ValueError("`eval_tasks` must contain at least one task.") 104 if any(t < 1 or t > num_tasks for t in eval_tasks): 105 raise ValueError( 106 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 107 ) 108 elif isinstance(eval_tasks, int): 109 if eval_tasks < 0 or eval_tasks > num_tasks: 110 raise ValueError( 111 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 112 ) 113 else: 114 raise TypeError( 115 "`eval_tasks` must be either a list of integers or an integer." 116 ) 117 118 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 119 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 120 pylogger.debug( 121 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 122 cl_dataset_cfg.get("_target_"), 123 ) 124 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 125 pylogger.debug( 126 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 127 cl_dataset_cfg.get("_target_"), 128 ) 129 130 def instantiate_callbacks( 131 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 132 ) -> None: 133 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 134 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 135 136 # instantiate metric callbacks 137 metric_callbacks: list[Callback] = [ 138 hydra.utils.instantiate(callback) for callback in metrics_cfg 139 ] 140 141 # instantiate other callbacks 142 other_callbacks: list[Callback] = [ 143 hydra.utils.instantiate(callback) for callback in callbacks_cfg 144 ] 145 146 # add metric callbacks to the list of callbacks 147 self.callbacks = metric_callbacks + other_callbacks 148 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 149 150 def instantiate_trainer( 151 self, 152 trainer_cfg: DictConfig, 153 callbacks: list[Callback], 154 ) -> None: 155 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 156 157 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 158 self.trainer = hydra.utils.instantiate( 159 trainer_cfg, 160 callbacks=callbacks, 161 ) 162 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 163 164 def set_global_seed(self, global_seed: int) -> None: 165 r"""Set the `global_seed` for the entire evaluation.""" 166 L.seed_everything(self.global_seed, workers=True) 167 pylogger.debug("Global seed is set as %d.", global_seed) 168 169 def run(self) -> None: 170 r"""The main method to run the continual learning main evaluation.""" 171 172 self.set_global_seed(self.global_seed) 173 174 # load the model from file 175 model = torch.load(self.main_model_path) 176 177 # components 178 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 179 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 180 self.instantiate_callbacks( 181 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 182 ) 183 self.instantiate_trainer( 184 trainer_cfg=self.cfg.trainer, 185 callbacks=self.callbacks, 186 ) # trainer should be instantiated after callbacks 187 188 # setup tasks for dataset 189 self.cl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks) 190 191 # evaluation skipping training and validation 192 self.trainer.test( 193 model=model, 194 datamodule=self.cl_dataset, 195 )
class
CLMainEvaluation:
22class CLMainEvaluation: 23 r"""The base class for continual learning main evaluation.""" 24 25 def __init__(self, cfg: DictConfig) -> None: 26 r""" 27 **Args:** 28 - **cfg** (`DictConfig`): the config dict for the continual learning main evaluation. 29 """ 30 self.cfg: DictConfig = cfg 31 r"""The complete config dict.""" 32 33 CLMainEvaluation.sanity_check(self) 34 35 # required config fields 36 self.main_model_path: str = cfg.main_model_path 37 r"""The file path of the model to evaluate.""" 38 self.cl_paradigm: str = cfg.clmlp 39 r"""The continual learning paradigm.""" 40 self.eval_tasks: list[int] = ( 41 cfg.eval_tasks 42 if isinstance(cfg.eval_tasks, ListConfig) 43 else list(range(1, cfg.eval_tasks + 1)) 44 ) 45 r"""The list of task IDs to evaluate.""" 46 self.global_seed: int = cfg.global_seed 47 r"""The global seed for the entire experiment.""" 48 self.output_dir: str = cfg.output_dir 49 r"""The folder for storing the experiment results.""" 50 51 # components 52 self.cl_dataset: CLDataset 53 r"""CL dataset object.""" 54 self.callbacks: list[Callback] 55 r"""Callback objects.""" 56 self.trainer: Trainer 57 r"""Trainer object.""" 58 59 def sanity_check(self) -> None: 60 r"""Sanity check for config.""" 61 62 # check required config fields 63 required_config_fields = [ 64 "pipeline", 65 "main_model_path", 66 "eval_tasks", 67 "cl_paradigm", 68 "global_seed", 69 "cl_dataset", 70 "trainer", 71 "metrics", 72 "callbacks", 73 "output_dir", 74 # "hydra" is excluded as it doesn't appear 75 "misc", 76 ] 77 for field in required_config_fields: 78 if not self.cfg.get(field): 79 raise KeyError( 80 f"Field `{field}` is required in the experiment index config." 81 ) 82 83 # check cl_paradigm 84 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 85 raise ValueError( 86 f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 87 ) 88 89 # check eval_tasks 90 if self.cfg.cl_dataset.get("num_tasks"): 91 num_tasks = self.cfg.cl_dataset.get("num_tasks") 92 elif self.cfg.cl_dataset.get("class_split"): 93 num_tasks = len(self.cfg.cl_dataset.class_split) 94 elif self.cfg.cl_dataset.get("datasets"): 95 num_tasks = len(self.cfg.cl_dataset.datasets) 96 else: 97 raise KeyError( 98 "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 99 ) 100 101 eval_tasks = self.cfg.eval_tasks 102 if isinstance(eval_tasks, ListConfig): 103 if len(eval_tasks) < 1: 104 raise ValueError("`eval_tasks` must contain at least one task.") 105 if any(t < 1 or t > num_tasks for t in eval_tasks): 106 raise ValueError( 107 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 108 ) 109 elif isinstance(eval_tasks, int): 110 if eval_tasks < 0 or eval_tasks > num_tasks: 111 raise ValueError( 112 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 113 ) 114 else: 115 raise TypeError( 116 "`eval_tasks` must be either a list of integers or an integer." 117 ) 118 119 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 120 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 121 pylogger.debug( 122 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 123 cl_dataset_cfg.get("_target_"), 124 ) 125 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 126 pylogger.debug( 127 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 128 cl_dataset_cfg.get("_target_"), 129 ) 130 131 def instantiate_callbacks( 132 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 133 ) -> None: 134 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 135 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 136 137 # instantiate metric callbacks 138 metric_callbacks: list[Callback] = [ 139 hydra.utils.instantiate(callback) for callback in metrics_cfg 140 ] 141 142 # instantiate other callbacks 143 other_callbacks: list[Callback] = [ 144 hydra.utils.instantiate(callback) for callback in callbacks_cfg 145 ] 146 147 # add metric callbacks to the list of callbacks 148 self.callbacks = metric_callbacks + other_callbacks 149 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 150 151 def instantiate_trainer( 152 self, 153 trainer_cfg: DictConfig, 154 callbacks: list[Callback], 155 ) -> None: 156 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 157 158 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 159 self.trainer = hydra.utils.instantiate( 160 trainer_cfg, 161 callbacks=callbacks, 162 ) 163 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 164 165 def set_global_seed(self, global_seed: int) -> None: 166 r"""Set the `global_seed` for the entire evaluation.""" 167 L.seed_everything(self.global_seed, workers=True) 168 pylogger.debug("Global seed is set as %d.", global_seed) 169 170 def run(self) -> None: 171 r"""The main method to run the continual learning main evaluation.""" 172 173 self.set_global_seed(self.global_seed) 174 175 # load the model from file 176 model = torch.load(self.main_model_path) 177 178 # components 179 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 180 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 181 self.instantiate_callbacks( 182 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 183 ) 184 self.instantiate_trainer( 185 trainer_cfg=self.cfg.trainer, 186 callbacks=self.callbacks, 187 ) # trainer should be instantiated after callbacks 188 189 # setup tasks for dataset 190 self.cl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks) 191 192 # evaluation skipping training and validation 193 self.trainer.test( 194 model=model, 195 datamodule=self.cl_dataset, 196 )
The base class for continual learning main evaluation.
CLMainEvaluation(cfg: omegaconf.dictconfig.DictConfig)
25 def __init__(self, cfg: DictConfig) -> None: 26 r""" 27 **Args:** 28 - **cfg** (`DictConfig`): the config dict for the continual learning main evaluation. 29 """ 30 self.cfg: DictConfig = cfg 31 r"""The complete config dict.""" 32 33 CLMainEvaluation.sanity_check(self) 34 35 # required config fields 36 self.main_model_path: str = cfg.main_model_path 37 r"""The file path of the model to evaluate.""" 38 self.cl_paradigm: str = cfg.clmlp 39 r"""The continual learning paradigm.""" 40 self.eval_tasks: list[int] = ( 41 cfg.eval_tasks 42 if isinstance(cfg.eval_tasks, ListConfig) 43 else list(range(1, cfg.eval_tasks + 1)) 44 ) 45 r"""The list of task IDs to evaluate.""" 46 self.global_seed: int = cfg.global_seed 47 r"""The global seed for the entire experiment.""" 48 self.output_dir: str = cfg.output_dir 49 r"""The folder for storing the experiment results.""" 50 51 # components 52 self.cl_dataset: CLDataset 53 r"""CL dataset object.""" 54 self.callbacks: list[Callback] 55 r"""Callback objects.""" 56 self.trainer: Trainer 57 r"""Trainer object."""
Args:
- cfg (
DictConfig): the config dict for the continual learning main evaluation.
def
sanity_check(self) -> None:
59 def sanity_check(self) -> None: 60 r"""Sanity check for config.""" 61 62 # check required config fields 63 required_config_fields = [ 64 "pipeline", 65 "main_model_path", 66 "eval_tasks", 67 "cl_paradigm", 68 "global_seed", 69 "cl_dataset", 70 "trainer", 71 "metrics", 72 "callbacks", 73 "output_dir", 74 # "hydra" is excluded as it doesn't appear 75 "misc", 76 ] 77 for field in required_config_fields: 78 if not self.cfg.get(field): 79 raise KeyError( 80 f"Field `{field}` is required in the experiment index config." 81 ) 82 83 # check cl_paradigm 84 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 85 raise ValueError( 86 f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 87 ) 88 89 # check eval_tasks 90 if self.cfg.cl_dataset.get("num_tasks"): 91 num_tasks = self.cfg.cl_dataset.get("num_tasks") 92 elif self.cfg.cl_dataset.get("class_split"): 93 num_tasks = len(self.cfg.cl_dataset.class_split) 94 elif self.cfg.cl_dataset.get("datasets"): 95 num_tasks = len(self.cfg.cl_dataset.datasets) 96 else: 97 raise KeyError( 98 "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config." 99 ) 100 101 eval_tasks = self.cfg.eval_tasks 102 if isinstance(eval_tasks, ListConfig): 103 if len(eval_tasks) < 1: 104 raise ValueError("`eval_tasks` must contain at least one task.") 105 if any(t < 1 or t > num_tasks for t in eval_tasks): 106 raise ValueError( 107 f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}." 108 ) 109 elif isinstance(eval_tasks, int): 110 if eval_tasks < 0 or eval_tasks > num_tasks: 111 raise ValueError( 112 f"`eval_tasks` as integer must be between 0 and {num_tasks}." 113 ) 114 else: 115 raise TypeError( 116 "`eval_tasks` must be either a list of integers or an integer." 117 )
Sanity check for config.
def
instantiate_cl_dataset(self, cl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
119 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 120 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 121 pylogger.debug( 122 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 123 cl_dataset_cfg.get("_target_"), 124 ) 125 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 126 pylogger.debug( 127 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 128 cl_dataset_cfg.get("_target_"), 129 )
Instantiate the CL dataset object from cl_dataset_cfg.
def
instantiate_callbacks( self, metrics_cfg: omegaconf.dictconfig.DictConfig, callbacks_cfg: omegaconf.dictconfig.DictConfig) -> None:
131 def instantiate_callbacks( 132 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 133 ) -> None: 134 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 135 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 136 137 # instantiate metric callbacks 138 metric_callbacks: list[Callback] = [ 139 hydra.utils.instantiate(callback) for callback in metrics_cfg 140 ] 141 142 # instantiate other callbacks 143 other_callbacks: list[Callback] = [ 144 hydra.utils.instantiate(callback) for callback in callbacks_cfg 145 ] 146 147 # add metric callbacks to the list of callbacks 148 self.callbacks = metric_callbacks + other_callbacks 149 pylogger.debug("Callbacks (lightning.Callback) instantiated!")
Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.
def
instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
151 def instantiate_trainer( 152 self, 153 trainer_cfg: DictConfig, 154 callbacks: list[Callback], 155 ) -> None: 156 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 157 158 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 159 self.trainer = hydra.utils.instantiate( 160 trainer_cfg, 161 callbacks=callbacks, 162 ) 163 pylogger.debug("Trainer (lightning.Trainer) instantiated!")
Instantiate the trainer object from trainer_cfg and callbacks.
def
set_global_seed(self, global_seed: int) -> None:
165 def set_global_seed(self, global_seed: int) -> None: 166 r"""Set the `global_seed` for the entire evaluation.""" 167 L.seed_everything(self.global_seed, workers=True) 168 pylogger.debug("Global seed is set as %d.", global_seed)
Set the global_seed for the entire evaluation.
def
run(self) -> None:
170 def run(self) -> None: 171 r"""The main method to run the continual learning main evaluation.""" 172 173 self.set_global_seed(self.global_seed) 174 175 # load the model from file 176 model = torch.load(self.main_model_path) 177 178 # components 179 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 180 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 181 self.instantiate_callbacks( 182 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 183 ) 184 self.instantiate_trainer( 185 trainer_cfg=self.cfg.trainer, 186 callbacks=self.callbacks, 187 ) # trainer should be instantiated after callbacks 188 189 # setup tasks for dataset 190 self.cl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks) 191 192 # evaluation skipping training and validation 193 self.trainer.test( 194 model=model, 195 datamodule=self.cl_dataset, 196 )
The main method to run the continual learning main evaluation.