clarena.pipelines.cul_full_eval

The submodule in pipelines for continual unlearning full evaluation.

  1r"""
  2The submodule in `pipelines` for continual unlearning full evaluation.
  3"""
  4
  5__all__ = ["CULFullEvaluation"]
  6
  7import logging
  8
  9import hydra
 10import lightning as L
 11import torch
 12from lightning import Callback, Trainer
 13from lightning.pytorch.loggers import Logger
 14from omegaconf import DictConfig, ListConfig
 15
 16from clarena.cl_datasets import CLDataset
 17from clarena.utils.eval import CULEvaluation
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23class CULFullEvaluation:
 24    r"""The base class for continual unlearning full evaluation."""
 25
 26    def __init__(self, cfg: DictConfig) -> None:
 27        r"""**Args:**
 28        - **cfg** (`DictConfig`): the complete config dict for the continual unlearning main evaluation.
 29        """
 30
 31        self.cfg: DictConfig = cfg
 32        r"""The complete config dict."""
 33
 34        CULFullEvaluation.sanity_check(self)
 35
 36        # required config fields
 37        self.main_model_path: str = cfg.main_model_path
 38        r"""The path to the model file to load the main model from."""
 39        self.refretrain_model_path: str = cfg.refretrain_model_path
 40        r"""The path to the model file to load the reference retrain model from."""
 41        self.if_run_reforiginal: bool = cfg.get("if_run_reforiginal") is not False
 42        r"""Whether reference original evaluation is enabled."""
 43        self.reforiginal_model_path: str | None = (
 44            cfg.reforiginal_model_path
 45            if self.if_run_reforiginal and cfg.get("reforiginal_model_path")
 46            else None
 47        )
 48        r"""The path to the model file to load the reference original model from."""
 49        self.cl_paradigm: str = cfg.cl_paradigm
 50        r"""The continual learning paradigm."""
 51
 52        self.dd_eval_tasks: list[int] = (
 53            list(cfg.dd_eval_tasks)
 54            if isinstance(cfg.dd_eval_tasks, ListConfig)
 55            else list(range(1, cfg.dd_eval_tasks + 1))
 56        )
 57        r"""The list of tasks to be evaluated for DD."""
 58        ag_eval_tasks_cfg = cfg.get("ag_eval_tasks")
 59        self.ag_eval_tasks: list[int] = (
 60            (
 61                list(ag_eval_tasks_cfg)
 62                if isinstance(ag_eval_tasks_cfg, ListConfig)
 63                else list(range(1, ag_eval_tasks_cfg + 1))
 64            )
 65            if self.reforiginal_model_path and ag_eval_tasks_cfg is not None
 66            else []
 67        )
 68        r"""The list of tasks to be evaluated for AG."""
 69        self.global_seed: int = cfg.global_seed
 70        r"""The global seed for the entire experiment."""
 71        self.output_dir: str = cfg.output_dir
 72        r"""The folder for storing the experiment results."""
 73
 74        # components
 75        self.cl_dataset: CLDataset
 76        r"""CL dataset object."""
 77        self.evaluation_module: CULEvaluation
 78        r"""Evaluation module for continual unlearning full evaluation."""
 79        self.lightning_loggers: list[Logger]
 80        r"""Lightning logger objects."""
 81        self.callbacks: list[Callback]
 82        r"""Callback objects."""
 83        self.trainer: Trainer
 84        r"""Trainer object."""
 85
 86    def sanity_check(self) -> None:
 87        r"""Sanity check for config."""
 88
 89        # check required config fields
 90        required_config_fields = [
 91            "pipeline",
 92            "main_model_path",
 93            "cl_paradigm",
 94            "global_seed",
 95            "cl_dataset",
 96            "trainer",
 97            "metrics",
 98            "callbacks",
 99            "output_dir",
100            # "hydra" is excluded as it doesn't appear
101            "misc",
102        ]
103        for field in required_config_fields:
104            if not self.cfg.get(field):
105                raise KeyError(
106                    f"Field `{field}` is required in the experiment index config."
107                )
108
109        # check cl_paradigm
110        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
111            raise ValueError(
112                f"Field cl_paradigm should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
113            )
114
115        # warn if any reference experiment result is not provided
116        if not self.cfg.get("refretrain_model_path"):
117            pylogger.warning(
118                "`refretrain_model_path` not provided. Distribution Distance (DD) cannot be calculated."
119            )
120
121        if self.cfg.get("if_run_reforiginal") is False:
122            pylogger.info(
123                "`if_run_reforiginal` is false. Skip loading the reference original model and AG evaluation."
124            )
125        elif not self.cfg.get("reforiginal_model_path"):
126            pylogger.warning(
127                "`reforiginal_model_path` not provided. Accuracy Gain (AG) cannot be calculated."
128            )
129
130    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
131        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
132        pylogger.debug(
133            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
134            cl_dataset_cfg.get("_target_"),
135        )
136        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
137        pylogger.debug(
138            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
139            cl_dataset_cfg.get("_target_"),
140        )
141
142    def instantiate_evaluation_module(self) -> None:
143        r"""Instantiate the evaluation module object."""
144        pylogger.debug(
145            "Instantiating evaluation module (clarena.utils.eval.CULEvaluation)...",
146        )
147
148        pylogger.debug("Loading main model from %s.", self.main_model_path)
149        pylogger.debug(
150            "Loading reference retrain model from %s.", self.refretrain_model_path
151        )
152        if self.reforiginal_model_path:
153            pylogger.debug(
154                "Loading reference original model from %s.",
155                self.reforiginal_model_path,
156            )
157
158        # NOTE: PyTorch >= 2.6 defaults to weights_only=True, which blocks loading custom classes.
159        # We explicitly set weights_only=False to allow loading full objects from our own checkpoints.
160        main_model = torch.load(
161            self.main_model_path, map_location="cpu", weights_only=False
162        )
163        refretrain_model = torch.load(
164            self.refretrain_model_path, map_location="cpu", weights_only=False
165        )
166        reforiginal_model = (
167            torch.load(
168                self.reforiginal_model_path, map_location="cpu", weights_only=False
169            )
170            if self.reforiginal_model_path
171            else None
172        )
173
174        pylogger.debug("Loaded main model type: %s", type(main_model))
175        pylogger.debug("Loaded reference retrain model type: %s", type(refretrain_model))
176        if reforiginal_model is not None:
177            pylogger.debug(
178                "Loaded reference original model type: %s", type(reforiginal_model)
179            )
180
181        self.evaluation_module = CULEvaluation(
182            main_model=main_model,
183            refretrain_model=refretrain_model,
184            reforiginal_model=reforiginal_model,
185            dd_eval_task_ids=self.dd_eval_tasks,
186            ag_eval_task_ids=self.ag_eval_tasks,
187        )
188        pylogger.debug(
189            "Evaluation module (clarena.utils.eval.CULEvaluation) instantiated!"
190        )
191
192    def instantiate_callbacks(
193        self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig
194    ) -> None:
195        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
196        pylogger.debug(
197            "Instantiating callbacks (lightning.Callback)...",
198        )
199
200        enabled_metric_cfgs = []
201        for callback in metrics_cfg:
202            if (
203                callback.get("_target_") == "clarena.metrics.CULAccuracyGain"
204                and (not self.reforiginal_model_path or not self.ag_eval_tasks)
205            ):
206                pylogger.info(
207                    "Skipping `clarena.metrics.CULAccuracyGain` because reference original evaluation is disabled."
208                )
209                continue
210            enabled_metric_cfgs.append(callback)
211
212        # instantiate metric callbacks
213        metric_callbacks = [
214            hydra.utils.instantiate(callback) for callback in enabled_metric_cfgs
215        ]
216
217        # instantiate other callbacks
218        other_callbacks = [
219            hydra.utils.instantiate(callback) for callback in callbacks_cfg
220        ]
221
222        # add metric callbacks to the list of callbacks
223        self.callbacks = metric_callbacks + other_callbacks
224        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
225
226    def instantiate_trainer(
227        self,
228        trainer_cfg: DictConfig,
229        callbacks: list[Callback],
230    ) -> None:
231        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
232
233        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
234        self.trainer = hydra.utils.instantiate(
235            trainer_cfg,
236            callbacks=callbacks,
237        )
238        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
239
240    def set_global_seed(self, global_seed: int) -> None:
241        r"""Set the `global_seed` for the entire evaluation."""
242        L.seed_everything(self.global_seed, workers=True)
243        pylogger.debug("Global seed is set as %d.", global_seed)
244
245    def run(self) -> None:
246        r"""The main method to run the continual unlearning full evaluation."""
247
248        self.set_global_seed(self.global_seed)
249
250        # components
251        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
252        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
253        self.instantiate_evaluation_module()
254        self.instantiate_callbacks(
255            metrics_cfg=self.cfg.metrics,
256            callbacks_cfg=self.cfg.callbacks,
257        )
258        self.instantiate_trainer(
259            trainer_cfg=self.cfg.trainer,
260            callbacks=self.callbacks,
261        )  # trainer should be instantiated after callbacks
262
263        # setup tasks for dataset and evaluation module
264        self.cl_dataset.setup_tasks_eval(
265            eval_tasks=sorted(set(self.dd_eval_tasks + self.ag_eval_tasks))
266        )
267
268        # evaluation
269        self.trainer.test(
270            model=self.evaluation_module,
271            datamodule=self.cl_dataset,
272        )
273        # please note this will set up last task dataset twice, which is fine
class CULFullEvaluation:
 24class CULFullEvaluation:
 25    r"""The base class for continual unlearning full evaluation."""
 26
 27    def __init__(self, cfg: DictConfig) -> None:
 28        r"""**Args:**
 29        - **cfg** (`DictConfig`): the complete config dict for the continual unlearning main evaluation.
 30        """
 31
 32        self.cfg: DictConfig = cfg
 33        r"""The complete config dict."""
 34
 35        CULFullEvaluation.sanity_check(self)
 36
 37        # required config fields
 38        self.main_model_path: str = cfg.main_model_path
 39        r"""The path to the model file to load the main model from."""
 40        self.refretrain_model_path: str = cfg.refretrain_model_path
 41        r"""The path to the model file to load the reference retrain model from."""
 42        self.if_run_reforiginal: bool = cfg.get("if_run_reforiginal") is not False
 43        r"""Whether reference original evaluation is enabled."""
 44        self.reforiginal_model_path: str | None = (
 45            cfg.reforiginal_model_path
 46            if self.if_run_reforiginal and cfg.get("reforiginal_model_path")
 47            else None
 48        )
 49        r"""The path to the model file to load the reference original model from."""
 50        self.cl_paradigm: str = cfg.cl_paradigm
 51        r"""The continual learning paradigm."""
 52
 53        self.dd_eval_tasks: list[int] = (
 54            list(cfg.dd_eval_tasks)
 55            if isinstance(cfg.dd_eval_tasks, ListConfig)
 56            else list(range(1, cfg.dd_eval_tasks + 1))
 57        )
 58        r"""The list of tasks to be evaluated for DD."""
 59        ag_eval_tasks_cfg = cfg.get("ag_eval_tasks")
 60        self.ag_eval_tasks: list[int] = (
 61            (
 62                list(ag_eval_tasks_cfg)
 63                if isinstance(ag_eval_tasks_cfg, ListConfig)
 64                else list(range(1, ag_eval_tasks_cfg + 1))
 65            )
 66            if self.reforiginal_model_path and ag_eval_tasks_cfg is not None
 67            else []
 68        )
 69        r"""The list of tasks to be evaluated for AG."""
 70        self.global_seed: int = cfg.global_seed
 71        r"""The global seed for the entire experiment."""
 72        self.output_dir: str = cfg.output_dir
 73        r"""The folder for storing the experiment results."""
 74
 75        # components
 76        self.cl_dataset: CLDataset
 77        r"""CL dataset object."""
 78        self.evaluation_module: CULEvaluation
 79        r"""Evaluation module for continual unlearning full evaluation."""
 80        self.lightning_loggers: list[Logger]
 81        r"""Lightning logger objects."""
 82        self.callbacks: list[Callback]
 83        r"""Callback objects."""
 84        self.trainer: Trainer
 85        r"""Trainer object."""
 86
 87    def sanity_check(self) -> None:
 88        r"""Sanity check for config."""
 89
 90        # check required config fields
 91        required_config_fields = [
 92            "pipeline",
 93            "main_model_path",
 94            "cl_paradigm",
 95            "global_seed",
 96            "cl_dataset",
 97            "trainer",
 98            "metrics",
 99            "callbacks",
100            "output_dir",
101            # "hydra" is excluded as it doesn't appear
102            "misc",
103        ]
104        for field in required_config_fields:
105            if not self.cfg.get(field):
106                raise KeyError(
107                    f"Field `{field}` is required in the experiment index config."
108                )
109
110        # check cl_paradigm
111        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
112            raise ValueError(
113                f"Field cl_paradigm should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
114            )
115
116        # warn if any reference experiment result is not provided
117        if not self.cfg.get("refretrain_model_path"):
118            pylogger.warning(
119                "`refretrain_model_path` not provided. Distribution Distance (DD) cannot be calculated."
120            )
121
122        if self.cfg.get("if_run_reforiginal") is False:
123            pylogger.info(
124                "`if_run_reforiginal` is false. Skip loading the reference original model and AG evaluation."
125            )
126        elif not self.cfg.get("reforiginal_model_path"):
127            pylogger.warning(
128                "`reforiginal_model_path` not provided. Accuracy Gain (AG) cannot be calculated."
129            )
130
131    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
132        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
133        pylogger.debug(
134            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
135            cl_dataset_cfg.get("_target_"),
136        )
137        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
138        pylogger.debug(
139            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
140            cl_dataset_cfg.get("_target_"),
141        )
142
143    def instantiate_evaluation_module(self) -> None:
144        r"""Instantiate the evaluation module object."""
145        pylogger.debug(
146            "Instantiating evaluation module (clarena.utils.eval.CULEvaluation)...",
147        )
148
149        pylogger.debug("Loading main model from %s.", self.main_model_path)
150        pylogger.debug(
151            "Loading reference retrain model from %s.", self.refretrain_model_path
152        )
153        if self.reforiginal_model_path:
154            pylogger.debug(
155                "Loading reference original model from %s.",
156                self.reforiginal_model_path,
157            )
158
159        # NOTE: PyTorch >= 2.6 defaults to weights_only=True, which blocks loading custom classes.
160        # We explicitly set weights_only=False to allow loading full objects from our own checkpoints.
161        main_model = torch.load(
162            self.main_model_path, map_location="cpu", weights_only=False
163        )
164        refretrain_model = torch.load(
165            self.refretrain_model_path, map_location="cpu", weights_only=False
166        )
167        reforiginal_model = (
168            torch.load(
169                self.reforiginal_model_path, map_location="cpu", weights_only=False
170            )
171            if self.reforiginal_model_path
172            else None
173        )
174
175        pylogger.debug("Loaded main model type: %s", type(main_model))
176        pylogger.debug("Loaded reference retrain model type: %s", type(refretrain_model))
177        if reforiginal_model is not None:
178            pylogger.debug(
179                "Loaded reference original model type: %s", type(reforiginal_model)
180            )
181
182        self.evaluation_module = CULEvaluation(
183            main_model=main_model,
184            refretrain_model=refretrain_model,
185            reforiginal_model=reforiginal_model,
186            dd_eval_task_ids=self.dd_eval_tasks,
187            ag_eval_task_ids=self.ag_eval_tasks,
188        )
189        pylogger.debug(
190            "Evaluation module (clarena.utils.eval.CULEvaluation) instantiated!"
191        )
192
193    def instantiate_callbacks(
194        self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig
195    ) -> None:
196        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
197        pylogger.debug(
198            "Instantiating callbacks (lightning.Callback)...",
199        )
200
201        enabled_metric_cfgs = []
202        for callback in metrics_cfg:
203            if (
204                callback.get("_target_") == "clarena.metrics.CULAccuracyGain"
205                and (not self.reforiginal_model_path or not self.ag_eval_tasks)
206            ):
207                pylogger.info(
208                    "Skipping `clarena.metrics.CULAccuracyGain` because reference original evaluation is disabled."
209                )
210                continue
211            enabled_metric_cfgs.append(callback)
212
213        # instantiate metric callbacks
214        metric_callbacks = [
215            hydra.utils.instantiate(callback) for callback in enabled_metric_cfgs
216        ]
217
218        # instantiate other callbacks
219        other_callbacks = [
220            hydra.utils.instantiate(callback) for callback in callbacks_cfg
221        ]
222
223        # add metric callbacks to the list of callbacks
224        self.callbacks = metric_callbacks + other_callbacks
225        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
226
227    def instantiate_trainer(
228        self,
229        trainer_cfg: DictConfig,
230        callbacks: list[Callback],
231    ) -> None:
232        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
233
234        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
235        self.trainer = hydra.utils.instantiate(
236            trainer_cfg,
237            callbacks=callbacks,
238        )
239        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
240
241    def set_global_seed(self, global_seed: int) -> None:
242        r"""Set the `global_seed` for the entire evaluation."""
243        L.seed_everything(self.global_seed, workers=True)
244        pylogger.debug("Global seed is set as %d.", global_seed)
245
246    def run(self) -> None:
247        r"""The main method to run the continual unlearning full evaluation."""
248
249        self.set_global_seed(self.global_seed)
250
251        # components
252        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
253        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
254        self.instantiate_evaluation_module()
255        self.instantiate_callbacks(
256            metrics_cfg=self.cfg.metrics,
257            callbacks_cfg=self.cfg.callbacks,
258        )
259        self.instantiate_trainer(
260            trainer_cfg=self.cfg.trainer,
261            callbacks=self.callbacks,
262        )  # trainer should be instantiated after callbacks
263
264        # setup tasks for dataset and evaluation module
265        self.cl_dataset.setup_tasks_eval(
266            eval_tasks=sorted(set(self.dd_eval_tasks + self.ag_eval_tasks))
267        )
268
269        # evaluation
270        self.trainer.test(
271            model=self.evaluation_module,
272            datamodule=self.cl_dataset,
273        )
274        # please note this will set up last task dataset twice, which is fine

