clarena.pipelines.cul_full_eval
The submodule in pipelines for continual unlearning full evaluation.
1r""" 2The submodule in `pipelines` for continual unlearning full evaluation. 3""" 4 5__all__ = ["CULFullEvaluation"] 6 7import logging 8 9import hydra 10import lightning as L 11import torch 12from lightning import Callback, Trainer 13from lightning.pytorch.loggers import Logger 14from omegaconf import DictConfig, ListConfig 15 16from clarena.cl_datasets import CLDataset 17from clarena.utils.eval import CULEvaluation 18 19# always get logger for built-in logging in each module 20pylogger = logging.getLogger(__name__) 21 22 23class CULFullEvaluation: 24 r"""The base class for continual unlearning full evaluation.""" 25 26 def __init__(self, cfg: DictConfig) -> None: 27 r"""**Args:** 28 - **cfg** (`DictConfig`): the complete config dict for the continual unlearning main evaluation. 29 """ 30 31 self.cfg: DictConfig = cfg 32 r"""The complete config dict.""" 33 34 CULFullEvaluation.sanity_check(self) 35 36 # required config fields 37 self.main_model_path: str = cfg.main_model_path 38 r"""The path to the model file to load the main model from.""" 39 self.refretrain_model_path: str = cfg.refretrain_model_path 40 r"""The path to the model file to load the reference retrain model from.""" 41 self.if_run_reforiginal: bool = cfg.get("if_run_reforiginal") is not False 42 r"""Whether reference original evaluation is enabled.""" 43 self.reforiginal_model_path: str | None = ( 44 cfg.reforiginal_model_path 45 if self.if_run_reforiginal and cfg.get("reforiginal_model_path") 46 else None 47 ) 48 r"""The path to the model file to load the reference original model from.""" 49 self.cl_paradigm: str = cfg.cl_paradigm 50 r"""The continual learning paradigm.""" 51 52 self.dd_eval_tasks: list[int] = ( 53 list(cfg.dd_eval_tasks) 54 if isinstance(cfg.dd_eval_tasks, ListConfig) 55 else list(range(1, cfg.dd_eval_tasks + 1)) 56 ) 57 r"""The list of tasks to be evaluated for DD.""" 58 ag_eval_tasks_cfg = cfg.get("ag_eval_tasks") 59 self.ag_eval_tasks: list[int] = ( 60 ( 61 list(ag_eval_tasks_cfg) 62 if isinstance(ag_eval_tasks_cfg, ListConfig) 63 else list(range(1, ag_eval_tasks_cfg + 1)) 64 ) 65 if self.reforiginal_model_path and ag_eval_tasks_cfg is not None 66 else [] 67 ) 68 r"""The list of tasks to be evaluated for AG.""" 69 self.global_seed: int = cfg.global_seed 70 r"""The global seed for the entire experiment.""" 71 self.output_dir: str = cfg.output_dir 72 r"""The folder for storing the experiment results.""" 73 74 # components 75 self.cl_dataset: CLDataset 76 r"""CL dataset object.""" 77 self.evaluation_module: CULEvaluation 78 r"""Evaluation module for continual unlearning full evaluation.""" 79 self.lightning_loggers: list[Logger] 80 r"""Lightning logger objects.""" 81 self.callbacks: list[Callback] 82 r"""Callback objects.""" 83 self.trainer: Trainer 84 r"""Trainer object.""" 85 86 def sanity_check(self) -> None: 87 r"""Sanity check for config.""" 88 89 # check required config fields 90 required_config_fields = [ 91 "pipeline", 92 "main_model_path", 93 "cl_paradigm", 94 "global_seed", 95 "cl_dataset", 96 "trainer", 97 "metrics", 98 "callbacks", 99 "output_dir", 100 # "hydra" is excluded as it doesn't appear 101 "misc", 102 ] 103 for field in required_config_fields: 104 if not self.cfg.get(field): 105 raise KeyError( 106 f"Field `{field}` is required in the experiment index config." 107 ) 108 109 # check cl_paradigm 110 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 111 raise ValueError( 112 f"Field cl_paradigm should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 113 ) 114 115 # warn if any reference experiment result is not provided 116 if not self.cfg.get("refretrain_model_path"): 117 pylogger.warning( 118 "`refretrain_model_path` not provided. Distribution Distance (DD) cannot be calculated." 119 ) 120 121 if self.cfg.get("if_run_reforiginal") is False: 122 pylogger.info( 123 "`if_run_reforiginal` is false. Skip loading the reference original model and AG evaluation." 124 ) 125 elif not self.cfg.get("reforiginal_model_path"): 126 pylogger.warning( 127 "`reforiginal_model_path` not provided. Accuracy Gain (AG) cannot be calculated." 128 ) 129 130 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 131 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 132 pylogger.debug( 133 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 134 cl_dataset_cfg.get("_target_"), 135 ) 136 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 137 pylogger.debug( 138 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 139 cl_dataset_cfg.get("_target_"), 140 ) 141 142 def instantiate_evaluation_module(self) -> None: 143 r"""Instantiate the evaluation module object.""" 144 pylogger.debug( 145 "Instantiating evaluation module (clarena.utils.eval.CULEvaluation)...", 146 ) 147 148 pylogger.debug("Loading main model from %s.", self.main_model_path) 149 pylogger.debug( 150 "Loading reference retrain model from %s.", self.refretrain_model_path 151 ) 152 if self.reforiginal_model_path: 153 pylogger.debug( 154 "Loading reference original model from %s.", 155 self.reforiginal_model_path, 156 ) 157 158 # NOTE: PyTorch >= 2.6 defaults to weights_only=True, which blocks loading custom classes. 159 # We explicitly set weights_only=False to allow loading full objects from our own checkpoints. 160 main_model = torch.load( 161 self.main_model_path, map_location="cpu", weights_only=False 162 ) 163 refretrain_model = torch.load( 164 self.refretrain_model_path, map_location="cpu", weights_only=False 165 ) 166 reforiginal_model = ( 167 torch.load( 168 self.reforiginal_model_path, map_location="cpu", weights_only=False 169 ) 170 if self.reforiginal_model_path 171 else None 172 ) 173 174 pylogger.debug("Loaded main model type: %s", type(main_model)) 175 pylogger.debug("Loaded reference retrain model type: %s", type(refretrain_model)) 176 if reforiginal_model is not None: 177 pylogger.debug( 178 "Loaded reference original model type: %s", type(reforiginal_model) 179 ) 180 181 self.evaluation_module = CULEvaluation( 182 main_model=main_model, 183 refretrain_model=refretrain_model, 184 reforiginal_model=reforiginal_model, 185 dd_eval_task_ids=self.dd_eval_tasks, 186 ag_eval_task_ids=self.ag_eval_tasks, 187 ) 188 pylogger.debug( 189 "Evaluation module (clarena.utils.eval.CULEvaluation) instantiated!" 190 ) 191 192 def instantiate_callbacks( 193 self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig 194 ) -> None: 195 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 196 pylogger.debug( 197 "Instantiating callbacks (lightning.Callback)...", 198 ) 199 200 enabled_metric_cfgs = [] 201 for callback in metrics_cfg: 202 if ( 203 callback.get("_target_") == "clarena.metrics.CULAccuracyGain" 204 and (not self.reforiginal_model_path or not self.ag_eval_tasks) 205 ): 206 pylogger.info( 207 "Skipping `clarena.metrics.CULAccuracyGain` because reference original evaluation is disabled." 208 ) 209 continue 210 enabled_metric_cfgs.append(callback) 211 212 # instantiate metric callbacks 213 metric_callbacks = [ 214 hydra.utils.instantiate(callback) for callback in enabled_metric_cfgs 215 ] 216 217 # instantiate other callbacks 218 other_callbacks = [ 219 hydra.utils.instantiate(callback) for callback in callbacks_cfg 220 ] 221 222 # add metric callbacks to the list of callbacks 223 self.callbacks = metric_callbacks + other_callbacks 224 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 225 226 def instantiate_trainer( 227 self, 228 trainer_cfg: DictConfig, 229 callbacks: list[Callback], 230 ) -> None: 231 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 232 233 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 234 self.trainer = hydra.utils.instantiate( 235 trainer_cfg, 236 callbacks=callbacks, 237 ) 238 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 239 240 def set_global_seed(self, global_seed: int) -> None: 241 r"""Set the `global_seed` for the entire evaluation.""" 242 L.seed_everything(self.global_seed, workers=True) 243 pylogger.debug("Global seed is set as %d.", global_seed) 244 245 def run(self) -> None: 246 r"""The main method to run the continual unlearning full evaluation.""" 247 248 self.set_global_seed(self.global_seed) 249 250 # components 251 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 252 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 253 self.instantiate_evaluation_module() 254 self.instantiate_callbacks( 255 metrics_cfg=self.cfg.metrics, 256 callbacks_cfg=self.cfg.callbacks, 257 ) 258 self.instantiate_trainer( 259 trainer_cfg=self.cfg.trainer, 260 callbacks=self.callbacks, 261 ) # trainer should be instantiated after callbacks 262 263 # setup tasks for dataset and evaluation module 264 self.cl_dataset.setup_tasks_eval( 265 eval_tasks=sorted(set(self.dd_eval_tasks + self.ag_eval_tasks)) 266 ) 267 268 # evaluation 269 self.trainer.test( 270 model=self.evaluation_module, 271 datamodule=self.cl_dataset, 272 ) 273 # please note this will set up last task dataset twice, which is fine
24class CULFullEvaluation: 25 r"""The base class for continual unlearning full evaluation.""" 26 27 def __init__(self, cfg: DictConfig) -> None: 28 r"""**Args:** 29 - **cfg** (`DictConfig`): the complete config dict for the continual unlearning main evaluation. 30 """ 31 32 self.cfg: DictConfig = cfg 33 r"""The complete config dict.""" 34 35 CULFullEvaluation.sanity_check(self) 36 37 # required config fields 38 self.main_model_path: str = cfg.main_model_path 39 r"""The path to the model file to load the main model from.""" 40 self.refretrain_model_path: str = cfg.refretrain_model_path 41 r"""The path to the model file to load the reference retrain model from.""" 42 self.if_run_reforiginal: bool = cfg.get("if_run_reforiginal") is not False 43 r"""Whether reference original evaluation is enabled.""" 44 self.reforiginal_model_path: str | None = ( 45 cfg.reforiginal_model_path 46 if self.if_run_reforiginal and cfg.get("reforiginal_model_path") 47 else None 48 ) 49 r"""The path to the model file to load the reference original model from.""" 50 self.cl_paradigm: str = cfg.cl_paradigm 51 r"""The continual learning paradigm.""" 52 53 self.dd_eval_tasks: list[int] = ( 54 list(cfg.dd_eval_tasks) 55 if isinstance(cfg.dd_eval_tasks, ListConfig) 56 else list(range(1, cfg.dd_eval_tasks + 1)) 57 ) 58 r"""The list of tasks to be evaluated for DD.""" 59 ag_eval_tasks_cfg = cfg.get("ag_eval_tasks") 60 self.ag_eval_tasks: list[int] = ( 61 ( 62 list(ag_eval_tasks_cfg) 63 if isinstance(ag_eval_tasks_cfg, ListConfig) 64 else list(range(1, ag_eval_tasks_cfg + 1)) 65 ) 66 if self.reforiginal_model_path and ag_eval_tasks_cfg is not None 67 else [] 68 ) 69 r"""The list of tasks to be evaluated for AG.""" 70 self.global_seed: int = cfg.global_seed 71 r"""The global seed for the entire experiment.""" 72 self.output_dir: str = cfg.output_dir 73 r"""The folder for storing the experiment results.""" 74 75 # components 76 self.cl_dataset: CLDataset 77 r"""CL dataset object.""" 78 self.evaluation_module: CULEvaluation 79 r"""Evaluation module for continual unlearning full evaluation.""" 80 self.lightning_loggers: list[Logger] 81 r"""Lightning logger objects.""" 82 self.callbacks: list[Callback] 83 r"""Callback objects.""" 84 self.trainer: Trainer 85 r"""Trainer object.""" 86 87 def sanity_check(self) -> None: 88 r"""Sanity check for config.""" 89 90 # check required config fields 91 required_config_fields = [ 92 "pipeline", 93 "main_model_path", 94 "cl_paradigm", 95 "global_seed", 96 "cl_dataset", 97 "trainer", 98 "metrics", 99 "callbacks", 100 "output_dir", 101 # "hydra" is excluded as it doesn't appear 102 "misc", 103 ] 104 for field in required_config_fields: 105 if not self.cfg.get(field): 106 raise KeyError( 107 f"Field `{field}` is required in the experiment index config." 108 ) 109 110 # check cl_paradigm 111 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 112 raise ValueError( 113 f"Field cl_paradigm should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 114 ) 115 116 # warn if any reference experiment result is not provided 117 if not self.cfg.get("refretrain_model_path"): 118 pylogger.warning( 119 "`refretrain_model_path` not provided. Distribution Distance (DD) cannot be calculated." 120 ) 121 122 if self.cfg.get("if_run_reforiginal") is False: 123 pylogger.info( 124 "`if_run_reforiginal` is false. Skip loading the reference original model and AG evaluation." 125 ) 126 elif not self.cfg.get("reforiginal_model_path"): 127 pylogger.warning( 128 "`reforiginal_model_path` not provided. Accuracy Gain (AG) cannot be calculated." 129 ) 130 131 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 132 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 133 pylogger.debug( 134 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 135 cl_dataset_cfg.get("_target_"), 136 ) 137 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 138 pylogger.debug( 139 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 140 cl_dataset_cfg.get("_target_"), 141 ) 142 143 def instantiate_evaluation_module(self) -> None: 144 r"""Instantiate the evaluation module object.""" 145 pylogger.debug( 146 "Instantiating evaluation module (clarena.utils.eval.CULEvaluation)...", 147 ) 148 149 pylogger.debug("Loading main model from %s.", self.main_model_path) 150 pylogger.debug( 151 "Loading reference retrain model from %s.", self.refretrain_model_path 152 ) 153 if self.reforiginal_model_path: 154 pylogger.debug( 155 "Loading reference original model from %s.", 156 self.reforiginal_model_path, 157 ) 158 159 # NOTE: PyTorch >= 2.6 defaults to weights_only=True, which blocks loading custom classes. 160 # We explicitly set weights_only=False to allow loading full objects from our own checkpoints. 161 main_model = torch.load( 162 self.main_model_path, map_location="cpu", weights_only=False 163 ) 164 refretrain_model = torch.load( 165 self.refretrain_model_path, map_location="cpu", weights_only=False 166 ) 167 reforiginal_model = ( 168 torch.load( 169 self.reforiginal_model_path, map_location="cpu", weights_only=False 170 ) 171 if self.reforiginal_model_path 172 else None 173 ) 174 175 pylogger.debug("Loaded main model type: %s", type(main_model)) 176 pylogger.debug("Loaded reference retrain model type: %s", type(refretrain_model)) 177 if reforiginal_model is not None: 178 pylogger.debug( 179 "Loaded reference original model type: %s", type(reforiginal_model) 180 ) 181 182 self.evaluation_module = CULEvaluation( 183 main_model=main_model, 184 refretrain_model=refretrain_model, 185 reforiginal_model=reforiginal_model, 186 dd_eval_task_ids=self.dd_eval_tasks, 187 ag_eval_task_ids=self.ag_eval_tasks, 188 ) 189 pylogger.debug( 190 "Evaluation module (clarena.utils.eval.CULEvaluation) instantiated!" 191 ) 192 193 def instantiate_callbacks( 194 self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig 195 ) -> None: 196 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 197 pylogger.debug( 198 "Instantiating callbacks (lightning.Callback)...", 199 ) 200 201 enabled_metric_cfgs = [] 202 for callback in metrics_cfg: 203 if ( 204 callback.get("_target_") == "clarena.metrics.CULAccuracyGain" 205 and (not self.reforiginal_model_path or not self.ag_eval_tasks) 206 ): 207 pylogger.info( 208 "Skipping `clarena.metrics.CULAccuracyGain` because reference original evaluation is disabled." 209 ) 210 continue 211 enabled_metric_cfgs.append(callback) 212 213 # instantiate metric callbacks 214 metric_callbacks = [ 215 hydra.utils.instantiate(callback) for callback in enabled_metric_cfgs 216 ] 217 218 # instantiate other callbacks 219 other_callbacks = [ 220 hydra.utils.instantiate(callback) for callback in callbacks_cfg 221 ] 222 223 # add metric callbacks to the list of callbacks 224 self.callbacks = metric_callbacks + other_callbacks 225 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 226 227 def instantiate_trainer( 228 self, 229 trainer_cfg: DictConfig, 230 callbacks: list[Callback], 231 ) -> None: 232 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 233 234 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 235 self.trainer = hydra.utils.instantiate( 236 trainer_cfg, 237 callbacks=callbacks, 238 ) 239 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 240 241 def set_global_seed(self, global_seed: int) -> None: 242 r"""Set the `global_seed` for the entire evaluation.""" 243 L.seed_everything(self.global_seed, workers=True) 244 pylogger.debug("Global seed is set as %d.", global_seed) 245 246 def run(self) -> None: 247 r"""The main method to run the continual unlearning full evaluation.""" 248 249 self.set_global_seed(self.global_seed) 250 251 # components 252 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 253 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 254 self.instantiate_evaluation_module() 255 self.instantiate_callbacks( 256 metrics_cfg=self.cfg.metrics, 257 callbacks_cfg=self.cfg.callbacks, 258 ) 259 self.instantiate_trainer( 260 trainer_cfg=self.cfg.trainer, 261 callbacks=self.callbacks, 262 ) # trainer should be instantiated after callbacks 263 264 # setup tasks for dataset and evaluation module 265 self.cl_dataset.setup_tasks_eval( 266 eval_tasks=sorted(set(self.dd_eval_tasks + self.ag_eval_tasks)) 267 ) 268 269 # evaluation 270 self.trainer.test( 271 model=self.evaluation_module, 272 datamodule=self.cl_dataset, 273 ) 274 # please note this will set up last task dataset twice, which is fine
The base class for continual unlearning full evaluation.
27 def __init__(self, cfg: DictConfig) -> None: 28 r"""**Args:** 29 - **cfg** (`DictConfig`): the complete config dict for the continual unlearning main evaluation. 30 """ 31 32 self.cfg: DictConfig = cfg 33 r"""The complete config dict.""" 34 35 CULFullEvaluation.sanity_check(self) 36 37 # required config fields 38 self.main_model_path: str = cfg.main_model_path 39 r"""The path to the model file to load the main model from.""" 40 self.refretrain_model_path: str = cfg.refretrain_model_path 41 r"""The path to the model file to load the reference retrain model from.""" 42 self.if_run_reforiginal: bool = cfg.get("if_run_reforiginal") is not False 43 r"""Whether reference original evaluation is enabled.""" 44 self.reforiginal_model_path: str | None = ( 45 cfg.reforiginal_model_path 46 if self.if_run_reforiginal and cfg.get("reforiginal_model_path") 47 else None 48 ) 49 r"""The path to the model file to load the reference original model from.""" 50 self.cl_paradigm: str = cfg.cl_paradigm 51 r"""The continual learning paradigm.""" 52 53 self.dd_eval_tasks: list[int] = ( 54 list(cfg.dd_eval_tasks) 55 if isinstance(cfg.dd_eval_tasks, ListConfig) 56 else list(range(1, cfg.dd_eval_tasks + 1)) 57 ) 58 r"""The list of tasks to be evaluated for DD.""" 59 ag_eval_tasks_cfg = cfg.get("ag_eval_tasks") 60 self.ag_eval_tasks: list[int] = ( 61 ( 62 list(ag_eval_tasks_cfg) 63 if isinstance(ag_eval_tasks_cfg, ListConfig) 64 else list(range(1, ag_eval_tasks_cfg + 1)) 65 ) 66 if self.reforiginal_model_path and ag_eval_tasks_cfg is not None 67 else [] 68 ) 69 r"""The list of tasks to be evaluated for AG.""" 70 self.global_seed: int = cfg.global_seed 71 r"""The global seed for the entire experiment.""" 72 self.output_dir: str = cfg.output_dir 73 r"""The folder for storing the experiment results.""" 74 75 # components 76 self.cl_dataset: CLDataset 77 r"""CL dataset object.""" 78 self.evaluation_module: CULEvaluation 79 r"""Evaluation module for continual unlearning full evaluation.""" 80 self.lightning_loggers: list[Logger] 81 r"""Lightning logger objects.""" 82 self.callbacks: list[Callback] 83 r"""Callback objects.""" 84 self.trainer: Trainer 85 r"""Trainer object."""
Args:
- cfg (
DictConfig): the complete config dict for the continual unlearning main evaluation.
The path to the model file to load the reference original model from.
Evaluation module for continual unlearning full evaluation.
87 def sanity_check(self) -> None: 88 r"""Sanity check for config.""" 89 90 # check required config fields 91 required_config_fields = [ 92 "pipeline", 93 "main_model_path", 94 "cl_paradigm", 95 "global_seed", 96 "cl_dataset", 97 "trainer", 98 "metrics", 99 "callbacks", 100 "output_dir", 101 # "hydra" is excluded as it doesn't appear 102 "misc", 103 ] 104 for field in required_config_fields: 105 if not self.cfg.get(field): 106 raise KeyError( 107 f"Field `{field}` is required in the experiment index config." 108 ) 109 110 # check cl_paradigm 111 if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]: 112 raise ValueError( 113 f"Field cl_paradigm should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!" 114 ) 115 116 # warn if any reference experiment result is not provided 117 if not self.cfg.get("refretrain_model_path"): 118 pylogger.warning( 119 "`refretrain_model_path` not provided. Distribution Distance (DD) cannot be calculated." 120 ) 121 122 if self.cfg.get("if_run_reforiginal") is False: 123 pylogger.info( 124 "`if_run_reforiginal` is false. Skip loading the reference original model and AG evaluation." 125 ) 126 elif not self.cfg.get("reforiginal_model_path"): 127 pylogger.warning( 128 "`reforiginal_model_path` not provided. Accuracy Gain (AG) cannot be calculated." 129 )
Sanity check for config.
131 def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None: 132 r"""Instantiate the CL dataset object from `cl_dataset_cfg`.""" 133 pylogger.debug( 134 "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...", 135 cl_dataset_cfg.get("_target_"), 136 ) 137 self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg) 138 pylogger.debug( 139 "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!", 140 cl_dataset_cfg.get("_target_"), 141 )
Instantiate the CL dataset object from cl_dataset_cfg.
143 def instantiate_evaluation_module(self) -> None: 144 r"""Instantiate the evaluation module object.""" 145 pylogger.debug( 146 "Instantiating evaluation module (clarena.utils.eval.CULEvaluation)...", 147 ) 148 149 pylogger.debug("Loading main model from %s.", self.main_model_path) 150 pylogger.debug( 151 "Loading reference retrain model from %s.", self.refretrain_model_path 152 ) 153 if self.reforiginal_model_path: 154 pylogger.debug( 155 "Loading reference original model from %s.", 156 self.reforiginal_model_path, 157 ) 158 159 # NOTE: PyTorch >= 2.6 defaults to weights_only=True, which blocks loading custom classes. 160 # We explicitly set weights_only=False to allow loading full objects from our own checkpoints. 161 main_model = torch.load( 162 self.main_model_path, map_location="cpu", weights_only=False 163 ) 164 refretrain_model = torch.load( 165 self.refretrain_model_path, map_location="cpu", weights_only=False 166 ) 167 reforiginal_model = ( 168 torch.load( 169 self.reforiginal_model_path, map_location="cpu", weights_only=False 170 ) 171 if self.reforiginal_model_path 172 else None 173 ) 174 175 pylogger.debug("Loaded main model type: %s", type(main_model)) 176 pylogger.debug("Loaded reference retrain model type: %s", type(refretrain_model)) 177 if reforiginal_model is not None: 178 pylogger.debug( 179 "Loaded reference original model type: %s", type(reforiginal_model) 180 ) 181 182 self.evaluation_module = CULEvaluation( 183 main_model=main_model, 184 refretrain_model=refretrain_model, 185 reforiginal_model=reforiginal_model, 186 dd_eval_task_ids=self.dd_eval_tasks, 187 ag_eval_task_ids=self.ag_eval_tasks, 188 ) 189 pylogger.debug( 190 "Evaluation module (clarena.utils.eval.CULEvaluation) instantiated!" 191 )
Instantiate the evaluation module object.
193 def instantiate_callbacks( 194 self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig 195 ) -> None: 196 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 197 pylogger.debug( 198 "Instantiating callbacks (lightning.Callback)...", 199 ) 200 201 enabled_metric_cfgs = [] 202 for callback in metrics_cfg: 203 if ( 204 callback.get("_target_") == "clarena.metrics.CULAccuracyGain" 205 and (not self.reforiginal_model_path or not self.ag_eval_tasks) 206 ): 207 pylogger.info( 208 "Skipping `clarena.metrics.CULAccuracyGain` because reference original evaluation is disabled." 209 ) 210 continue 211 enabled_metric_cfgs.append(callback) 212 213 # instantiate metric callbacks 214 metric_callbacks = [ 215 hydra.utils.instantiate(callback) for callback in enabled_metric_cfgs 216 ] 217 218 # instantiate other callbacks 219 other_callbacks = [ 220 hydra.utils.instantiate(callback) for callback in callbacks_cfg 221 ] 222 223 # add metric callbacks to the list of callbacks 224 self.callbacks = metric_callbacks + other_callbacks 225 pylogger.debug("Callbacks (lightning.Callback) instantiated!")
Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.
227 def instantiate_trainer( 228 self, 229 trainer_cfg: DictConfig, 230 callbacks: list[Callback], 231 ) -> None: 232 r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`.""" 233 234 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 235 self.trainer = hydra.utils.instantiate( 236 trainer_cfg, 237 callbacks=callbacks, 238 ) 239 pylogger.debug("Trainer (lightning.Trainer) instantiated!")
Instantiate the trainer object from trainer_cfg and callbacks.
241 def set_global_seed(self, global_seed: int) -> None: 242 r"""Set the `global_seed` for the entire evaluation.""" 243 L.seed_everything(self.global_seed, workers=True) 244 pylogger.debug("Global seed is set as %d.", global_seed)
Set the global_seed for the entire evaluation.
246 def run(self) -> None: 247 r"""The main method to run the continual unlearning full evaluation.""" 248 249 self.set_global_seed(self.global_seed) 250 251 # components 252 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 253 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 254 self.instantiate_evaluation_module() 255 self.instantiate_callbacks( 256 metrics_cfg=self.cfg.metrics, 257 callbacks_cfg=self.cfg.callbacks, 258 ) 259 self.instantiate_trainer( 260 trainer_cfg=self.cfg.trainer, 261 callbacks=self.callbacks, 262 ) # trainer should be instantiated after callbacks 263 264 # setup tasks for dataset and evaluation module 265 self.cl_dataset.setup_tasks_eval( 266 eval_tasks=sorted(set(self.dd_eval_tasks + self.ag_eval_tasks)) 267 ) 268 269 # evaluation 270 self.trainer.test( 271 model=self.evaluation_module, 272 datamodule=self.cl_dataset, 273 ) 274 # please note this will set up last task dataset twice, which is fine
The main method to run the continual unlearning full evaluation.