clarena.pipelines.cl_main_eval

The submodule in pipelines for continual learning main evaluation.

  1r"""
  2The submodule in `pipelines` for continual learning main evaluation.
  3"""
  4
  5__all__ = ["CLMainEvaluation"]
  6
  7import logging
  8
  9import hydra
 10import lightning as L
 11import torch
 12from lightning import Callback, Trainer
 13from omegaconf import DictConfig, ListConfig
 14
 15from clarena.cl_datasets import CLDataset
 16
 17# always get logger for built-in logging in each module
 18pylogger = logging.getLogger(__name__)
 19
 20
 21class CLMainEvaluation:
 22    r"""The base class for continual learning main evaluation."""
 23
 24    def __init__(self, cfg: DictConfig) -> None:
 25        r"""
 26        **Args:**
 27        - **cfg** (`DictConfig`): the config dict for the continual learning main evaluation.
 28        """
 29        self.cfg: DictConfig = cfg
 30        r"""The complete config dict."""
 31
 32        CLMainEvaluation.sanity_check(self)
 33
 34        # required config fields
 35        self.main_model_path: str = cfg.main_model_path
 36        r"""The file path of the model to evaluate."""
 37        self.cl_paradigm: str = cfg.clmlp
 38        r"""The continual learning paradigm."""
 39        self.eval_tasks: list[int] = (
 40            cfg.eval_tasks
 41            if isinstance(cfg.eval_tasks, ListConfig)
 42            else list(range(1, cfg.eval_tasks + 1))
 43        )
 44        r"""The list of task IDs to evaluate."""
 45        self.global_seed: int = cfg.global_seed
 46        r"""The global seed for the entire experiment."""
 47        self.output_dir: str = cfg.output_dir
 48        r"""The folder for storing the experiment results."""
 49
 50        # components
 51        self.cl_dataset: CLDataset
 52        r"""CL dataset object."""
 53        self.callbacks: list[Callback]
 54        r"""Callback objects."""
 55        self.trainer: Trainer
 56        r"""Trainer object."""
 57
 58    def sanity_check(self) -> None:
 59        r"""Sanity check for config."""
 60
 61        # check required config fields
 62        required_config_fields = [
 63            "pipeline",
 64            "main_model_path",
 65            "eval_tasks",
 66            "cl_paradigm",
 67            "global_seed",
 68            "cl_dataset",
 69            "trainer",
 70            "metrics",
 71            "callbacks",
 72            "output_dir",
 73            # "hydra" is excluded as it doesn't appear
 74            "misc",
 75        ]
 76        for field in required_config_fields:
 77            if not self.cfg.get(field):
 78                raise KeyError(
 79                    f"Field `{field}` is required in the experiment index config."
 80                )
 81
 82        # check cl_paradigm
 83        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
 84            raise ValueError(
 85                f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
 86            )
 87
 88        # check eval_tasks
 89        if self.cfg.cl_dataset.get("num_tasks"):
 90            num_tasks = self.cfg.cl_dataset.get("num_tasks")
 91        elif self.cfg.cl_dataset.get("class_split"):
 92            num_tasks = len(self.cfg.cl_dataset.class_split)
 93        elif self.cfg.cl_dataset.get("datasets"):
 94            num_tasks = len(self.cfg.cl_dataset.datasets)
 95        else:
 96            raise KeyError(
 97                "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
 98            )
 99
