clarena.pipelines.stl_expr

The submodule in pipelines for single-task learning experiment.

  1r"""
  2The submodule in `pipelines` for single-task learning experiment.
  3
  4"""
  5
  6__all__ = ["STLExperiment"]
  7
  8import logging
  9from typing import Any
 10
 11import hydra
 12import lightning as L
 13from lightning import Callback, Trainer
 14from lightning.pytorch.loggers import Logger
 15from omegaconf import DictConfig
 16from torch.optim import Optimizer
 17from torch.optim.lr_scheduler import LRScheduler
 18
 19from clarena.backbones import Backbone
 20from clarena.heads import HeadSTL
 21from clarena.stl_algorithms import STLAlgorithm
 22from clarena.stl_datasets import STLDataset
 23from clarena.utils.cfg import select_hyperparameters_from_config
 24
 25# always get logger for built-in logging in each module
 26pylogger = logging.getLogger(__name__)
 27
 28
 29class STLExperiment:
 30    r"""The base class for single-task learning experiment."""
 31
 32    def __init__(self, cfg: DictConfig) -> None:
 33        r"""
 34        **Args:**
 35        - **cfg** (`DictConfig`): the complete config dict for the single-task learning experiment.
 36        """
 37        self.cfg: DictConfig = cfg
 38        r"""The complete config dict."""
 39
 40        STLExperiment.sanity_check(self)
 41
 42        # required config fields
 43        self.eval: bool = cfg.eval
 44        r"""Whether to include evaluation phase."""
 45        self.global_seed: int = cfg.global_seed
 46        r"""The global seed for the entire experiment."""
 47
 48        # components
 49        self.stl_dataset: STLDataset
 50        r"""STL dataset object."""
 51        self.backbone: Backbone
 52        r"""Backbone network object."""
 53        self.head: HeadSTL
 54        r"""STL output heads object."""
 55        self.model: STLAlgorithm
 56        r"""STL model object."""
 57        self.optimizer: Optimizer
 58        r"""Optimizer object."""
 59        self.lr_scheduler: LRScheduler | None
 60        r"""Learning rate scheduler object."""
 61        self.lightning_loggers: list[Logger]
 62        r"""The list of initialized lightning loggers objects."""
 63        self.callbacks: list[Callback]
 64        r"""The list of initialized callbacks objects."""
 65        self.trainer: Trainer
 66        r"""Trainer object."""
 67
 68    def sanity_check(self) -> None:
 69        r"""Sanity check for config."""
 70
 71        # check required config fields
 72        required_config_fields = [
 73            "pipeline",
 74            "expr_name",
 75            "global_seed",
 76            "stl_dataset",
 77            "stl_algorithm",
 78            "backbone",
 79            "optimizer",
 80            "lr_scheduler",
 81            "trainer",
 82            "metrics",
 83            "lightning_loggers",
 84            "callbacks",
 85            "output_dir",
 86            # "hydra" is excluded as it doesn't appear
 87            "misc",
 88        ]
 89        for field in required_config_fields:
 90            if not self.cfg.get(field):
 91                raise KeyError(
 92                    f"Field `{field}` is required in the experiment index config."
 93                )
 94
 95    def instantiate_stl_dataset(
 96        self,
 97        stl_dataset_cfg: DictConfig,
 98    ) -> None:
 99        r"""Instantiate the STL dataset object from `stl_dataset_cfg`."""