The base class for continual unlearning full evaluation.

CULFullEvaluation(cfg: omegaconf.dictconfig.DictConfig)
27    def __init__(self, cfg: DictConfig) -> None:
28        r"""**Args:**
29        - **cfg** (`DictConfig`): the complete config dict for the continual unlearning main evaluation.
30        """
31
32        self.cfg: DictConfig = cfg
33        r"""The complete config dict."""
34
35        CULFullEvaluation.sanity_check(self)
36
37        # required config fields
38        self.main_model_path: str = cfg.main_model_path
39        r"""The path to the model file to load the main model from."""
40        self.refretrain_model_path: str = cfg.refretrain_model_path
41        r"""The path to the model file to load the reference retrain model from."""
42        self.if_run_reforiginal: bool = cfg.get("if_run_reforiginal") is not False
43        r"""Whether reference original evaluation is enabled."""
44        self.reforiginal_model_path: str | None = (
45            cfg.reforiginal_model_path
46            if self.if_run_reforiginal and cfg.get("reforiginal_model_path")
47            else None
48        )
49        r"""The path to the model file to load the reference original model from."""
50        self.cl_paradigm: str = cfg.cl_paradigm
51        r"""The continual learning paradigm."""
52
53        self.dd_eval_tasks: list[int] = (
54            list(cfg.dd_eval_tasks)
55            if isinstance(cfg.dd_eval_tasks, ListConfig)
56            else list(range(1, cfg.dd_eval_tasks + 1))
57        )
58        r"""The list of tasks to be evaluated for DD."""
59        ag_eval_tasks_cfg = cfg.get("ag_eval_tasks")
60        self.ag_eval_tasks: list[int] = (
61            (
62                list(ag_eval_tasks_cfg)
63                if isinstance(ag_eval_tasks_cfg, ListConfig)
64                else list(range(1, ag_eval_tasks_cfg + 1))
65            )
66            if self.reforiginal_model_path and ag_eval_tasks_cfg is not None
67            else []
68        )
69        r"""The list of tasks to be evaluated for AG."""
70        self.global_seed: int = cfg.global_seed
71        r"""The global seed for the entire experiment."""
72        self.output_dir: str = cfg.output_dir
73        r"""The folder for storing the experiment results."""
74
75        # components
76        self.cl_dataset: CLDataset
77        r"""CL dataset object."""
78        self.evaluation_module: CULEvaluation
79        r"""Evaluation module for continual unlearning full evaluation."""
80        self.lightning_loggers: list[Logger]
81        r"""Lightning logger objects."""
82        self.callbacks: list[Callback]
83        r"""Callback objects."""
84        self.trainer: Trainer
85        r"""Trainer object."""

