clarena.pipelines.mtl_eval

The submodule in pipelines for multi-task learning evaluation.

  1r"""The submodule in `pipelines` for multi-task learning evaluation."""
  2
  3__all__ = ["MTLEvaluation"]
  4
  5import logging
  6
  7import hydra
  8import lightning as L
  9import torch
 10from lightning import Callback, Trainer
 11from lightning.pytorch.loggers import Logger
 12from omegaconf import DictConfig, ListConfig
 13
 14from clarena.mtl_datasets import MTLDataset
 15
 16# always get logger for built-in logging in each module
 17pylogger = logging.getLogger(__name__)
 18
 19
 20class MTLEvaluation:
 21    r"""The base class for multi-task learning evaluation."""
 22
 23    def __init__(self, cfg: DictConfig) -> None:
 24        r"""
 25        **Args:**
 26        - **cfg** (`DictConfig`): the config dict for the multi-task learning evaluation.
 27        """
 28        self.cfg: DictConfig = cfg
 29        r"""The complete config dict."""
 30
 31        MTLEvaluation.sanity_check(self)
 32
 33        # required config fields
 34        self.eval_tasks: list[int] = (
 35            cfg.eval_tasks
 36            if isinstance(cfg.eval_tasks, ListConfig)
 37            else list(range(1, cfg.eval_tasks + 1))
 38        )
 39        r"""The list of task IDs to evaluate."""
 40        self.global_seed: int = cfg.global_seed
 41        r"""The global seed for the entire experiment."""
 42        self.output_dir: str = cfg.output_dir
 43        r"""The folder for storing the experiment results."""
 44        self.model_path: str = cfg.model_path
 45        r"""The file path of the model to evaluate."""
 46
 47        # components
 48        self.mtl_dataset: MTLDataset
 49        r"""MTL dataset object."""
 50        self.lightning_loggers: list[Logger]
 51        r"""Lightning logger objects."""
 52        self.callbacks: list[Callback]
 53        r"""Callback objects."""
 54        self.trainer: Trainer
 55        r"""Trainer object."""
 56
 57    def sanity_check(self) -> None:
 58        r"""Sanity check for config."""
 59
 60        # check required config fields
 61        required_config_fields = [
 62            "pipeline",
 63            "model_path",
 64            "eval_tasks",
 65            "global_seed",
 66            "trainer",
 67            "metrics",
 68            "callbacks",
 69            "output_dir",
 70            # "hydra" is excluded as it doesn't appear
 71            "misc",
 72        ]
 73        for field in required_config_fields:
 74            if not self.cfg.get(field):
 75                raise KeyError(
 76                    f"Field `{field}` is required in the experiment index config."
 77                )
 78
 79        # get dataset number of tasks
 80        if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL":
 81            cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset")
 82            if cl_dataset_cfg.get("num_tasks"):
 83                num_tasks = cl_dataset_cfg.get("num_tasks")
 84            elif cl_dataset_cfg.get("class_split"):
 85                num_tasks = len(cl_dataset_cfg.class_split)
 86            elif cl_dataset_cfg.get("datasets"):
 87                num_tasks = len(cl_dataset_cfg.datasets)
 88            else:
 89                raise KeyError(
 90                    "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
 91                )
 92        else:
 93            if self.cfg.mtl_dataset.get("num_tasks"):
 94                num_tasks = self.cfg.mtl_dataset.num_tasks
 95            else:
 96                raise KeyError(
 97                    "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config."
 98                )
 99