100        eval_tasks = self.cfg.eval_tasks
101        if isinstance(eval_tasks, ListConfig):
102            if len(eval_tasks) < 1:
103                raise ValueError("`eval_tasks` must contain at least one task.")
104            if any(t < 1 or t > num_tasks for t in eval_tasks):
105                raise ValueError(
106                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
107                )
108        elif isinstance(eval_tasks, int):
109            if eval_tasks < 0 or eval_tasks > num_tasks:
110                raise ValueError(
111                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
112                )
113        else:
114            raise TypeError(
115                "`eval_tasks` must be either a list of integers or an integer."
116            )
117
118    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
119        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
120        pylogger.debug(
121            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
122            cl_dataset_cfg.get("_target_"),
123        )
124        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
125        pylogger.debug(
126            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
127            cl_dataset_cfg.get("_target_"),
128        )
129
130    def instantiate_callbacks(
131        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
132    ) -> None:
133        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
134        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
135
136        # instantiate metric callbacks
137        metric_callbacks: list[Callback] = [
138            hydra.utils.instantiate(callback) for callback in metrics_cfg
139        ]
140
141        # instantiate other callbacks
142        other_callbacks: list[Callback] = [
143            hydra.utils.instantiate(callback) for callback in callbacks_cfg
144        ]
145
146        # add metric callbacks to the list of callbacks
147        self.callbacks = metric_callbacks + other_callbacks
148        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
149
150    def instantiate_trainer(
151        self,
152        trainer_cfg: DictConfig,
153        callbacks: list[Callback],
154    ) -> None:
155        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
156
157        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
158        self.trainer = hydra.utils.instantiate(
159            trainer_cfg,
160            callbacks=callbacks,
161        )
162        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
163
164    def set_global_seed(self, global_seed: int) -> None:
165        r"""Set the `global_seed` for the entire evaluation."""
166        L.seed_everything(self.global_seed, workers=True)
167        pylogger.debug("Global seed is set as %d.", global_seed)
168
169    def run(self) -> None:
170        r"""The main method to run the continual learning main evaluation."""
171
172        self.set_global_seed(self.global_seed)
173
174        # load the model from file
175        model = torch.load(self.main_model_path)
176
177        # components
178        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
179        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
180        self.instantiate_callbacks(
181            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
182        )
183        self.instantiate_trainer(
184            trainer_cfg=self.cfg.trainer,
185            callbacks=self.callbacks,
186        )  # trainer should be instantiated after callbacks
187
188        # setup tasks for dataset
189        self.cl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks)
190
191        # evaluation skipping training and validation
192        self.trainer.test(
193            model=model,
194            datamodule=self.cl_dataset,
195        )
class CLMainEvaluation:
 22class CLMainEvaluation:
 23    r"""The base class for continual learning main evaluation."""
 24
 25    def __init__(self, cfg: DictConfig) -> None:
 26        r"""
 27        **Args:**
 28        - **cfg** (`DictConfig`): the config dict for the continual learning main evaluation.
 29        """
 30        self.cfg: DictConfig = cfg
 31        r"""The complete config dict."""
 32
 33        CLMainEvaluation.sanity_check(self)
 34
 35        # required config fields
 36        self.main_model_path: str = cfg.main_model_path
 37        r"""The file path of the model to evaluate."""
 38        self.cl_paradigm: str = cfg.clmlp
 39        r"""The continual learning paradigm."""
 40        self.eval_tasks: list[int] = (
 41            cfg.eval_tasks
 42            if isinstance(cfg.eval_tasks, ListConfig)
 43            else list(range(1, cfg.eval_tasks + 1))
 44        )
 45        r"""The list of task IDs to evaluate."""
 46        self.global_seed: int = cfg.global_seed
 47        r"""The global seed for the entire experiment."""
 48        self.output_dir: str = cfg.output_dir
 49        r"""The folder for storing the experiment results."""
 50
 51        # components
 52        self.cl_dataset: CLDataset
 53        r"""CL dataset object."""
 54        self.callbacks: list[Callback]
 55        r"""Callback objects."""
 56        self.trainer: Trainer
 57        r"""Trainer object."""
 58
 59    def sanity_check(self) -> None:
 60        r"""Sanity check for config."""
 61
 62        # check required config fields
 63        required_config_fields = [
 64            "pipeline",
 65            "main_model_path",
 66            "eval_tasks",
 67            "cl_paradigm",
 68            "global_seed",
 69            "cl_dataset",
 70            "trainer",
 71            "metrics",
 72            "callbacks",
 73            "output_dir",
 74            # "hydra" is excluded as it doesn't appear
 75            "misc",
 76        ]
 77        for field in required_config_fields:
 78            if not self.cfg.get(field):
 79                raise KeyError(
 80                    f"Field `{field}` is required in the experiment index config."
 81                )
 82
 83        # check cl_paradigm
 84        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
 85            raise ValueError(
 86                f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
 87            )
 88
 89        # check eval_tasks
 90        if self.cfg.cl_dataset.get("num_tasks"):
 91            num_tasks = self.cfg.cl_dataset.get("num_tasks")
 92        elif self.cfg.cl_dataset.get("class_split"):
 93            num_tasks = len(self.cfg.cl_dataset.class_split)
 94        elif self.cfg.cl_dataset.get("datasets"):
 95            num_tasks = len(self.cfg.cl_dataset.datasets)
 96        else:
 97            raise KeyError(
 98                "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
 99            )