Args:

  • cfg (DictConfig): the complete config dict for the continual unlearning main evaluation.
cfg: omegaconf.dictconfig.DictConfig

The complete config dict.

main_model_path: str

The path to the model file to load the main model from.

refretrain_model_path: str

The path to the model file to load the reference retrain model from.

if_run_reforiginal: bool

Whether reference original evaluation is enabled.

reforiginal_model_path: str | None

The path to the model file to load the reference original model from.

cl_paradigm: str

The continual learning paradigm.

dd_eval_tasks: list[int]

The list of tasks to be evaluated for DD.

ag_eval_tasks: list[int]

The list of tasks to be evaluated for AG.

global_seed: int

The global seed for the entire experiment.

output_dir: str

The folder for storing the experiment results.

CL dataset object.

Evaluation module for continual unlearning full evaluation.

lightning_loggers: list[lightning.pytorch.loggers.logger.Logger]

Lightning logger objects.

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

Callback objects.

trainer: lightning.pytorch.trainer.trainer.Trainer

Trainer object.

def sanity_check(self) -> None:
 87    def sanity_check(self) -> None:
 88        r"""Sanity check for config."""
 89
 90        # check required config fields
 91        required_config_fields = [
 92            "pipeline",
 93            "main_model_path",
 94            "cl_paradigm",
 95            "global_seed",
 96            "cl_dataset",
 97            "trainer",
 98            "metrics",
 99            "callbacks",
100            "output_dir",
101            # "hydra" is excluded as it doesn't appear
102            "misc",
103        ]
104        for field in required_config_fields:
105            if not self.cfg.get(field):
106                raise KeyError(
107                    f"Field `{field}` is required in the experiment index config."
108                )
109
110        # check cl_paradigm
111        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
112            raise ValueError(
113                f"Field cl_paradigm should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
114            )
115
116        # warn if any reference experiment result is not provided
117        if not self.cfg.get("refretrain_model_path"):
118            pylogger.warning(
119                "`refretrain_model_path` not provided. Distribution Distance (DD) cannot be calculated."
120            )
121
122        if self.cfg.get("if_run_reforiginal") is False:
123            pylogger.info(
124                "`if_run_reforiginal` is false. Skip loading the reference original model and AG evaluation."
125            )
126        elif not self.cfg.get("reforiginal_model_path"):
127            pylogger.warning(
128                "`reforiginal_model_path` not provided. Accuracy Gain (AG) cannot be calculated."
129            )