100        # check eval_tasks
101        eval_tasks = self.cfg.eval_tasks
102        if isinstance(eval_tasks, list):
103            if len(eval_tasks) < 1:
104                raise ValueError("`eval_tasks` must contain at least one task.")
105            if any(t < 1 or t > num_tasks for t in eval_tasks):
106                raise ValueError(
107                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
108                )
109        elif isinstance(eval_tasks, int):
110            if eval_tasks < 0 or eval_tasks > num_tasks:
111                raise ValueError(
112                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
113                )
114        else:
115            raise TypeError(
116                "`eval_tasks` must be either a list of integers or an integer."
117            )
118
119    def instantiate_mtl_dataset(
120        self,
121        mtl_dataset_cfg: DictConfig,
122    ) -> None:
123        r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`."""
124        pylogger.debug(
125            "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) ...",
126            mtl_dataset_cfg.get("_target_"),
127        )
128        self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg)
129        pylogger.debug(
130            "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated! ",
131            mtl_dataset_cfg.get("_target_"),
132        )
133
134    def instantiate_callbacks(
135        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
136    ) -> None:
137        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
138        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
139
140        # instantiate metric callbacks
141        metric_callbacks = [
142            hydra.utils.instantiate(callback) for callback in metrics_cfg
143        ]
144
145        # instantiate other callbacks
146        other_callbacks = [
147            hydra.utils.instantiate(callback) for callback in callbacks_cfg
148        ]
149
150        # add metric callbacks to the list of callbacks
151        self.callbacks = metric_callbacks + other_callbacks
152        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
153
154    def instantiate_trainer(
155        self,
156        trainer_cfg: DictConfig,
157        callbacks: list[Callback],
158    ) -> None:
159        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
160
161        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
162        self.trainer = hydra.utils.instantiate(trainer_cfg, callbacks=callbacks)
163        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
164
165    def set_global_seed(self, global_seed: int) -> None:
166        r"""Set the `global_seed` for the entire evaluation."""
167        L.seed_everything(self.global_seed, workers=True)
168        pylogger.debug("Global seed is set as %d.", global_seed)
169
170    def run(self) -> None:
171        r"""The main method to run the multi-task learning evaluation."""
172
173        self.set_global_seed(self.global_seed)
174
175        # load the model from file
176        model = torch.load(self.model_path)
177
178        self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset)
179        self.instantiate_callbacks(
180            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
181        )
182        self.instantiate_trainer(
183            trainer_cfg=self.cfg.trainer,
184            callbacks=self.callbacks,
185        )  # trainer should be instantiated after callbacks
186
187        # setup tasks for dataset
188        self.mtl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks)
189
190        # evaluation skipping training and validation
191        self.trainer.test(
192            model=model,
193            datamodule=self.mtl_dataset,
194        )
class MTLEvaluation:
 21class MTLEvaluation:
 22    r"""The base class for multi-task learning evaluation."""
 23
 24    def __init__(self, cfg: DictConfig) -> None:
 25        r"""
 26        **Args:**
 27        - **cfg** (`DictConfig`): the config dict for the multi-task learning evaluation.
 28        """
 29        self.cfg: DictConfig = cfg
 30        r"""The complete config dict."""
 31
 32        MTLEvaluation.sanity_check(self)
 33
 34        # required config fields
 35        self.eval_tasks: list[int] = (
 36            cfg.eval_tasks
 37            if isinstance(cfg.eval_tasks, ListConfig)
 38            else list(range(1, cfg.eval_tasks + 1))
 39        )
 40        r"""The list of task IDs to evaluate."""
 41        self.global_seed: int = cfg.global_seed
 42        r"""The global seed for the entire experiment."""
 43        self.output_dir: str = cfg.output_dir
 44        r"""The folder for storing the experiment results."""
 45        self.model_path: str = cfg.model_path
 46        r"""The file path of the model to evaluate."""
 47
 48        # components
 49        self.mtl_dataset: MTLDataset
 50        r"""MTL dataset object."""
 51        self.lightning_loggers: list[Logger]
 52        r"""Lightning logger objects."""
 53        self.callbacks: list[Callback]
 54        r"""Callback objects."""
 55        self.trainer: Trainer
 56        r"""Trainer object."""
 57
 58    def sanity_check(self) -> None:
 59        r"""Sanity check for config."""
 60
 61        # check required config fields
 62        required_config_fields = [
 63            "pipeline",
 64            "model_path",
 65            "eval_tasks",
 66            "global_seed",
 67            "trainer",
 68            "metrics",
 69            "callbacks",
 70            "output_dir",
 71            # "hydra" is excluded as it doesn't appear
 72            "misc",
 73        ]
 74        for field in required_config_fields:
 75            if not self.cfg.get(field):
 76                raise KeyError(
 77                    f"Field `{field}` is required in the experiment index config."
 78                )
 79
 80        # get dataset number of tasks
 81        if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL":
 82            cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset")
 83            if cl_dataset_cfg.get("num_tasks"):
 84                num_tasks = cl_dataset_cfg.get("num_tasks")
 85            elif cl_dataset_cfg.get("class_split"):
 86                num_tasks = len(cl_dataset_cfg.class_split)
 87            elif cl_dataset_cfg.get("datasets"):
 88                num_tasks = len(cl_dataset_cfg.datasets)
 89            else:
 90                raise KeyError(
 91                    "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
 92                )
 93        else:
 94            if self.cfg.mtl_dataset.get("num_tasks"):
 95                num_tasks = self.cfg.mtl_dataset.num_tasks
 96            else:
 97                raise KeyError(
 98                    "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config."
 99                )
100
101        # check eval_tasks
102        eval_tasks = self.cfg.eval_tasks
103        if isinstance(eval_tasks, list):
104            if len(eval_tasks) < 1:
105                raise ValueError("`eval_tasks` must contain at least one task.")
106            if any(t < 1 or t > num_tasks for t in eval_tasks):
107                raise ValueError(
108                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
109                )
110        elif isinstance(eval_tasks, int):
111            if eval_tasks < 0 or eval_tasks > num_tasks:
112                raise ValueError(
113                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
114                )
115        else:
116            raise TypeError(
117                "`eval_tasks` must be either a list of integers or an integer."
118            )
119
120    def instantiate_mtl_dataset(
121        self,
122        mtl_dataset_cfg: DictConfig,
123    ) -> None:
124        r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`."""
125        pylogger.debug(
126            "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) ...",
127            mtl_dataset_cfg.get("_target_"),
128        )
129        self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg)
130        pylogger.debug(
131            "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated! ",
132            mtl_dataset_cfg.get("_target_"),
133        )
134
135    def instantiate_callbacks(
136        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
137    ) -> None:
138        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
139        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
140
141        # instantiate metric callbacks
142        metric_callbacks = [
143            hydra.utils.instantiate(callback) for callback in metrics_cfg
144        ]
145
146        # instantiate other callbacks
147        other_callbacks = [
148            hydra.utils.instantiate(callback) for callback in callbacks_cfg
149        ]
150
151        # add metric callbacks to the list of callbacks
152        self.callbacks = metric_callbacks + other_callbacks
153        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
154
155    def instantiate_trainer(
156        self,
157        trainer_cfg: DictConfig,
158        callbacks: list[Callback],
159    ) -> None:
160        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
161
162        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
163        self.trainer = hydra.utils.instantiate(trainer_cfg, callbacks=callbacks)
164        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
165
166    def set_global_seed(self, global_seed: int) -> None:
167        r"""Set the `global_seed` for the entire evaluation."""
168        L.seed_everything(self.global_seed, workers=True)
169        pylogger.debug("Global seed is set as %d.", global_seed)
170
171    def run(self) -> None:
172        r"""The main method to run the multi-task learning evaluation."""
173
174        self.set_global_seed(self.global_seed)
175
176        # load the model from file
177        model = torch.load(self.model_path)
178
179        self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset)
180        self.instantiate_callbacks(
181            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
182        )
183        self.instantiate_trainer(
184            trainer_cfg=self.cfg.trainer,
185            callbacks=self.callbacks,
186        )  # trainer should be instantiated after callbacks
187
188        # setup tasks for dataset
189        self.mtl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks)
190
191        # evaluation skipping training and validation
192        self.trainer.test(
193            model=model,
194            datamodule=self.mtl_dataset,
195        )

The base class for multi-task learning evaluation.

MTLEvaluation(cfg: omegaconf.dictconfig.DictConfig)
24    def __init__(self, cfg: DictConfig) -> None:
25        r"""
26        **Args:**
27        - **cfg** (`DictConfig`): the config dict for the multi-task learning evaluation.
28        """
29        self.cfg: DictConfig = cfg
30        r"""The complete config dict."""
31
32        MTLEvaluation.sanity_check(self)
33
34        # required config fields
35        self.eval_tasks: list[int] = (
36            cfg.eval_tasks
37            if isinstance(cfg.eval_tasks, ListConfig)
38            else list(range(1, cfg.eval_tasks + 1))
39        )
40        r"""The list of task IDs to evaluate."""
41        self.global_seed: int = cfg.global_seed
42        r"""The global seed for the entire experiment."""
43        self.output_dir: str = cfg.output_dir
44        r"""The folder for storing the experiment results."""
45        self.model_path: str = cfg.model_path
46        r"""The file path of the model to evaluate."""
47
48        # components
49        self.mtl_dataset: MTLDataset
50        r"""MTL dataset object."""
51        self.lightning_loggers: list[Logger]
52        r"""Lightning logger objects."""
53        self.callbacks: list[Callback]
54        r"""Callback objects."""
55        self.trainer: Trainer
56        r"""Trainer object."""

Args:

  • cfg (DictConfig): the config dict for the multi-task learning evaluation.
cfg: omegaconf.dictconfig.DictConfig

The complete config dict.

eval_tasks: list[int]

The list of task IDs to evaluate.

global_seed: int

The global seed for the entire experiment.

output_dir: str

The folder for storing the experiment results.

model_path: str

The file path of the model to evaluate.

MTL dataset object.

lightning_loggers: list[lightning.pytorch.loggers.logger.Logger]

Lightning logger objects.

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

Callback objects.

trainer: lightning.pytorch.trainer.trainer.Trainer

Trainer object.

def sanity_check(self) -> None:
 58    def sanity_check(self) -> None:
 59        r"""Sanity check for config."""
 60
 61        # check required config fields
 62        required_config_fields = [
 63            "pipeline",
 64            "model_path",
 65            "eval_tasks",
 66            "global_seed",
 67            "trainer",
 68            "metrics",
 69            "callbacks",
 70            "output_dir",
 71            # "hydra" is excluded as it doesn't appear
 72            "misc",
 73        ]
 74        for field in required_config_fields:
 75            if not self.cfg.get(field):
 76                raise KeyError(
 77                    f"Field `{field}` is required in the experiment index config."
 78                )
 79
 80        # get dataset number of tasks
 81        if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL":
 82            cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset")
 83            if cl_dataset_cfg.get("num_tasks"):
 84                num_tasks = cl_dataset_cfg.get("num_tasks")
 85            elif cl_dataset_cfg.get("class_split"):
 86                num_tasks = len(cl_dataset_cfg.class_split)
 87            elif cl_dataset_cfg.get("datasets"):
 88                num_tasks = len(cl_dataset_cfg.datasets)
 89            else:
 90                raise KeyError(
 91                    "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
 92                )
 93        else:
 94            if self.cfg.mtl_dataset.get("num_tasks"):
 95                num_tasks = self.cfg.mtl_dataset.num_tasks
 96            else:
 97                raise KeyError(
 98                    "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config."
 99                )
100
101        # check eval_tasks
102        eval_tasks = self.cfg.eval_tasks
103        if isinstance(eval_tasks, list):
104            if len(eval_tasks) < 1:
105                raise ValueError("`eval_tasks` must contain at least one task.")
106            if any(t < 1 or t > num_tasks for t in eval_tasks):
107                raise ValueError(
108                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
109                )
110        elif isinstance(eval_tasks, int):
111            if eval_tasks < 0 or eval_tasks > num_tasks:
112                raise ValueError(
113                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
114                )
115        else:
116            raise TypeError(
117                "`eval_tasks` must be either a list of integers or an integer."
118            )

Sanity check for config.

def instantiate_mtl_dataset(self, mtl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
120    def instantiate_mtl_dataset(
121        self,
122        mtl_dataset_cfg: DictConfig,
123    ) -> None:
124        r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`."""
125        pylogger.debug(
126            "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) ...",
127            mtl_dataset_cfg.get("_target_"),
128        )
129        self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg)
130        pylogger.debug(
131            "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated! ",
132            mtl_dataset_cfg.get("_target_"),
133        )