100        pylogger.debug(
101            "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...",
102            stl_dataset_cfg.get("_target_"),
103        )
104        self.stl_dataset = hydra.utils.instantiate(
105            stl_dataset_cfg,
106        )
107        pylogger.debug(
108            "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!",
109            stl_dataset_cfg.get("_target_"),
110        )
111
112    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
113        r"""Instantiate the MTL backbone network object from `backbone_cfg`."""
114        pylogger.debug(
115            "Instantiating backbone network <%s> (clarena.backbones.Backbone)...",
116            backbone_cfg.get("_target_"),
117        )
118        self.backbone = hydra.utils.instantiate(backbone_cfg)
119        pylogger.debug(
120            "Backbone network <%s> (clarena.backbones.Backbone) instantiated!",
121            backbone_cfg.get("_target_"),
122        )
123
124    def instantiate_head(self, input_dim: int) -> None:
125        r"""Instantiate the STL output head object.
126
127        **Args:**
128        - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone.
129        """
130        pylogger.debug(
131            "Instantiating STL head...",
132        )
133        self.head = HeadSTL(input_dim=input_dim)
134        pylogger.debug("STL head instantiated! ")
135
136    def instantiate_stl_algorithm(
137        self,
138        stl_algorithm_cfg: DictConfig,
139        backbone: Backbone,
140        head: HeadSTL,
141        non_algorithmic_hparams: dict[str, Any],
142    ) -> None:
143        r"""Instantiate the stl_algorithm object from `stl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
144        pylogger.debug(
145            "STL algorithm is set as <%s>. Instantiating <%s> (clarena.stl_algorithms.STLAlgorithm)...",
146            stl_algorithm_cfg.get("_target_"),
147            stl_algorithm_cfg.get("_target_"),
148        )
149        self.model = hydra.utils.instantiate(
150            stl_algorithm_cfg,
151            backbone=backbone,
152            head=head,
153            non_algorithmic_hparams=non_algorithmic_hparams,
154        )
155        pylogger.debug(
156            "<%s> (clarena.stl_algorithms.STLAlgorithm) instantiated!",
157            stl_algorithm_cfg.get("_target_"),
158        )
159
160    def instantiate_optimizer(
161        self,
162        optimizer_cfg: DictConfig,
163    ) -> None:
164        r"""Instantiate the optimizer object from `optimizer_cfg`."""
165
166        # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
167        pylogger.debug(
168            "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...",
169            optimizer_cfg.get("_target_"),
170        )
171        self.optimizer = hydra.utils.instantiate(optimizer_cfg)
172        pylogger.debug(
173            "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!",
174            optimizer_cfg.get("_target_"),
175        )
176
177    def instantiate_lr_scheduler(
178        self,
179        lr_scheduler_cfg: DictConfig,
180    ) -> None:
181        r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`."""
182
183        # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now.
184        pylogger.debug(
185            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...",
186            lr_scheduler_cfg.get("_target_"),
187        )
188        self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg)
189        pylogger.debug(
190            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!",
191            lr_scheduler_cfg.get("_target_"),
192        )
193
194    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
195        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
196        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
197        self.lightning_loggers = [
198            hydra.utils.instantiate(lightning_logger)
199            for lightning_logger in lightning_loggers_cfg.values()
200        ]
201        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
202
203    def instantiate_callbacks(
204        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
205    ) -> None:
206        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
207        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
208
209        # instantiate metric callbacks
210        metric_callbacks = [
211            hydra.utils.instantiate(callback) for callback in metrics_cfg
212        ]
213
214        # instantiate other callbacks
215        other_callbacks = [
216            hydra.utils.instantiate(callback) for callback in callbacks_cfg
217        ]
218
219        # add metric callbacks to the list of callbacks
220        self.callbacks = metric_callbacks + other_callbacks
221        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
222
223    def instantiate_trainer(
224        self,
225        trainer_cfg: DictConfig,
226        lightning_loggers: list[Logger],
227        callbacks: list[Callback],
228    ) -> None:
229        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
230
231        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
232        self.trainer = hydra.utils.instantiate(
233            trainer_cfg, logger=lightning_loggers, callbacks=callbacks
234        )
235        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
236
237    def set_global_seed(self, global_seed: int) -> None:
238        r"""Set the `global_seed` for the entire experiment."""
239        L.seed_everything(self.global_seed, workers=True)
240        pylogger.debug("Global seed is set as %d.", global_seed)
241
242    def run(self) -> None:
243        r"""The main method to run the single-task learning experiment."""
244        self.set_global_seed(self.global_seed)
245
246        self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset)
247        self.instantiate_backbone(backbone_cfg=self.cfg.backbone)
248        self.instantiate_head(input_dim=self.cfg.backbone.output_dim)
249        self.instantiate_stl_algorithm(
250            stl_algorithm_cfg=self.cfg.stl_algorithm,
251            backbone=self.backbone,
252            head=self.head,
253            non_algorithmic_hparams=select_hyperparameters_from_config(
254                cfg=self.cfg, type=self.cfg.pipeline
255            ),
256        )  # stl_algorithm should be instantiated after backbone and heads
257        self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer)
258        self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler)
259        self.instantiate_lightning_loggers(
260            lightning_loggers_cfg=self.cfg.lightning_loggers
261        )
262        self.instantiate_callbacks(
263            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
264        )
265        self.instantiate_trainer(
266            trainer_cfg=self.cfg.trainer,
267            lightning_loggers=self.lightning_loggers,
268            callbacks=self.callbacks,
269        )  # trainer should be instantiated after loggers and callbacks
270
271        # setup task for dataset and model
272        self.stl_dataset.setup_task()
273        t = self.stl_dataset.get_class_map()
274        print(t)
275        self.model.setup_task(
276            num_classes=len(self.stl_dataset.get_class_map()),
277            optimizer=self.optimizer,
278            lr_scheduler=self.lr_scheduler,
279        )
280
281        # fit the model on the STL dataset
282        self.trainer.fit(
283            model=self.model,
284            datamodule=self.stl_dataset,
285        )
286
287        # evaluation after training and validation
288        self.trainer.test(
289            model=self.model,
290            datamodule=self.stl_dataset,
291        )
class STLExperiment:
 30class STLExperiment:
 31    r"""The base class for single-task learning experiment."""
 32
 33    def __init__(self, cfg: DictConfig) -> None:
 34        r"""
 35        **Args:**
 36        - **cfg** (`DictConfig`): the complete config dict for the single-task learning experiment.
 37        """
 38        self.cfg: DictConfig = cfg
 39        r"""The complete config dict."""
 40
 41        STLExperiment.sanity_check(self)
 42
 43        # required config fields
 44        self.eval: bool = cfg.eval
 45        r"""Whether to include evaluation phase."""
 46        self.global_seed: int = cfg.global_seed
 47        r"""The global seed for the entire experiment."""
 48
 49        # components
 50        self.stl_dataset: STLDataset
 51        r"""STL dataset object."""
 52        self.backbone: Backbone
 53        r"""Backbone network object."""
 54        self.head: HeadSTL
 55        r"""STL output heads object."""
 56        self.model: STLAlgorithm
 57        r"""STL model object."""
 58        self.optimizer: Optimizer
 59        r"""Optimizer object."""
 60        self.lr_scheduler: LRScheduler | None
 61        r"""Learning rate scheduler object."""
 62        self.lightning_loggers: list[Logger]
 63        r"""The list of initialized lightning loggers objects."""
 64        self.callbacks: list[Callback]
 65        r"""The list of initialized callbacks objects."""
 66        self.trainer: Trainer
 67        r"""Trainer object."""
 68
 69    def sanity_check(self) -> None:
 70        r"""Sanity check for config."""
 71
 72        # check required config fields
 73        required_config_fields = [
 74            "pipeline",
 75            "expr_name",
 76            "global_seed",
 77            "stl_dataset",
 78            "stl_algorithm",
 79            "backbone",
 80            "optimizer",
 81            "lr_scheduler",
 82            "trainer",
 83            "metrics",
 84            "lightning_loggers",
 85            "callbacks",
 86            "output_dir",
 87            # "hydra" is excluded as it doesn't appear
 88            "misc",
 89        ]
 90        for field in required_config_fields:
 91            if not self.cfg.get(field):
 92                raise KeyError(
 93                    f"Field `{field}` is required in the experiment index config."
 94                )
 95
 96    def instantiate_stl_dataset(
 97        self,
 98        stl_dataset_cfg: DictConfig,
 99    ) -> None:
100        r"""Instantiate the STL dataset object from `stl_dataset_cfg`."""
101        pylogger.debug(
102            "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...",
103            stl_dataset_cfg.get("_target_"),
104        )
105        self.stl_dataset = hydra.utils.instantiate(
106            stl_dataset_cfg,
107        )
108        pylogger.debug(
109            "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!",
110            stl_dataset_cfg.get("_target_"),
111        )
112
113    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
114        r"""Instantiate the MTL backbone network object from `backbone_cfg`."""
115        pylogger.debug(
116            "Instantiating backbone network <%s> (clarena.backbones.Backbone)...",
117            backbone_cfg.get("_target_"),
118        )
119        self.backbone = hydra.utils.instantiate(backbone_cfg)
120        pylogger.debug(
121            "Backbone network <%s> (clarena.backbones.Backbone) instantiated!",
122            backbone_cfg.get("_target_"),
123        )
124
125    def instantiate_head(self, input_dim: int) -> None:
126        r"""Instantiate the STL output head object.
127
128        **Args:**
129        - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone.
130        """
131        pylogger.debug(
132            "Instantiating STL head...",
133        )
134        self.head = HeadSTL(input_dim=input_dim)
135        pylogger.debug("STL head instantiated! ")
136
137    def instantiate_stl_algorithm(
138        self,
139        stl_algorithm_cfg: DictConfig,
140        backbone: Backbone,
141        head: HeadSTL,
142        non_algorithmic_hparams: dict[str, Any],
143    ) -> None:
144        r"""Instantiate the stl_algorithm object from `stl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
145        pylogger.debug(
146            "STL algorithm is set as <%s>. Instantiating <%s> (clarena.stl_algorithms.STLAlgorithm)...",
147            stl_algorithm_cfg.get("_target_"),
148            stl_algorithm_cfg.get("_target_"),
149        )
150        self.model = hydra.utils.instantiate(
151            stl_algorithm_cfg,
152            backbone=backbone,
153            head=head,
154            non_algorithmic_hparams=non_algorithmic_hparams,
155        )
156        pylogger.debug(
157            "<%s> (clarena.stl_algorithms.STLAlgorithm) instantiated!",
158            stl_algorithm_cfg.get("_target_"),
159        )
160
161    def instantiate_optimizer(
162        self,
163        optimizer_cfg: DictConfig,
164    ) -> None:
165        r"""Instantiate the optimizer object from `optimizer_cfg`."""
166
167        # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
168        pylogger.debug(
169            "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...",
170            optimizer_cfg.get("_target_"),
171        )
172        self.optimizer = hydra.utils.instantiate(optimizer_cfg)
173        pylogger.debug(
174            "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!",
175            optimizer_cfg.get("_target_"),
176        )
177
178    def instantiate_lr_scheduler(
179        self,
180        lr_scheduler_cfg: DictConfig,
181    ) -> None:
182        r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`."""
183
184        # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now.
185        pylogger.debug(
186            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...",
187            lr_scheduler_cfg.get("_target_"),
188        )
189        self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg)
190        pylogger.debug(
191            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!",
192            lr_scheduler_cfg.get("_target_"),
193        )
194
195    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
196        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
197        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
198        self.lightning_loggers = [
199            hydra.utils.instantiate(lightning_logger)
200            for lightning_logger in lightning_loggers_cfg.values()
201        ]
202        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
203
204    def instantiate_callbacks(
205        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
206    ) -> None:
207        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
208        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
209
210        # instantiate metric callbacks
211        metric_callbacks = [
212            hydra.utils.instantiate(callback) for callback in metrics_cfg
213        ]
214
215        # instantiate other callbacks
216        other_callbacks = [
217            hydra.utils.instantiate(callback) for callback in callbacks_cfg
218        ]
219
220        # add metric callbacks to the list of callbacks
221        self.callbacks = metric_callbacks + other_callbacks
222        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
223
224    def instantiate_trainer(
225        self,
226        trainer_cfg: DictConfig,
227        lightning_loggers: list[Logger],
228        callbacks: list[Callback],
229    ) -> None:
230        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
231
232        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
233        self.trainer = hydra.utils.instantiate(
234            trainer_cfg, logger=lightning_loggers, callbacks=callbacks
235        )
236        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
237
238    def set_global_seed(self, global_seed: int) -> None:
239        r"""Set the `global_seed` for the entire experiment."""
240        L.seed_everything(self.global_seed, workers=True)
241        pylogger.debug("Global seed is set as %d.", global_seed)
242
243    def run(self) -> None:
244        r"""The main method to run the single-task learning experiment."""
245        self.set_global_seed(self.global_seed)
246
247        self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset)
248        self.instantiate_backbone(backbone_cfg=self.cfg.backbone)
249        self.instantiate_head(input_dim=self.cfg.backbone.output_dim)
250        self.instantiate_stl_algorithm(
251            stl_algorithm_cfg=self.cfg.stl_algorithm,
252            backbone=self.backbone,
253            head=self.head,
254            non_algorithmic_hparams=select_hyperparameters_from_config(
255                cfg=self.cfg, type=self.cfg.pipeline
256            ),
257        )  # stl_algorithm should be instantiated after backbone and heads
258        self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer)
259        self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler)
260        self.instantiate_lightning_loggers(
261            lightning_loggers_cfg=self.cfg.lightning_loggers
262        )
263        self.instantiate_callbacks(
264            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
265        )
266        self.instantiate_trainer(
267            trainer_cfg=self.cfg.trainer,
268            lightning_loggers=self.lightning_loggers,
269            callbacks=self.callbacks,
270        )  # trainer should be instantiated after loggers and callbacks
271
272        # setup task for dataset and model
273        self.stl_dataset.setup_task()
274        t = self.stl_dataset.get_class_map()
275        print(t)
276        self.model.setup_task(
277            num_classes=len(self.stl_dataset.get_class_map()),
278            optimizer=self.optimizer,
279            lr_scheduler=self.lr_scheduler,
280        )
281
282        # fit the model on the STL dataset
283        self.trainer.fit(
284            model=self.model,
285            datamodule=self.stl_dataset,
286        )
287
288        # evaluation after training and validation
289        self.trainer.test(
290            model=self.model,
291            datamodule=self.stl_dataset,
292        )