Sanity check for config.

def instantiate_cl_dataset(self, cl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
131    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
132        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
133        pylogger.debug(
134            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
135            cl_dataset_cfg.get("_target_"),
136        )
137        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
138        pylogger.debug(
139            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
140            cl_dataset_cfg.get("_target_"),
141        )

Instantiate the CL dataset object from cl_dataset_cfg.

def instantiate_evaluation_module(self) -> None:
143    def instantiate_evaluation_module(self) -> None:
144        r"""Instantiate the evaluation module object."""
145        pylogger.debug(
146            "Instantiating evaluation module (clarena.utils.eval.CULEvaluation)...",
147        )
148
149        pylogger.debug("Loading main model from %s.", self.main_model_path)
150        pylogger.debug(
151            "Loading reference retrain model from %s.", self.refretrain_model_path
152        )
153        if self.reforiginal_model_path:
154            pylogger.debug(
155                "Loading reference original model from %s.",
156                self.reforiginal_model_path,
157            )
158
159        # NOTE: PyTorch >= 2.6 defaults to weights_only=True, which blocks loading custom classes.
160        # We explicitly set weights_only=False to allow loading full objects from our own checkpoints.
161        main_model = torch.load(
162            self.main_model_path, map_location="cpu", weights_only=False
163        )
164        refretrain_model = torch.load(
165            self.refretrain_model_path, map_location="cpu", weights_only=False
166        )
167        reforiginal_model = (
168            torch.load(
169                self.reforiginal_model_path, map_location="cpu", weights_only=False
170            )
171            if self.reforiginal_model_path
172            else None
173        )
174
175        pylogger.debug("Loaded main model type: %s", type(main_model))
176        pylogger.debug("Loaded reference retrain model type: %s", type(refretrain_model))
177        if reforiginal_model is not None:
178            pylogger.debug(
179                "Loaded reference original model type: %s", type(reforiginal_model)
180            )
181
182        self.evaluation_module = CULEvaluation(
183            main_model=main_model,
184            refretrain_model=refretrain_model,
185            reforiginal_model=reforiginal_model,
186            dd_eval_task_ids=self.dd_eval_tasks,
187            ag_eval_task_ids=self.ag_eval_tasks,
188        )
189        pylogger.debug(
190            "Evaluation module (clarena.utils.eval.CULEvaluation) instantiated!"
191        )

Instantiate the evaluation module object.

def instantiate_callbacks( self, metrics_cfg: omegaconf.listconfig.ListConfig, callbacks_cfg: omegaconf.listconfig.ListConfig) -> None:
193    def instantiate_callbacks(
194        self, metrics_cfg: ListConfig, callbacks_cfg: ListConfig
195    ) -> None:
196        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
197        pylogger.debug(
198            "Instantiating callbacks (lightning.Callback)...",
199        )
200
201        enabled_metric_cfgs = []
202        for callback in metrics_cfg:
203            if (
204                callback.get("_target_") == "clarena.metrics.CULAccuracyGain"
205                and (not self.reforiginal_model_path or not self.ag_eval_tasks)
206            ):
207                pylogger.info(
208                    "Skipping `clarena.metrics.CULAccuracyGain` because reference original evaluation is disabled."
209                )
210                continue
211            enabled_metric_cfgs.append(callback)
212
213        # instantiate metric callbacks
214        metric_callbacks = [
215            hydra.utils.instantiate(callback) for callback in enabled_metric_cfgs
216        ]
217
218        # instantiate other callbacks
219        other_callbacks = [
220            hydra.utils.instantiate(callback) for callback in callbacks_cfg
221        ]
222
223        # add metric callbacks to the list of callbacks
224        self.callbacks = metric_callbacks + other_callbacks
225        pylogger.debug("Callbacks (lightning.Callback) instantiated!")

Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.

def instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
227    def instantiate_trainer(
228        self,
229        trainer_cfg: DictConfig,
230        callbacks: list[Callback],
231    ) -> None:
232        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
233
234        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
235        self.trainer = hydra.utils.instantiate(
236            trainer_cfg,
237            callbacks=callbacks,
238        )
239        pylogger.debug("Trainer (lightning.Trainer) instantiated!")

Instantiate the trainer object from trainer_cfg and callbacks.

def set_global_seed(self, global_seed: int) -> None:
241    def set_global_seed(self, global_seed: int) -> None:
242        r"""Set the `global_seed` for the entire evaluation."""
243        L.seed_everything(self.global_seed, workers=True)
244        pylogger.debug("Global seed is set as %d.", global_seed)

Set the global_seed for the entire evaluation.

def run(self) -> None:
246    def run(self) -> None:
247        r"""The main method to run the continual unlearning full evaluation."""
248
249        self.set_global_seed(self.global_seed)
250
251        # components
252        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
253        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
254        self.instantiate_evaluation_module()
255        self.instantiate_callbacks(
256            metrics_cfg=self.cfg.metrics,
257            callbacks_cfg=self.cfg.callbacks,
258        )
259        self.instantiate_trainer(
260            trainer_cfg=self.cfg.trainer,
261            callbacks=self.callbacks,
262        )  # trainer should be instantiated after callbacks
263
264        # setup tasks for dataset and evaluation module
265        self.cl_dataset.setup_tasks_eval(
266            eval_tasks=sorted(set(self.dd_eval_tasks + self.ag_eval_tasks))
267        )
268
269        # evaluation
270        self.trainer.test(
271            model=self.evaluation_module,
272            datamodule=self.cl_dataset,
273        )
274        # please note this will set up last task dataset twice, which is fine

The main method to run the continual unlearning full evaluation.