clarena.pipelines.stl_eval

The submodule in pipelines for single-task learning evaluation.

  1r"""The submodule in `pipelines` for single-task learning evaluation."""
  2
  3__all__ = ["STLEvaluation"]
  4
  5import logging
  6
  7import hydra
  8import lightning as L
  9import torch
 10from lightning import Callback, Trainer
 11from lightning.pytorch.loggers import Logger
 12from omegaconf import DictConfig
 13
 14from clarena.pipelines import STLExperiment
 15from clarena.stl_datasets import STLDataset
 16
 17# always get logger for built-in logging in each module
 18pylogger = logging.getLogger(__name__)
 19
 20
 21class STLEvaluation:
 22    r"""The base class for single-task learning evaluation."""
 23
 24    def __init__(self, cfg: DictConfig) -> None:
 25        r"""
 26        **Args:**
 27        - **cfg** (`DictConfig`): the config dict for the single-task learning evaluation.
 28        """
 29        self.cfg: DictConfig = cfg
 30        r"""The complete config dict."""
 31
 32        STLEvaluation.sanity_check(self)
 33
 34        # required config fields
 35        self.global_seed: int = cfg.global_seed
 36        r"""The global seed for the entire experiment."""
 37        self.output_dir: str = cfg.output_dir
 38        r"""The folder for storing the experiment results."""
 39        self.model_path: str = cfg.model_path
 40        r"""The file path of the model to evaluate."""
 41
 42        # components
 43        self.stl_dataset: STLDataset
 44        r"""STL dataset object."""
 45        self.lightning_loggers: list[Logger]
 46        r"""Lightning logger objects."""
 47        self.callbacks: list[Callback]
 48        r"""Callback objects."""
 49        self.trainer: Trainer
 50        r"""Trainer object."""
 51
 52    def sanity_check(self) -> None:
 53        r"""Sanity check for config."""
 54
 55        # check required config fields
 56        required_config_fields = [
 57            "pipeline",
 58            "model_path",
 59            "global_seed",
 60            "stl_dataset",
 61            "trainer",
 62            "metrics",
 63            "callbacks",
 64            "output_dir",
 65            # "hydra" is excluded as it doesn't appear
 66            "misc",
 67        ]
 68        for field in required_config_fields:
 69            if not self.cfg.get(field):
 70                raise KeyError(
 71                    f"Field `{field}` is required in the experiment index config."
 72                )
 73
 74    def instantiate_stl_dataset(
 75        self,
 76        stl_dataset_cfg: DictConfig,
 77    ) -> None:
 78        r"""Instantiate the STL dataset object from `stl_dataset_cfg`."""
 79        pylogger.debug(
 80            "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...",
 81            stl_dataset_cfg.get("_target_"),
 82        )
 83        self.stl_dataset = hydra.utils.instantiate(
 84            stl_dataset_cfg,
 85        )
 86        pylogger.debug(
 87            "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!",
 88            stl_dataset_cfg.get("_target_"),
 89        )
 90
 91    def instantiate_callbacks(
 92        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
 93    ) -> None:
 94        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
 95        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
 96
 97        # instantiate metric callbacks
 98        metric_callbacks = [
 99            hydra.utils.instantiate(callback) for callback in metrics_cfg
100        ]
101
102        # instantiate other callbacks
103        other_callbacks = [
104            hydra.utils.instantiate(callback) for callback in callbacks_cfg
105        ]
106
107        # add metric callbacks to the list of callbacks
108        self.callbacks = metric_callbacks + other_callbacks
109        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
110
111    def instantiate_trainer(
112        self,
113        trainer_cfg: DictConfig,
114        callbacks: list[Callback],
115    ) -> None:
116        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
117
118        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
119        self.trainer = hydra.utils.instantiate(
120            trainer_cfg, callbacks=callbacks
121        )
122        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
123
124    def set_global_seed(self, global_seed: int) -> None:
125        r"""Set the `global_seed` for the entire evaluation."""
126        L.seed_everything(self.global_seed, workers=True)
127        pylogger.debug("Global seed is set as %d.", global_seed)
128
129    def run(self) -> None:
130        r"""The main method to run the single-task learning experiment."""
131        self.set_global_seed(self.global_seed)
132
133        # load the model from file
134        model = torch.load(self.model_path)
135
136        self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset)
137        self.instantiate_callbacks(
138            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
139        )
140        self.instantiate_trainer(
141            trainer_cfg=self.cfg.trainer,
142            callbacks=self.callbacks,
143        )  # trainer should be instantiated after callbacks
144
145        # setup task for dataset
146        self.stl_dataset.setup_task()
147
148        # evaluation skipping training and validation
149        self.trainer.test(
150            model=model,
151            datamodule=self.stl_dataset,
152        )
class STLEvaluation:
 22class STLEvaluation:
 23    r"""The base class for single-task learning evaluation."""
 24
 25    def __init__(self, cfg: DictConfig) -> None:
 26        r"""
 27        **Args:**
 28        - **cfg** (`DictConfig`): the config dict for the single-task learning evaluation.
 29        """
 30        self.cfg: DictConfig = cfg
 31        r"""The complete config dict."""
 32
 33        STLEvaluation.sanity_check(self)
 34
 35        # required config fields
 36        self.global_seed: int = cfg.global_seed
 37        r"""The global seed for the entire experiment."""
 38        self.output_dir: str = cfg.output_dir
 39        r"""The folder for storing the experiment results."""
 40        self.model_path: str = cfg.model_path
 41        r"""The file path of the model to evaluate."""
 42
 43        # components
 44        self.stl_dataset: STLDataset
 45        r"""STL dataset object."""
 46        self.lightning_loggers: list[Logger]
 47        r"""Lightning logger objects."""
 48        self.callbacks: list[Callback]
 49        r"""Callback objects."""
 50        self.trainer: Trainer
 51        r"""Trainer object."""
 52
 53    def sanity_check(self) -> None:
 54        r"""Sanity check for config."""
 55
 56        # check required config fields
 57        required_config_fields = [
 58            "pipeline",
 59            "model_path",
 60            "global_seed",
 61            "stl_dataset",
 62            "trainer",
 63            "metrics",
 64            "callbacks",
 65            "output_dir",
 66            # "hydra" is excluded as it doesn't appear
 67            "misc",
 68        ]
 69        for field in required_config_fields:
 70            if not self.cfg.get(field):
 71                raise KeyError(
 72                    f"Field `{field}` is required in the experiment index config."
 73                )
 74
 75    def instantiate_stl_dataset(
 76        self,
 77        stl_dataset_cfg: DictConfig,
 78    ) -> None:
 79        r"""Instantiate the STL dataset object from `stl_dataset_cfg`."""
 80        pylogger.debug(
 81            "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...",
 82            stl_dataset_cfg.get("_target_"),
 83        )
 84        self.stl_dataset = hydra.utils.instantiate(
 85            stl_dataset_cfg,
 86        )
 87        pylogger.debug(
 88            "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!",
 89            stl_dataset_cfg.get("_target_"),
 90        )
 91
 92    def instantiate_callbacks(
 93        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
 94    ) -> None:
 95        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
 96        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
 97
 98        # instantiate metric callbacks
 99        metric_callbacks = [
100            hydra.utils.instantiate(callback) for callback in metrics_cfg
101        ]
102
103        # instantiate other callbacks
104        other_callbacks = [
105            hydra.utils.instantiate(callback) for callback in callbacks_cfg
106        ]
107
108        # add metric callbacks to the list of callbacks
109        self.callbacks = metric_callbacks + other_callbacks
110        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
111
112    def instantiate_trainer(
113        self,
114        trainer_cfg: DictConfig,
115        callbacks: list[Callback],
116    ) -> None:
117        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
118
119        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
120        self.trainer = hydra.utils.instantiate(
121            trainer_cfg, callbacks=callbacks
122        )
123        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
124
125    def set_global_seed(self, global_seed: int) -> None:
126        r"""Set the `global_seed` for the entire evaluation."""
127        L.seed_everything(self.global_seed, workers=True)
128        pylogger.debug("Global seed is set as %d.", global_seed)
129
130    def run(self) -> None:
131        r"""The main method to run the single-task learning experiment."""
132        self.set_global_seed(self.global_seed)
133
134        # load the model from file
135        model = torch.load(self.model_path)
136
137        self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset)
138        self.instantiate_callbacks(
139            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
140        )
141        self.instantiate_trainer(
142            trainer_cfg=self.cfg.trainer,
143            callbacks=self.callbacks,
144        )  # trainer should be instantiated after callbacks
145
146        # setup task for dataset
147        self.stl_dataset.setup_task()
148
149        # evaluation skipping training and validation
150        self.trainer.test(
151            model=model,
152            datamodule=self.stl_dataset,
153        )