100
101        eval_tasks = self.cfg.eval_tasks
102        if isinstance(eval_tasks, ListConfig):
103            if len(eval_tasks) < 1:
104                raise ValueError("`eval_tasks` must contain at least one task.")
105            if any(t < 1 or t > num_tasks for t in eval_tasks):
106                raise ValueError(
107                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
108                )
109        elif isinstance(eval_tasks, int):
110            if eval_tasks < 0 or eval_tasks > num_tasks:
111                raise ValueError(
112                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
113                )
114        else:
115            raise TypeError(
116                "`eval_tasks` must be either a list of integers or an integer."
117            )
118
119    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
120        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
121        pylogger.debug(
122            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
123            cl_dataset_cfg.get("_target_"),
124        )
125        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
126        pylogger.debug(
127            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
128            cl_dataset_cfg.get("_target_"),
129        )
130
131    def instantiate_callbacks(
132        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
133    ) -> None:
134        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
135        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
136
137        # instantiate metric callbacks
138        metric_callbacks: list[Callback] = [
139            hydra.utils.instantiate(callback) for callback in metrics_cfg
140        ]
141
142        # instantiate other callbacks
143        other_callbacks: list[Callback] = [
144            hydra.utils.instantiate(callback) for callback in callbacks_cfg
145        ]
146
147        # add metric callbacks to the list of callbacks
148        self.callbacks = metric_callbacks + other_callbacks
149        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
150
151    def instantiate_trainer(
152        self,
153        trainer_cfg: DictConfig,
154        callbacks: list[Callback],
155    ) -> None:
156        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
157
158        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
159        self.trainer = hydra.utils.instantiate(
160            trainer_cfg,
161            callbacks=callbacks,
162        )
163        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
164
165    def set_global_seed(self, global_seed: int) -> None:
166        r"""Set the `global_seed` for the entire evaluation."""
167        L.seed_everything(self.global_seed, workers=True)
168        pylogger.debug("Global seed is set as %d.", global_seed)
169
170    def run(self) -> None:
171        r"""The main method to run the continual learning main evaluation."""
172
173        self.set_global_seed(self.global_seed)
174
175        # load the model from file
176        model = torch.load(self.main_model_path)
177
178        # components
179        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
180        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
181        self.instantiate_callbacks(
182            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
183        )
184        self.instantiate_trainer(
185            trainer_cfg=self.cfg.trainer,
186            callbacks=self.callbacks,
187        )  # trainer should be instantiated after callbacks
188
189        # setup tasks for dataset
190        self.cl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks)
191
192        # evaluation skipping training and validation
193        self.trainer.test(
194            model=model,
195            datamodule=self.cl_dataset,
196        )

The base class for continual learning main evaluation.

CLMainEvaluation(cfg: omegaconf.dictconfig.DictConfig)
25    def __init__(self, cfg: DictConfig) -> None:
26        r"""
27        **Args:**
28        - **cfg** (`DictConfig`): the config dict for the continual learning main evaluation.
29        """
30        self.cfg: DictConfig = cfg
31        r"""The complete config dict."""
32
33        CLMainEvaluation.sanity_check(self)
34
35        # required config fields
36        self.main_model_path: str = cfg.main_model_path
37        r"""The file path of the model to evaluate."""
38        self.cl_paradigm: str = cfg.clmlp
39        r"""The continual learning paradigm."""
40        self.eval_tasks: list[int] = (
41            cfg.eval_tasks
42            if isinstance(cfg.eval_tasks, ListConfig)
43            else list(range(1, cfg.eval_tasks + 1))
44        )
45        r"""The list of task IDs to evaluate."""
46        self.global_seed: int = cfg.global_seed
47        r"""The global seed for the entire experiment."""
48        self.output_dir: str = cfg.output_dir
49        r"""The folder for storing the experiment results."""
50
51        # components
52        self.cl_dataset: CLDataset
53        r"""CL dataset object."""
54        self.callbacks: list[Callback]
55        r"""Callback objects."""
56        self.trainer: Trainer
57        r"""Trainer object."""

Args:

  • cfg (DictConfig): the config dict for the continual learning main evaluation.
cfg: omegaconf.dictconfig.DictConfig

The complete config dict.

main_model_path: str

The file path of the model to evaluate.

cl_paradigm: str

The continual learning paradigm.

eval_tasks: list[int]

The list of task IDs to evaluate.

global_seed: int

The global seed for the entire experiment.

output_dir: str

The folder for storing the experiment results.

CL dataset object.

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

Callback objects.

trainer: lightning.pytorch.trainer.trainer.Trainer

Trainer object.