The base class for single-task learning experiment.

STLExperiment(cfg: omegaconf.dictconfig.DictConfig)
33    def __init__(self, cfg: DictConfig) -> None:
34        r"""
35        **Args:**
36        - **cfg** (`DictConfig`): the complete config dict for the single-task learning experiment.
37        """
38        self.cfg: DictConfig = cfg
39        r"""The complete config dict."""
40
41        STLExperiment.sanity_check(self)
42
43        # required config fields
44        self.eval: bool = cfg.eval
45        r"""Whether to include evaluation phase."""
46        self.global_seed: int = cfg.global_seed
47        r"""The global seed for the entire experiment."""
48
49        # components
50        self.stl_dataset: STLDataset
51        r"""STL dataset object."""
52        self.backbone: Backbone
53        r"""Backbone network object."""
54        self.head: HeadSTL
55        r"""STL output heads object."""
56        self.model: STLAlgorithm
57        r"""STL model object."""
58        self.optimizer: Optimizer
59        r"""Optimizer object."""
60        self.lr_scheduler: LRScheduler | None
61        r"""Learning rate scheduler object."""
62        self.lightning_loggers: list[Logger]
63        r"""The list of initialized lightning loggers objects."""
64        self.callbacks: list[Callback]
65        r"""The list of initialized callbacks objects."""
66        self.trainer: Trainer
67        r"""Trainer object."""