The base class for single-task learning evaluation.

STLEvaluation(cfg: omegaconf.dictconfig.DictConfig)
25    def __init__(self, cfg: DictConfig) -> None:
26        r"""
27        **Args:**
28        - **cfg** (`DictConfig`): the config dict for the single-task learning evaluation.
29        """
30        self.cfg: DictConfig = cfg
31        r"""The complete config dict."""
32
33        STLEvaluation.sanity_check(self)
34
35        # required config fields
36        self.global_seed: int = cfg.global_seed
37        r"""The global seed for the entire experiment."""
38        self.output_dir: str = cfg.output_dir
39        r"""The folder for storing the experiment results."""
40        self.model_path: str = cfg.model_path
41        r"""The file path of the model to evaluate."""
42
43        # components
44        self.stl_dataset: STLDataset
45        r"""STL dataset object."""
46        self.lightning_loggers: list[Logger]
47        r"""Lightning logger objects."""
48        self.callbacks: list[Callback]
49        r"""Callback objects."""
50        self.trainer: Trainer
51        r"""Trainer object."""

Args:

  • cfg (DictConfig): the config dict for the single-task learning evaluation.
cfg: omegaconf.dictconfig.DictConfig

The complete config dict.

global_seed: int

The global seed for the entire experiment.

output_dir: str