def sanity_check(self) -> None:
 59    def sanity_check(self) -> None:
 60        r"""Sanity check for config."""
 61
 62        # check required config fields
 63        required_config_fields = [
 64            "pipeline",
 65            "main_model_path",
 66            "eval_tasks",
 67            "cl_paradigm",
 68            "global_seed",
 69            "cl_dataset",
 70            "trainer",
 71            "metrics",
 72            "callbacks",
 73            "output_dir",
 74            # "hydra" is excluded as it doesn't appear
 75            "misc",
 76        ]
 77        for field in required_config_fields:
 78            if not self.cfg.get(field):
 79                raise KeyError(
 80                    f"Field `{field}` is required in the experiment index config."
 81                )
 82
 83        # check cl_paradigm
 84        if self.cfg.cl_paradigm not in ["TIL", "CIL", "DIL"]:
 85            raise ValueError(
 86                f"Field `cl_paradigm` should be either 'TIL', 'CIL' or 'DIL' but got {self.cfg.cl_paradigm}!"
 87            )
 88
 89        # check eval_tasks
 90        if self.cfg.cl_dataset.get("num_tasks"):
 91            num_tasks = self.cfg.cl_dataset.get("num_tasks")
 92        elif self.cfg.cl_dataset.get("class_split"):
 93            num_tasks = len(self.cfg.cl_dataset.class_split)
 94        elif self.cfg.cl_dataset.get("datasets"):
 95            num_tasks = len(self.cfg.cl_dataset.datasets)
 96        else:
 97            raise KeyError(
 98                "`num_tasks` is required in cl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
 99            )
100
101        eval_tasks = self.cfg.eval_tasks
102        if isinstance(eval_tasks, ListConfig):
103            if len(eval_tasks) < 1:
104                raise ValueError("`eval_tasks` must contain at least one task.")
105            if any(t < 1 or t > num_tasks for t in eval_tasks):
106                raise ValueError(
107                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
108                )
109        elif isinstance(eval_tasks, int):
110            if eval_tasks < 0 or eval_tasks > num_tasks:
111                raise ValueError(
112                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
113                )
114        else:
115            raise TypeError(
116                "`eval_tasks` must be either a list of integers or an integer."
117            )

Sanity check for config.

def instantiate_cl_dataset(self, cl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
119    def instantiate_cl_dataset(self, cl_dataset_cfg: DictConfig) -> None:
120        r"""Instantiate the CL dataset object from `cl_dataset_cfg`."""
121        pylogger.debug(
122            "Instantiating CL dataset <%s> (clarena.cl_datasets.CLDataset)...",
123            cl_dataset_cfg.get("_target_"),
124        )
125        self.cl_dataset = hydra.utils.instantiate(cl_dataset_cfg)
126        pylogger.debug(
127            "CL dataset <%s> (clarena.cl_datasets.CLDataset) instantiated!",
128            cl_dataset_cfg.get("_target_"),
129        )

Instantiate the CL dataset object from cl_dataset_cfg.

def instantiate_callbacks( self, metrics_cfg: omegaconf.dictconfig.DictConfig, callbacks_cfg: omegaconf.dictconfig.DictConfig) -> None:
131    def instantiate_callbacks(
132        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
133    ) -> None:
134        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
135        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
136
137        # instantiate metric callbacks
138        metric_callbacks: list[Callback] = [
139            hydra.utils.instantiate(callback) for callback in metrics_cfg
140        ]
141
142        # instantiate other callbacks
143        other_callbacks: list[Callback] = [
144            hydra.utils.instantiate(callback) for callback in callbacks_cfg
145        ]
146
147        # add metric callbacks to the list of callbacks
148        self.callbacks = metric_callbacks + other_callbacks
149        pylogger.debug("Callbacks (lightning.Callback) instantiated!")

Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.

def instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
151    def instantiate_trainer(
152        self,
153        trainer_cfg: DictConfig,
154        callbacks: list[Callback],
155    ) -> None:
156        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
157
158        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
159        self.trainer = hydra.utils.instantiate(
160            trainer_cfg,
161            callbacks=callbacks,
162        )
163        pylogger.debug("Trainer (lightning.Trainer) instantiated!")

Instantiate the trainer object from trainer_cfg and callbacks.

def set_global_seed(self, global_seed: int) -> None:
165    def set_global_seed(self, global_seed: int) -> None:
166        r"""Set the `global_seed` for the entire evaluation."""
167        L.seed_everything(self.global_seed, workers=True)
168        pylogger.debug("Global seed is set as %d.", global_seed)

Set the global_seed for the entire evaluation.

def run(self) -> None:
170    def run(self) -> None:
171        r"""The main method to run the continual learning main evaluation."""
172
173        self.set_global_seed(self.global_seed)
174
175        # load the model from file
176        model = torch.load(self.main_model_path)
177
178        # components
179        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
180        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
181        self.instantiate_callbacks(
182            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
183        )
184        self.instantiate_trainer(
185            trainer_cfg=self.cfg.trainer,
186            callbacks=self.callbacks,
187        )  # trainer should be instantiated after callbacks
188
189        # setup tasks for dataset
190        self.cl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks)
191
192        # evaluation skipping training and validation
193        self.trainer.test(
194            model=model,
195            datamodule=self.cl_dataset,
196        )

The main method to run the continual learning main evaluation.