Args:

  • cfg (DictConfig): the complete config dict for the single-task learning experiment.
cfg: omegaconf.dictconfig.DictConfig

The complete config dict.

eval: bool

Whether to include evaluation phase.

global_seed: int

The global seed for the entire experiment.

STL dataset object.

Backbone network object.

STL output heads object.

STL model object.

optimizer: torch.optim.optimizer.Optimizer

Optimizer object.

lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None

Learning rate scheduler object.

lightning_loggers: list[lightning.pytorch.loggers.logger.Logger]

The list of initialized lightning loggers objects.

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

The list of initialized callbacks objects.

trainer: lightning.pytorch.trainer.trainer.Trainer

Trainer object.

def sanity_check(self) -> None:
69    def sanity_check(self) -> None:
70        r"""Sanity check for config."""
71
72        # check required config fields
73        required_config_fields = [
74            "pipeline",
75            "expr_name",
76            "global_seed",
77            "stl_dataset",
78            "stl_algorithm",
79            "backbone",
80            "optimizer",
81            "lr_scheduler",
82            "trainer",
83            "metrics",
84            "lightning_loggers",
85            "callbacks",
86            "output_dir",
87            # "hydra" is excluded as it doesn't appear
88            "misc",
89        ]
90        for field in required_config_fields:
91            if not self.cfg.get(field):
92                raise KeyError(
93                    f"Field `{field}` is required in the experiment index config."
94                )

Sanity check for config.