Instantiate the MTL dataset object from mtl_dataset_cfg.

def instantiate_callbacks( self, metrics_cfg: omegaconf.dictconfig.DictConfig, callbacks_cfg: omegaconf.dictconfig.DictConfig) -> None:
135    def instantiate_callbacks(
136        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
137    ) -> None:
138        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
139        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
140
141        # instantiate metric callbacks
142        metric_callbacks = [
143            hydra.utils.instantiate(callback) for callback in metrics_cfg
144        ]
145
146        # instantiate other callbacks
147        other_callbacks = [
148            hydra.utils.instantiate(callback) for callback in callbacks_cfg
149        ]
150
151        # add metric callbacks to the list of callbacks
152        self.callbacks = metric_callbacks + other_callbacks
153        pylogger.debug("Callbacks (lightning.Callback) instantiated!")

Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.

def instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
155    def instantiate_trainer(
156        self,
157        trainer_cfg: DictConfig,
158        callbacks: list[Callback],
159    ) -> None:
160        r"""Instantiate the trainer object from `trainer_cfg` and `callbacks`."""
161
162        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
163        self.trainer = hydra.utils.instantiate(trainer_cfg, callbacks=callbacks)
164        pylogger.debug("Trainer (lightning.Trainer) instantiated!")

Instantiate the trainer object from trainer_cfg and callbacks.

def set_global_seed(self, global_seed: int) -> None:
166    def set_global_seed(self, global_seed: int) -> None:
167        r"""Set the `global_seed` for the entire evaluation."""
168        L.seed_everything(self.global_seed, workers=True)
169        pylogger.debug("Global seed is set as %d.", global_seed)

Set the global_seed for the entire evaluation.

def run(self) -> None:
171    def run(self) -> None:
172        r"""The main method to run the multi-task learning evaluation."""
173
174        self.set_global_seed(self.global_seed)
175
176        # load the model from file
177        model = torch.load(self.model_path)
178
179        self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset)
180        self.instantiate_callbacks(
181            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
182        )
183        self.instantiate_trainer(
184            trainer_cfg=self.cfg.trainer,
185            callbacks=self.callbacks,
186        )  # trainer should be instantiated after callbacks
187
188        # setup tasks for dataset
189        self.mtl_dataset.setup_tasks_eval(eval_tasks=self.eval_tasks)
190
191        # evaluation skipping training and validation
192        self.trainer.test(
193            model=model,
194            datamodule=self.mtl_dataset,
195        )

The main method to run the multi-task learning evaluation.