The folder for storing the experiment results.

model_path: str

The file path of the model to evaluate.

STL dataset object.

lightning_loggers: list[lightning.pytorch.loggers.logger.Logger]

Lightning logger objects.

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

Callback objects.

trainer: lightning.pytorch.trainer.trainer.Trainer

Trainer object.

def sanity_check(self) -> None:
53    def sanity_check(self) -> None:
54        r"""Sanity check for config."""
55
56        # check required config fields
57        required_config_fields = [
58            "pipeline",
59            "model_path",
60            "global_seed",
61            "stl_dataset",
62            "trainer",
63            "metrics",
64            "callbacks",
65            "output_dir",
66            # "hydra" is excluded as it doesn't appear
67            "misc",
68        ]
69        for field in required_config_fields:
70            if not self.cfg.get(field):
71                raise KeyError(
72                    f"Field `{field}` is required in the experiment index config."
73                )

Sanity check for config.

def instantiate_stl_dataset(self, stl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
75    def instantiate_stl_dataset(
76        self,
77        stl_dataset_cfg: DictConfig,
78    ) -> None:
79        r"""Instantiate the STL dataset object from `stl_dataset_cfg`."""
80        pylogger.debug(
81            "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...",
82            stl_dataset_cfg.get("_target_"),
83        )
84        self.stl_dataset = hydra.utils.instantiate(
85            stl_dataset_cfg,
86        )
87        pylogger.debug(
88            "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!",
89            stl_dataset_cfg.get("_target_"),
90        )

Instantiate the STL dataset object from stl_dataset_cfg.

def instantiate_callbacks( self, metrics_cfg: omegaconf.dictconfig.DictConfig, callbacks_cfg: omegaconf.dictconfig.DictConfig) -> None:
 92    def instantiate_callbacks(
 93        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
 94    ) -> None:
 95        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
 96        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
 97
 98        # instantiate metric callbacks
 99        metric_callbacks = [
100            hydra.utils.instantiate(callback) for callback in metrics_cfg
101        ]
102
103        # instantiate other callbacks
104        other_callbacks = [
105            hydra.utils.instantiate(callback) for callback in callbacks_cfg
106        ]
107
108        # add metric callbacks to the list of callbacks
109        self.callbacks = metric_callbacks + other_callbacks
110        pylogger.debug("Callbacks (lightning.Callback) instantiated!")

Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.

def instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
112    def instantiate_trainer(
113        self,
114        trainer_cfg: DictConfig,
115        callbacks: list[Callback],
116    ) -> None:
117        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
118
119        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
120        self.trainer = hydra.utils.instantiate(
121            trainer_cfg, callbacks=callbacks
122        )
123        pylogger.debug("Trainer (lightning.Trainer) instantiated!")

Instantiate the trainer object from trainer_cfg, lightning_loggers, and callbacks.

def set_global_seed(self, global_seed: int) -> None:
125    def set_global_seed(self, global_seed: int) -> None:
126        r"""Set the `global_seed` for the entire evaluation."""
127        L.seed_everything(self.global_seed, workers=True)
128        pylogger.debug("Global seed is set as %d.", global_seed)

Set the global_seed for the entire evaluation.

def run(self) -> None:
130    def run(self) -> None:
131        r"""The main method to run the single-task learning experiment."""
132        self.set_global_seed(self.global_seed)
133
134        # load the model from file
135        model = torch.load(self.model_path)
136
137        self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset)
138        self.instantiate_callbacks(
139            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
140        )
141        self.instantiate_trainer(
142            trainer_cfg=self.cfg.trainer,
143            callbacks=self.callbacks,
144        )  # trainer should be instantiated after callbacks
145
146        # setup task for dataset
147        self.stl_dataset.setup_task()
148
149        # evaluation skipping training and validation
150        self.trainer.test(
151            model=model,
152            datamodule=self.stl_dataset,
153        )

The main method to run the single-task learning experiment.