def instantiate_stl_dataset(self, stl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
 96    def instantiate_stl_dataset(
 97        self,
 98        stl_dataset_cfg: DictConfig,
 99    ) -> None:
100        r"""Instantiate the STL dataset object from `stl_dataset_cfg`."""
101        pylogger.debug(
102            "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...",
103            stl_dataset_cfg.get("_target_"),
104        )
105        self.stl_dataset = hydra.utils.instantiate(
106            stl_dataset_cfg,
107        )
108        pylogger.debug(
109            "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!",
110            stl_dataset_cfg.get("_target_"),
111        )

Instantiate the STL dataset object from stl_dataset_cfg.

def instantiate_backbone(self, backbone_cfg: omegaconf.dictconfig.DictConfig) -> None:
113    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
114        r"""Instantiate the MTL backbone network object from `backbone_cfg`."""
115        pylogger.debug(
116            "Instantiating backbone network <%s> (clarena.backbones.Backbone)...",
117            backbone_cfg.get("_target_"),
118        )
119        self.backbone = hydra.utils.instantiate(backbone_cfg)
120        pylogger.debug(
121            "Backbone network <%s> (clarena.backbones.Backbone) instantiated!",
122            backbone_cfg.get("_target_"),
123        )

Instantiate the MTL backbone network object from backbone_cfg.

def instantiate_head(self, input_dim: int) -> None:
125    def instantiate_head(self, input_dim: int) -> None:
126        r"""Instantiate the STL output head object.
127
128        **Args:**
129        - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone.
130        """
131        pylogger.debug(
132            "Instantiating STL head...",
133        )
134        self.head = HeadSTL(input_dim=input_dim)
135        pylogger.debug("STL head instantiated! ")

Instantiate the STL output head object.

Args:

  • input_dim (int): the input dimension of the head. Must be equal to the output_dim of the connected backbone.
def instantiate_stl_algorithm( self, stl_algorithm_cfg: omegaconf.dictconfig.DictConfig, backbone: clarena.backbones.Backbone, head: clarena.heads.HeadSTL, non_algorithmic_hparams: dict[str, typing.Any]) -> None:
137    def instantiate_stl_algorithm(
138        self,
139        stl_algorithm_cfg: DictConfig,
140        backbone: Backbone,
141        head: HeadSTL,
142        non_algorithmic_hparams: dict[str, Any],
143    ) -> None:
144        r"""Instantiate the stl_algorithm object from `stl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
145        pylogger.debug(
146            "STL algorithm is set as <%s>. Instantiating <%s> (clarena.stl_algorithms.STLAlgorithm)...",
147            stl_algorithm_cfg.get("_target_"),
148            stl_algorithm_cfg.get("_target_"),
149        )
150        self.model = hydra.utils.instantiate(
151            stl_algorithm_cfg,
152            backbone=backbone,
153            head=head,
154            non_algorithmic_hparams=non_algorithmic_hparams,
155        )
156        pylogger.debug(
157            "<%s> (clarena.stl_algorithms.STLAlgorithm) instantiated!",
158            stl_algorithm_cfg.get("_target_"),
159        )

Instantiate the stl_algorithm object from stl_algorithm_cfg, backbone, heads and non_algorithmic_hparams.

def instantiate_optimizer(self, optimizer_cfg: omegaconf.dictconfig.DictConfig) -> None:
161    def instantiate_optimizer(
162        self,
163        optimizer_cfg: DictConfig,
164    ) -> None:
165        r"""Instantiate the optimizer object from `optimizer_cfg`."""
166
167        # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
168        pylogger.debug(
169            "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...",
170            optimizer_cfg.get("_target_"),
171        )
172        self.optimizer = hydra.utils.instantiate(optimizer_cfg)
173        pylogger.debug(
174            "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!",
175            optimizer_cfg.get("_target_"),
176        )

Instantiate the optimizer object from optimizer_cfg.

def instantiate_lr_scheduler(self, lr_scheduler_cfg: omegaconf.dictconfig.DictConfig) -> None:
178    def instantiate_lr_scheduler(
179        self,
180        lr_scheduler_cfg: DictConfig,
181    ) -> None:
182        r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`."""
183
184        # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now.
185        pylogger.debug(
186            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...",
187            lr_scheduler_cfg.get("_target_"),
188        )
189        self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg)
190        pylogger.debug(
191            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!",
192            lr_scheduler_cfg.get("_target_"),
193        )

