clarena.pipelines.stl_expr
The submodule in pipelines for single-task learning experiment.
1r""" 2The submodule in `pipelines` for single-task learning experiment. 3 4""" 5 6__all__ = ["STLExperiment"] 7 8import logging 9from typing import Any 10 11import hydra 12import lightning as L 13from lightning import Callback, Trainer 14from lightning.pytorch.loggers import Logger 15from omegaconf import DictConfig 16from torch.optim import Optimizer 17from torch.optim.lr_scheduler import LRScheduler 18 19from clarena.backbones import Backbone 20from clarena.heads import HeadSTL 21from clarena.stl_algorithms import STLAlgorithm 22from clarena.stl_datasets import STLDataset 23from clarena.utils.cfg import select_hyperparameters_from_config 24 25# always get logger for built-in logging in each module 26pylogger = logging.getLogger(__name__) 27 28 29class STLExperiment: 30 r"""The base class for single-task learning experiment.""" 31 32 def __init__(self, cfg: DictConfig) -> None: 33 r""" 34 **Args:** 35 - **cfg** (`DictConfig`): the complete config dict for the single-task learning experiment. 36 """ 37 self.cfg: DictConfig = cfg 38 r"""The complete config dict.""" 39 40 STLExperiment.sanity_check(self) 41 42 # required config fields 43 self.eval: bool = cfg.eval 44 r"""Whether to include evaluation phase.""" 45 self.global_seed: int = cfg.global_seed 46 r"""The global seed for the entire experiment.""" 47 48 # components 49 self.stl_dataset: STLDataset 50 r"""STL dataset object.""" 51 self.backbone: Backbone 52 r"""Backbone network object.""" 53 self.head: HeadSTL 54 r"""STL output heads object.""" 55 self.model: STLAlgorithm 56 r"""STL model object.""" 57 self.optimizer: Optimizer 58 r"""Optimizer object.""" 59 self.lr_scheduler: LRScheduler | None 60 r"""Learning rate scheduler object.""" 61 self.lightning_loggers: list[Logger] 62 r"""The list of initialized lightning loggers objects.""" 63 self.callbacks: list[Callback] 64 r"""The list of initialized callbacks objects.""" 65 self.trainer: Trainer 66 r"""Trainer object.""" 67 68 def sanity_check(self) -> None: 69 r"""Sanity check for config.""" 70 71 # check required config fields 72 required_config_fields = [ 73 "pipeline", 74 "expr_name", 75 "global_seed", 76 "stl_dataset", 77 "stl_algorithm", 78 "backbone", 79 "optimizer", 80 "lr_scheduler", 81 "trainer", 82 "metrics", 83 "lightning_loggers", 84 "callbacks", 85 "output_dir", 86 # "hydra" is excluded as it doesn't appear 87 "misc", 88 ] 89 for field in required_config_fields: 90 if not self.cfg.get(field): 91 raise KeyError( 92 f"Field `{field}` is required in the experiment index config." 93 ) 94 95 def instantiate_stl_dataset( 96 self, 97 stl_dataset_cfg: DictConfig, 98 ) -> None: 99 r"""Instantiate the STL dataset object from `stl_dataset_cfg`.""" 100 pylogger.debug( 101 "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...", 102 stl_dataset_cfg.get("_target_"), 103 ) 104 self.stl_dataset = hydra.utils.instantiate( 105 stl_dataset_cfg, 106 ) 107 pylogger.debug( 108 "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!", 109 stl_dataset_cfg.get("_target_"), 110 ) 111 112 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 113 r"""Instantiate the MTL backbone network object from `backbone_cfg`.""" 114 pylogger.debug( 115 "Instantiating backbone network <%s> (clarena.backbones.Backbone)...", 116 backbone_cfg.get("_target_"), 117 ) 118 self.backbone = hydra.utils.instantiate(backbone_cfg) 119 pylogger.debug( 120 "Backbone network <%s> (clarena.backbones.Backbone) instantiated!", 121 backbone_cfg.get("_target_"), 122 ) 123 124 def instantiate_head(self, input_dim: int) -> None: 125 r"""Instantiate the STL output head object. 126 127 **Args:** 128 - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone. 129 """ 130 pylogger.debug( 131 "Instantiating STL head...", 132 ) 133 self.head = HeadSTL(input_dim=input_dim) 134 pylogger.debug("STL head instantiated! ") 135 136 def instantiate_stl_algorithm( 137 self, 138 stl_algorithm_cfg: DictConfig, 139 backbone: Backbone, 140 head: HeadSTL, 141 non_algorithmic_hparams: dict[str, Any], 142 ) -> None: 143 r"""Instantiate the stl_algorithm object from `stl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 144 pylogger.debug( 145 "STL algorithm is set as <%s>. Instantiating <%s> (clarena.stl_algorithms.STLAlgorithm)...", 146 stl_algorithm_cfg.get("_target_"), 147 stl_algorithm_cfg.get("_target_"), 148 ) 149 self.model = hydra.utils.instantiate( 150 stl_algorithm_cfg, 151 backbone=backbone, 152 head=head, 153 non_algorithmic_hparams=non_algorithmic_hparams, 154 ) 155 pylogger.debug( 156 "<%s> (clarena.stl_algorithms.STLAlgorithm) instantiated!", 157 stl_algorithm_cfg.get("_target_"), 158 ) 159 160 def instantiate_optimizer( 161 self, 162 optimizer_cfg: DictConfig, 163 ) -> None: 164 r"""Instantiate the optimizer object from `optimizer_cfg`.""" 165 166 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 167 pylogger.debug( 168 "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...", 169 optimizer_cfg.get("_target_"), 170 ) 171 self.optimizer = hydra.utils.instantiate(optimizer_cfg) 172 pylogger.debug( 173 "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!", 174 optimizer_cfg.get("_target_"), 175 ) 176 177 def instantiate_lr_scheduler( 178 self, 179 lr_scheduler_cfg: DictConfig, 180 ) -> None: 181 r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`.""" 182 183 # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now. 184 pylogger.debug( 185 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...", 186 lr_scheduler_cfg.get("_target_"), 187 ) 188 self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg) 189 pylogger.debug( 190 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!", 191 lr_scheduler_cfg.get("_target_"), 192 ) 193 194 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 195 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 196 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 197 self.lightning_loggers = [ 198 hydra.utils.instantiate(lightning_logger) 199 for lightning_logger in lightning_loggers_cfg.values() 200 ] 201 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!") 202 203 def instantiate_callbacks( 204 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 205 ) -> None: 206 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 207 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 208 209 # instantiate metric callbacks 210 metric_callbacks = [ 211 hydra.utils.instantiate(callback) for callback in metrics_cfg 212 ] 213 214 # instantiate other callbacks 215 other_callbacks = [ 216 hydra.utils.instantiate(callback) for callback in callbacks_cfg 217 ] 218 219 # add metric callbacks to the list of callbacks 220 self.callbacks = metric_callbacks + other_callbacks 221 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 222 223 def instantiate_trainer( 224 self, 225 trainer_cfg: DictConfig, 226 lightning_loggers: list[Logger], 227 callbacks: list[Callback], 228 ) -> None: 229 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 230 231 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 232 self.trainer = hydra.utils.instantiate( 233 trainer_cfg, logger=lightning_loggers, callbacks=callbacks 234 ) 235 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 236 237 def set_global_seed(self, global_seed: int) -> None: 238 r"""Set the `global_seed` for the entire experiment.""" 239 L.seed_everything(self.global_seed, workers=True) 240 pylogger.debug("Global seed is set as %d.", global_seed) 241 242 def run(self) -> None: 243 r"""The main method to run the single-task learning experiment.""" 244 self.set_global_seed(self.global_seed) 245 246 self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset) 247 self.instantiate_backbone(backbone_cfg=self.cfg.backbone) 248 self.instantiate_head(input_dim=self.cfg.backbone.output_dim) 249 self.instantiate_stl_algorithm( 250 stl_algorithm_cfg=self.cfg.stl_algorithm, 251 backbone=self.backbone, 252 head=self.head, 253 non_algorithmic_hparams=select_hyperparameters_from_config( 254 cfg=self.cfg, type=self.cfg.pipeline 255 ), 256 ) # stl_algorithm should be instantiated after backbone and heads 257 self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer) 258 self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler) 259 self.instantiate_lightning_loggers( 260 lightning_loggers_cfg=self.cfg.lightning_loggers 261 ) 262 self.instantiate_callbacks( 263 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 264 ) 265 self.instantiate_trainer( 266 trainer_cfg=self.cfg.trainer, 267 lightning_loggers=self.lightning_loggers, 268 callbacks=self.callbacks, 269 ) # trainer should be instantiated after loggers and callbacks 270 271 # setup task for dataset and model 272 self.stl_dataset.setup_task() 273 t = self.stl_dataset.get_class_map() 274 print(t) 275 self.model.setup_task( 276 num_classes=len(self.stl_dataset.get_class_map()), 277 optimizer=self.optimizer, 278 lr_scheduler=self.lr_scheduler, 279 ) 280 281 # fit the model on the STL dataset 282 self.trainer.fit( 283 model=self.model, 284 datamodule=self.stl_dataset, 285 ) 286 287 # evaluation after training and validation 288 self.trainer.test( 289 model=self.model, 290 datamodule=self.stl_dataset, 291 )
30class STLExperiment: 31 r"""The base class for single-task learning experiment.""" 32 33 def __init__(self, cfg: DictConfig) -> None: 34 r""" 35 **Args:** 36 - **cfg** (`DictConfig`): the complete config dict for the single-task learning experiment. 37 """ 38 self.cfg: DictConfig = cfg 39 r"""The complete config dict.""" 40 41 STLExperiment.sanity_check(self) 42 43 # required config fields 44 self.eval: bool = cfg.eval 45 r"""Whether to include evaluation phase.""" 46 self.global_seed: int = cfg.global_seed 47 r"""The global seed for the entire experiment.""" 48 49 # components 50 self.stl_dataset: STLDataset 51 r"""STL dataset object.""" 52 self.backbone: Backbone 53 r"""Backbone network object.""" 54 self.head: HeadSTL 55 r"""STL output heads object.""" 56 self.model: STLAlgorithm 57 r"""STL model object.""" 58 self.optimizer: Optimizer 59 r"""Optimizer object.""" 60 self.lr_scheduler: LRScheduler | None 61 r"""Learning rate scheduler object.""" 62 self.lightning_loggers: list[Logger] 63 r"""The list of initialized lightning loggers objects.""" 64 self.callbacks: list[Callback] 65 r"""The list of initialized callbacks objects.""" 66 self.trainer: Trainer 67 r"""Trainer object.""" 68 69 def sanity_check(self) -> None: 70 r"""Sanity check for config.""" 71 72 # check required config fields 73 required_config_fields = [ 74 "pipeline", 75 "expr_name", 76 "global_seed", 77 "stl_dataset", 78 "stl_algorithm", 79 "backbone", 80 "optimizer", 81 "lr_scheduler", 82 "trainer", 83 "metrics", 84 "lightning_loggers", 85 "callbacks", 86 "output_dir", 87 # "hydra" is excluded as it doesn't appear 88 "misc", 89 ] 90 for field in required_config_fields: 91 if not self.cfg.get(field): 92 raise KeyError( 93 f"Field `{field}` is required in the experiment index config." 94 ) 95 96 def instantiate_stl_dataset( 97 self, 98 stl_dataset_cfg: DictConfig, 99 ) -> None: 100 r"""Instantiate the STL dataset object from `stl_dataset_cfg`.""" 101 pylogger.debug( 102 "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...", 103 stl_dataset_cfg.get("_target_"), 104 ) 105 self.stl_dataset = hydra.utils.instantiate( 106 stl_dataset_cfg, 107 ) 108 pylogger.debug( 109 "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!", 110 stl_dataset_cfg.get("_target_"), 111 ) 112 113 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 114 r"""Instantiate the MTL backbone network object from `backbone_cfg`.""" 115 pylogger.debug( 116 "Instantiating backbone network <%s> (clarena.backbones.Backbone)...", 117 backbone_cfg.get("_target_"), 118 ) 119 self.backbone = hydra.utils.instantiate(backbone_cfg) 120 pylogger.debug( 121 "Backbone network <%s> (clarena.backbones.Backbone) instantiated!", 122 backbone_cfg.get("_target_"), 123 ) 124 125 def instantiate_head(self, input_dim: int) -> None: 126 r"""Instantiate the STL output head object. 127 128 **Args:** 129 - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone. 130 """ 131 pylogger.debug( 132 "Instantiating STL head...", 133 ) 134 self.head = HeadSTL(input_dim=input_dim) 135 pylogger.debug("STL head instantiated! ") 136 137 def instantiate_stl_algorithm( 138 self, 139 stl_algorithm_cfg: DictConfig, 140 backbone: Backbone, 141 head: HeadSTL, 142 non_algorithmic_hparams: dict[str, Any], 143 ) -> None: 144 r"""Instantiate the stl_algorithm object from `stl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 145 pylogger.debug( 146 "STL algorithm is set as <%s>. Instantiating <%s> (clarena.stl_algorithms.STLAlgorithm)...", 147 stl_algorithm_cfg.get("_target_"), 148 stl_algorithm_cfg.get("_target_"), 149 ) 150 self.model = hydra.utils.instantiate( 151 stl_algorithm_cfg, 152 backbone=backbone, 153 head=head, 154 non_algorithmic_hparams=non_algorithmic_hparams, 155 ) 156 pylogger.debug( 157 "<%s> (clarena.stl_algorithms.STLAlgorithm) instantiated!", 158 stl_algorithm_cfg.get("_target_"), 159 ) 160 161 def instantiate_optimizer( 162 self, 163 optimizer_cfg: DictConfig, 164 ) -> None: 165 r"""Instantiate the optimizer object from `optimizer_cfg`.""" 166 167 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 168 pylogger.debug( 169 "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...", 170 optimizer_cfg.get("_target_"), 171 ) 172 self.optimizer = hydra.utils.instantiate(optimizer_cfg) 173 pylogger.debug( 174 "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!", 175 optimizer_cfg.get("_target_"), 176 ) 177 178 def instantiate_lr_scheduler( 179 self, 180 lr_scheduler_cfg: DictConfig, 181 ) -> None: 182 r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`.""" 183 184 # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now. 185 pylogger.debug( 186 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...", 187 lr_scheduler_cfg.get("_target_"), 188 ) 189 self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg) 190 pylogger.debug( 191 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!", 192 lr_scheduler_cfg.get("_target_"), 193 ) 194 195 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 196 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 197 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 198 self.lightning_loggers = [ 199 hydra.utils.instantiate(lightning_logger) 200 for lightning_logger in lightning_loggers_cfg.values() 201 ] 202 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!") 203 204 def instantiate_callbacks( 205 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 206 ) -> None: 207 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 208 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 209 210 # instantiate metric callbacks 211 metric_callbacks = [ 212 hydra.utils.instantiate(callback) for callback in metrics_cfg 213 ] 214 215 # instantiate other callbacks 216 other_callbacks = [ 217 hydra.utils.instantiate(callback) for callback in callbacks_cfg 218 ] 219 220 # add metric callbacks to the list of callbacks 221 self.callbacks = metric_callbacks + other_callbacks 222 pylogger.debug("Callbacks (lightning.Callback) instantiated!") 223 224 def instantiate_trainer( 225 self, 226 trainer_cfg: DictConfig, 227 lightning_loggers: list[Logger], 228 callbacks: list[Callback], 229 ) -> None: 230 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 231 232 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 233 self.trainer = hydra.utils.instantiate( 234 trainer_cfg, logger=lightning_loggers, callbacks=callbacks 235 ) 236 pylogger.debug("Trainer (lightning.Trainer) instantiated!") 237 238 def set_global_seed(self, global_seed: int) -> None: 239 r"""Set the `global_seed` for the entire experiment.""" 240 L.seed_everything(self.global_seed, workers=True) 241 pylogger.debug("Global seed is set as %d.", global_seed) 242 243 def run(self) -> None: 244 r"""The main method to run the single-task learning experiment.""" 245 self.set_global_seed(self.global_seed) 246 247 self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset) 248 self.instantiate_backbone(backbone_cfg=self.cfg.backbone) 249 self.instantiate_head(input_dim=self.cfg.backbone.output_dim) 250 self.instantiate_stl_algorithm( 251 stl_algorithm_cfg=self.cfg.stl_algorithm, 252 backbone=self.backbone, 253 head=self.head, 254 non_algorithmic_hparams=select_hyperparameters_from_config( 255 cfg=self.cfg, type=self.cfg.pipeline 256 ), 257 ) # stl_algorithm should be instantiated after backbone and heads 258 self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer) 259 self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler) 260 self.instantiate_lightning_loggers( 261 lightning_loggers_cfg=self.cfg.lightning_loggers 262 ) 263 self.instantiate_callbacks( 264 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 265 ) 266 self.instantiate_trainer( 267 trainer_cfg=self.cfg.trainer, 268 lightning_loggers=self.lightning_loggers, 269 callbacks=self.callbacks, 270 ) # trainer should be instantiated after loggers and callbacks 271 272 # setup task for dataset and model 273 self.stl_dataset.setup_task() 274 t = self.stl_dataset.get_class_map() 275 print(t) 276 self.model.setup_task( 277 num_classes=len(self.stl_dataset.get_class_map()), 278 optimizer=self.optimizer, 279 lr_scheduler=self.lr_scheduler, 280 ) 281 282 # fit the model on the STL dataset 283 self.trainer.fit( 284 model=self.model, 285 datamodule=self.stl_dataset, 286 ) 287 288 # evaluation after training and validation 289 self.trainer.test( 290 model=self.model, 291 datamodule=self.stl_dataset, 292 )
The base class for single-task learning experiment.
33 def __init__(self, cfg: DictConfig) -> None: 34 r""" 35 **Args:** 36 - **cfg** (`DictConfig`): the complete config dict for the single-task learning experiment. 37 """ 38 self.cfg: DictConfig = cfg 39 r"""The complete config dict.""" 40 41 STLExperiment.sanity_check(self) 42 43 # required config fields 44 self.eval: bool = cfg.eval 45 r"""Whether to include evaluation phase.""" 46 self.global_seed: int = cfg.global_seed 47 r"""The global seed for the entire experiment.""" 48 49 # components 50 self.stl_dataset: STLDataset 51 r"""STL dataset object.""" 52 self.backbone: Backbone 53 r"""Backbone network object.""" 54 self.head: HeadSTL 55 r"""STL output heads object.""" 56 self.model: STLAlgorithm 57 r"""STL model object.""" 58 self.optimizer: Optimizer 59 r"""Optimizer object.""" 60 self.lr_scheduler: LRScheduler | None 61 r"""Learning rate scheduler object.""" 62 self.lightning_loggers: list[Logger] 63 r"""The list of initialized lightning loggers objects.""" 64 self.callbacks: list[Callback] 65 r"""The list of initialized callbacks objects.""" 66 self.trainer: Trainer 67 r"""Trainer object."""
Args:
- cfg (
DictConfig): the complete config dict for the single-task learning experiment.
The list of initialized lightning loggers objects.
The list of initialized callbacks objects.
69 def sanity_check(self) -> None: 70 r"""Sanity check for config.""" 71 72 # check required config fields 73 required_config_fields = [ 74 "pipeline", 75 "expr_name", 76 "global_seed", 77 "stl_dataset", 78 "stl_algorithm", 79 "backbone", 80 "optimizer", 81 "lr_scheduler", 82 "trainer", 83 "metrics", 84 "lightning_loggers", 85 "callbacks", 86 "output_dir", 87 # "hydra" is excluded as it doesn't appear 88 "misc", 89 ] 90 for field in required_config_fields: 91 if not self.cfg.get(field): 92 raise KeyError( 93 f"Field `{field}` is required in the experiment index config." 94 )
Sanity check for config.
96 def instantiate_stl_dataset( 97 self, 98 stl_dataset_cfg: DictConfig, 99 ) -> None: 100 r"""Instantiate the STL dataset object from `stl_dataset_cfg`.""" 101 pylogger.debug( 102 "Instantiating STL dataset <%s> (clarena.stl_datasets.STLDataset)...", 103 stl_dataset_cfg.get("_target_"), 104 ) 105 self.stl_dataset = hydra.utils.instantiate( 106 stl_dataset_cfg, 107 ) 108 pylogger.debug( 109 "STL dataset <%s> (clarena.stl_datasets.STLDataset) instantiated!", 110 stl_dataset_cfg.get("_target_"), 111 )
Instantiate the STL dataset object from stl_dataset_cfg.
113 def instantiate_backbone(self, backbone_cfg: DictConfig) -> None: 114 r"""Instantiate the MTL backbone network object from `backbone_cfg`.""" 115 pylogger.debug( 116 "Instantiating backbone network <%s> (clarena.backbones.Backbone)...", 117 backbone_cfg.get("_target_"), 118 ) 119 self.backbone = hydra.utils.instantiate(backbone_cfg) 120 pylogger.debug( 121 "Backbone network <%s> (clarena.backbones.Backbone) instantiated!", 122 backbone_cfg.get("_target_"), 123 )
Instantiate the MTL backbone network object from backbone_cfg.
125 def instantiate_head(self, input_dim: int) -> None: 126 r"""Instantiate the STL output head object. 127 128 **Args:** 129 - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone. 130 """ 131 pylogger.debug( 132 "Instantiating STL head...", 133 ) 134 self.head = HeadSTL(input_dim=input_dim) 135 pylogger.debug("STL head instantiated! ")
Instantiate the STL output head object.
Args:
- input_dim (
int): the input dimension of the head. Must be equal to theoutput_dimof the connected backbone.
137 def instantiate_stl_algorithm( 138 self, 139 stl_algorithm_cfg: DictConfig, 140 backbone: Backbone, 141 head: HeadSTL, 142 non_algorithmic_hparams: dict[str, Any], 143 ) -> None: 144 r"""Instantiate the stl_algorithm object from `stl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`.""" 145 pylogger.debug( 146 "STL algorithm is set as <%s>. Instantiating <%s> (clarena.stl_algorithms.STLAlgorithm)...", 147 stl_algorithm_cfg.get("_target_"), 148 stl_algorithm_cfg.get("_target_"), 149 ) 150 self.model = hydra.utils.instantiate( 151 stl_algorithm_cfg, 152 backbone=backbone, 153 head=head, 154 non_algorithmic_hparams=non_algorithmic_hparams, 155 ) 156 pylogger.debug( 157 "<%s> (clarena.stl_algorithms.STLAlgorithm) instantiated!", 158 stl_algorithm_cfg.get("_target_"), 159 )
Instantiate the stl_algorithm object from stl_algorithm_cfg, backbone, heads and non_algorithmic_hparams.
161 def instantiate_optimizer( 162 self, 163 optimizer_cfg: DictConfig, 164 ) -> None: 165 r"""Instantiate the optimizer object from `optimizer_cfg`.""" 166 167 # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now. 168 pylogger.debug( 169 "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...", 170 optimizer_cfg.get("_target_"), 171 ) 172 self.optimizer = hydra.utils.instantiate(optimizer_cfg) 173 pylogger.debug( 174 "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!", 175 optimizer_cfg.get("_target_"), 176 )
Instantiate the optimizer object from optimizer_cfg.
178 def instantiate_lr_scheduler( 179 self, 180 lr_scheduler_cfg: DictConfig, 181 ) -> None: 182 r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`.""" 183 184 # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now. 185 pylogger.debug( 186 "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...", 187 lr_scheduler_cfg.get("_target_"), 188 ) 189 self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg) 190 pylogger.debug( 191 "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!", 192 lr_scheduler_cfg.get("_target_"), 193 )
Instantiate the learning rate scheduler object from lr_scheduler_cfg.
195 def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None: 196 r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`.""" 197 pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...") 198 self.lightning_loggers = [ 199 hydra.utils.instantiate(lightning_logger) 200 for lightning_logger in lightning_loggers_cfg.values() 201 ] 202 pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
Instantiate the list of lightning loggers objects from lightning_loggers_cfg.
204 def instantiate_callbacks( 205 self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig 206 ) -> None: 207 r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.""" 208 pylogger.debug("Instantiating callbacks (lightning.Callback)...") 209 210 # instantiate metric callbacks 211 metric_callbacks = [ 212 hydra.utils.instantiate(callback) for callback in metrics_cfg 213 ] 214 215 # instantiate other callbacks 216 other_callbacks = [ 217 hydra.utils.instantiate(callback) for callback in callbacks_cfg 218 ] 219 220 # add metric callbacks to the list of callbacks 221 self.callbacks = metric_callbacks + other_callbacks 222 pylogger.debug("Callbacks (lightning.Callback) instantiated!")
Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.
224 def instantiate_trainer( 225 self, 226 trainer_cfg: DictConfig, 227 lightning_loggers: list[Logger], 228 callbacks: list[Callback], 229 ) -> None: 230 r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`.""" 231 232 pylogger.debug("Instantiating trainer (lightning.Trainer)...") 233 self.trainer = hydra.utils.instantiate( 234 trainer_cfg, logger=lightning_loggers, callbacks=callbacks 235 ) 236 pylogger.debug("Trainer (lightning.Trainer) instantiated!")
Instantiate the trainer object from trainer_cfg, lightning_loggers, and callbacks.
238 def set_global_seed(self, global_seed: int) -> None: 239 r"""Set the `global_seed` for the entire experiment.""" 240 L.seed_everything(self.global_seed, workers=True) 241 pylogger.debug("Global seed is set as %d.", global_seed)
Set the global_seed for the entire experiment.
243 def run(self) -> None: 244 r"""The main method to run the single-task learning experiment.""" 245 self.set_global_seed(self.global_seed) 246 247 self.instantiate_stl_dataset(stl_dataset_cfg=self.cfg.stl_dataset) 248 self.instantiate_backbone(backbone_cfg=self.cfg.backbone) 249 self.instantiate_head(input_dim=self.cfg.backbone.output_dim) 250 self.instantiate_stl_algorithm( 251 stl_algorithm_cfg=self.cfg.stl_algorithm, 252 backbone=self.backbone, 253 head=self.head, 254 non_algorithmic_hparams=select_hyperparameters_from_config( 255 cfg=self.cfg, type=self.cfg.pipeline 256 ), 257 ) # stl_algorithm should be instantiated after backbone and heads 258 self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer) 259 self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler) 260 self.instantiate_lightning_loggers( 261 lightning_loggers_cfg=self.cfg.lightning_loggers 262 ) 263 self.instantiate_callbacks( 264 metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks 265 ) 266 self.instantiate_trainer( 267 trainer_cfg=self.cfg.trainer, 268 lightning_loggers=self.lightning_loggers, 269 callbacks=self.callbacks, 270 ) # trainer should be instantiated after loggers and callbacks 271 272 # setup task for dataset and model 273 self.stl_dataset.setup_task() 274 t = self.stl_dataset.get_class_map() 275 print(t) 276 self.model.setup_task( 277 num_classes=len(self.stl_dataset.get_class_map()), 278 optimizer=self.optimizer, 279 lr_scheduler=self.lr_scheduler, 280 ) 281 282 # fit the model on the STL dataset 283 self.trainer.fit( 284 model=self.model, 285 datamodule=self.stl_dataset, 286 ) 287 288 # evaluation after training and validation 289 self.trainer.test( 290 model=self.model, 291 datamodule=self.stl_dataset, 292 )
The main method to run the single-task learning experiment.