Instantiate the learning rate scheduler object from lr_scheduler_cfg.

def instantiate_lightning_loggers(self, lightning_loggers_cfg: omegaconf.dictconfig.DictConfig) -> None:
195    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
196        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
197        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
198        self.lightning_loggers = [
199            hydra.utils.instantiate(lightning_logger)
200            for lightning_logger in lightning_loggers_cfg.values()
201        ]
202        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")

Instantiate the list of lightning loggers objects from lightning_loggers_cfg.

def instantiate_callbacks( self, metrics_cfg: omegaconf.dictconfig.DictConfig, callbacks_cfg: omegaconf.dictconfig.DictConfig) -> None:
204    def instantiate_callbacks(
205        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
206    ) -> None:
207        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
208        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
209
210        # instantiate metric callbacks
211        metric_callbacks = [
212            hydra.utils.instantiate(callback) for callback in metrics_cfg
213        ]
214
215        # instantiate other callbacks
216        other_callbacks = [
217            hydra.utils.instantiate(callback) for callback in callbacks_cfg
218        ]
219
220        # add metric callbacks to the list of callbacks
221        self.callbacks = metric_callbacks + other_callbacks
222        pylogger.debug("Callbacks (lightning.Callback) instantiated!")

Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.

def instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, lightning_loggers: list[lightning.pytorch.loggers.logger.Logger], callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
224    def instantiate_trainer(
225        self,
226        trainer_cfg: DictConfig,
227        lightning_loggers: list[Logger],
228        callbacks: list[Callback],
229    ) -> None:
230        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
231
232        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
233        self.trainer = hydra.utils.instantiate(
234            trainer_cfg, logger=lightning_loggers, callbacks=callbacks
235        )
236        pylogger.debug("Trainer (lightning.Trainer) instantiated!")

Instantiate the trainer object from trainer_cfg, lightning_loggers, and callbacks.

def set_global_seed(self, global_seed: int) -> None:
238    def set_global_seed(self, global_seed: int) -> None:
239        r"""Set the `global_seed` for the entire experiment."""
240        L.seed_everything(self.global_seed, workers=True)
241        pylogger.debug("Global seed is set as %d.", global_seed)

Set the global_seed for the entire experiment.

def run(self) -> None:
243    def run(self) -> None:
244        r"""The main method to run the single-task learning experiment."""
245        self.set_global_seed(self.global_seed)
246
247        self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset)
248        self.instantiate_backbone(backbone_cfg=self.cfg.backbone)
249        self.instantiate_head(input_dim=self.cfg.backbone.output_dim)
250        self.instantiate_stl_algorithm(
251            stl_algorithm_cfg=self.cfg.stl_algorithm,
252            backbone=self.backbone,
253            head=self.head,
254            non_algorithmic_hparams=select_hyperparameters_from_config(
255                cfg=self.cfg, type=self.cfg.pipeline
256            ),
257        )  # stl_algorithm should be instantiated after backbone and heads
258        self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer)
259        self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler)
260        self.instantiate_lightning_loggers(
261            lightning_loggers_cfg=self.cfg.lightning_loggers
262        )
263        self.instantiate_callbacks(
264            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
265        )
266        self.instantiate_trainer(
267            trainer_cfg=self.cfg.trainer,
268            lightning_loggers=self.lightning_loggers,
269            callbacks=self.callbacks,
270        )  # trainer should be instantiated after loggers and callbacks
271
272        # setup task for dataset and model
273        self.stl_dataset.setup_task()
274        t = self.stl_dataset.get_class_map()
275        print(t)
276        self.model.setup_task(
277            num_classes=len(self.stl_dataset.get_class_map()),
278            optimizer=self.optimizer,
279            lr_scheduler=self.lr_scheduler,
280        )
281
282        # fit the model on the STL dataset
283        self.trainer.fit(
284            model=self.model,
285            datamodule=self.stl_dataset,
286        )
287
288        # evaluation after training and validation
289        self.trainer.test(
290            model=self.model,
291            datamodule=self.stl_dataset,
292        )

The main method to run the single-task learning experiment.