clarena.pipelines.mtl_expr

The submodule in pipelines for multi-task learning experiment.

  1r"""
  2The submodule in `pipelines` for multi-task learning experiment.
  3"""
  4
  5__all__ = ["MTLExperiment"]
  6
  7import logging
  8from typing import Any
  9
 10import hydra
 11import lightning as L
 12from lightning import Callback, Trainer
 13from lightning.pytorch.loggers import Logger
 14from omegaconf import DictConfig
 15from torch.optim import Optimizer
 16from torch.optim.lr_scheduler import LRScheduler
 17
 18from clarena.backbones import Backbone, CLBackbone
 19from clarena.heads import HeadsMTL
 20from clarena.mtl_algorithms import MTLAlgorithm
 21from clarena.mtl_datasets import MTLDataset
 22from clarena.utils.cfg import select_hyperparameters_from_config
 23
 24# always get logger for built-in logging in each module
 25pylogger = logging.getLogger(__name__)
 26
 27
 28class MTLExperiment:
 29    r"""The base class for multi-task learning experiment."""
 30
 31    def __init__(self, cfg: DictConfig) -> None:
 32        r"""
 33        **Args:**
 34        - **cfg** (`DictConfig`): the complete config dict for the multi-task learning experiment.
 35        """
 36        self.cfg: DictConfig = cfg
 37        r"""The complete config dict."""
 38
 39        MTLExperiment.sanity_check(self)
 40
 41        # required config fields
 42        self.train_tasks: list[int] = (
 43            cfg.train_tasks
 44            if isinstance(cfg.train_tasks, list)
 45            else list(range(1, cfg.train_tasks + 1))
 46        )
 47        r"""The list of tasks to train."""
 48        self.eval_tasks: list[int] = (
 49            cfg.eval_tasks
 50            if isinstance(cfg.eval_tasks, list)
 51            else list(range(1, cfg.eval_tasks + 1))
 52        )
 53        r"""The list of tasks to evaluate."""
 54        self.global_seed: int = cfg.global_seed
 55        r"""The global seed for the entire experiment."""
 56        self.output_dir: str = cfg.output_dir
 57        r"""The folder for storing the experiment results."""
 58
 59        # components
 60        self.mtl_dataset: MTLDataset
 61        r"""MTL dataset object."""
 62        self.backbone: CLBackbone
 63        r"""Backbone network object."""
 64        self.heads: HeadsMTL
 65        r"""MTL output heads object."""
 66        self.model: MTLAlgorithm
 67        r"""MTL model object."""
 68        self.optimizer: Optimizer
 69        r"""Optimizer object."""
 70        self.lr_scheduler: LRScheduler | None
 71        r"""Learning rate scheduler object."""
 72        self.lightning_loggers: list[Logger]
 73        r"""The list of initialized lightning loggers objects."""
 74        self.callbacks: list[Callback]
 75        r"""The list of initialized callbacks objects."""
 76        self.trainer: Trainer
 77        r"""Trainer object."""
 78
 79    def sanity_check(self) -> None:
 80        r"""Sanity check for config."""
 81
 82        # check required config fields
 83        required_config_fields = [
 84            "pipeline",
 85            "expr_name",
 86            "train_tasks",
 87            "eval_tasks",
 88            "global_seed",
 89            "mtl_dataset",
 90            "mtl_algorithm",
 91            "backbone",
 92            "optimizer",
 93            "lr_scheduler",
 94            "trainer",
 95            "metrics",
 96            "lightning_loggers",
 97            "callbacks",
 98            "output_dir",
 99            # "hydra" is excluded as it doesn't appear
100            "misc",
101        ]
102        for field in required_config_fields:
103            if not self.cfg.get(field):
104                raise KeyError(
105                    f"Field `{field}` is required in the experiment index config."
106                )
107
108        # get dataset number of tasks
109        if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL":
110            cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset")
111            if cl_dataset_cfg.get("num_tasks"):
112                num_tasks = cl_dataset_cfg.get("num_tasks")
113            elif cl_dataset_cfg.get("class_split"):
114                num_tasks = len(cl_dataset_cfg.class_split)
115            elif cl_dataset_cfg.get("datasets"):
116                num_tasks = len(cl_dataset_cfg.datasets)
117            else:
118                raise KeyError(
119                    "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
120                )
121        else:
122            if self.cfg.mtl_dataset.get("num_tasks"):
123                num_tasks = self.cfg.mtl_dataset.num_tasks
124            else:
125                raise KeyError(
126                    "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config."
127                )
128
129        # check train_tasks
130        train_tasks = self.cfg.train_tasks
131        if isinstance(train_tasks, list):
132            if len(train_tasks) < 1:
133                raise ValueError("`train_tasks` must contain at least one task.")
134            if any(t < 1 or t > num_tasks for t in train_tasks):
135                raise ValueError(
136                    f"All task IDs in `train_tasks` must be between 1 and {num_tasks}."
137                )
138        elif isinstance(train_tasks, int):
139            if train_tasks < 0 or train_tasks > num_tasks:
140                raise ValueError(
141                    f"`train_tasks` as integer must be between 0 and {num_tasks}."
142                )
143        else:
144            raise TypeError(
145                "`train_tasks` must be either a list of integers or an integer."
146            )
147
148        # check eval_tasks
149        eval_tasks = self.cfg.eval_tasks
150        if isinstance(eval_tasks, list):
151            if len(eval_tasks) < 1:
152                raise ValueError("`eval_tasks` must contain at least one task.")
153            if any(t < 1 or t > num_tasks for t in eval_tasks):
154                raise ValueError(
155                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
156                )
157        elif isinstance(eval_tasks, int):
158            if eval_tasks < 0 or eval_tasks > num_tasks:
159                raise ValueError(
160                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
161                )
162        else:
163            raise TypeError(
164                "`eval_tasks` must be either a list of integers or an integer."
165            )
166
167    def instantiate_mtl_dataset(
168        self,
169        mtl_dataset_cfg: DictConfig,
170    ) -> None:
171        r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`."""
172        pylogger.debug(
173            "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset)...",
174            mtl_dataset_cfg.get("_target_"),
175        )
176        self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg)
177        pylogger.debug(
178            "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated!",
179            mtl_dataset_cfg.get("_target_"),
180        )
181
182    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
183        r"""Instantiate the MTL backbone network object from `backbone_cfg`."""
184        pylogger.debug(
185            "Instantiating backbone network <%s> (clarena.backbones.Backbone)...",
186            backbone_cfg.get("_target_"),
187        )
188        self.backbone = hydra.utils.instantiate(backbone_cfg)
189        pylogger.debug(
190            "Backbone network <%s> (clarena.backbones.Backbone) instantiated!",
191            backbone_cfg.get("_target_"),
192        )
193
194    def instantiate_heads(
195        self,
196        input_dim: int,
197    ) -> None:
198        r"""Instantiate the MTL output heads object.
199
200        **Args:**
201        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
202        """
203        pylogger.debug(
204            "Instantiating MTL heads...",
205        )
206        self.heads = HeadsMTL(input_dim=input_dim)
207        pylogger.debug("MTL heads instantiated! ")
208
209    def instantiate_mtl_algorithm(
210        self,
211        mtl_algorithm_cfg: DictConfig,
212        backbone: Backbone,
213        heads: HeadsMTL,
214        non_algorithmic_hparams: dict[str, Any],
215    ) -> None:
216        r"""Instantiate the mtl_algorithm object from `mtl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
217        pylogger.debug(
218            "MTL algorithm is set as <%s>. Instantiating <%s> (clarena.mtl_algorithms.MTLAlgorithm)...",
219            mtl_algorithm_cfg.get("_target_"),
220            mtl_algorithm_cfg.get("_target_"),
221        )
222        self.model = hydra.utils.instantiate(
223            mtl_algorithm_cfg,
224            backbone=backbone,
225            heads=heads,
226            non_algorithmic_hparams=non_algorithmic_hparams,
227        )
228        pylogger.debug(
229            "<%s> (clarena.mtl_algorithms.MTLAlgorithm) instantiated!",
230            mtl_algorithm_cfg.get("_target_"),
231        )
232
233    def instantiate_optimizer(
234        self,
235        optimizer_cfg: DictConfig,
236    ) -> None:
237        r"""Instantiate the optimizer object from `optimizer_cfg`."""
238
239        # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
240        pylogger.debug(
241            "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...",
242            optimizer_cfg.get("_target_"),
243        )
244        self.optimizer = hydra.utils.instantiate(optimizer_cfg)
245        pylogger.debug(
246            "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!",
247            optimizer_cfg.get("_target_"),
248        )
249
250    def instantiate_lr_scheduler(
251        self,
252        lr_scheduler_cfg: DictConfig,
253    ) -> None:
254        r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`."""
255
256        # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now.
257        pylogger.debug(
258            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...",
259            lr_scheduler_cfg.get("_target_"),
260        )
261        self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg)
262        pylogger.debug(
263            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!",
264            lr_scheduler_cfg.get("_target_"),
265        )
266
267    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
268        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
269        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
270        self.lightning_loggers = [
271            hydra.utils.instantiate(lightning_logger)
272            for lightning_logger in lightning_loggers_cfg.values()
273        ]
274        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
275
276    def instantiate_callbacks(
277        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
278    ) -> None:
279        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
280        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
281
282        # instantiate metric callbacks
283        metric_callbacks = [
284            hydra.utils.instantiate(callback) for callback in metrics_cfg
285        ]
286
287        # instantiate other callbacks
288        other_callbacks = [
289            hydra.utils.instantiate(callback) for callback in callbacks_cfg
290        ]
291
292        # add metric callbacks to the list of callbacks
293        self.callbacks = metric_callbacks + other_callbacks
294        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
295
296    def instantiate_trainer(
297        self,
298        trainer_cfg: DictConfig,
299        lightning_loggers: list[Logger],
300        callbacks: list[Callback],
301    ) -> None:
302        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
303
304        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
305        self.trainer = hydra.utils.instantiate(
306            trainer_cfg, logger=lightning_loggers, callbacks=callbacks
307        )
308        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
309
310    def set_global_seed(self, global_seed: int) -> None:
311        r"""Set the `global_seed` for the entire experiment."""
312        L.seed_everything(self.global_seed, workers=True)
313        pylogger.debug("Global seed is set as %d.", global_seed)
314
315    def run(self) -> None:
316        r"""The main method to run the multi-task learning experiment."""
317        self.set_global_seed(self.global_seed)
318
319        self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset)
320        self.instantiate_backbone(backbone_cfg=self.cfg.backbone)
321        self.instantiate_heads(input_dim=self.cfg.backbone.output_dim)
322        self.instantiate_mtl_algorithm(
323            mtl_algorithm_cfg=self.cfg.mtl_algorithm,
324            backbone=self.backbone,
325            heads=self.heads,
326            non_algorithmic_hparams=select_hyperparameters_from_config(
327                cfg=self.cfg, type=self.cfg.pipeline
328            ),
329        )  # mtl_algorithm should be instantiated after backbone and heads
330        self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer)
331        self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler)
332        self.instantiate_lightning_loggers(
333            lightning_loggers_cfg=self.cfg.lightning_loggers
334        )
335        self.instantiate_callbacks(
336            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
337        )
338        self.instantiate_trainer(
339            trainer_cfg=self.cfg.trainer,
340            lightning_loggers=self.lightning_loggers,
341            callbacks=self.callbacks,
342        )  # trainer should be instantiated after lightning loggers and callbacks
343
344        # setup tasks for dataset and model
345        self.mtl_dataset.setup_tasks_expr(
346            train_tasks=self.train_tasks, eval_tasks=self.eval_tasks
347        )
348        self.model.setup_tasks(
349            task_ids=self.train_tasks,
350            num_classes={
351                task_id: len(self.mtl_dataset.get_mtl_class_map(task_id))
352                for task_id in self.train_tasks
353            },
354            optimizer=self.optimizer,
355            lr_scheduler=self.lr_scheduler,
356        )
357
358        # train and validate the model
359        self.trainer.fit(
360            model=self.model,
361            datamodule=self.mtl_dataset,
362        )
363
364        # evaluation after training and validation
365        self.trainer.test(
366            model=self.model,
367            datamodule=self.mtl_dataset,
368        )
class MTLExperiment:
 29class MTLExperiment:
 30    r"""The base class for multi-task learning experiment."""
 31
 32    def __init__(self, cfg: DictConfig) -> None:
 33        r"""
 34        **Args:**
 35        - **cfg** (`DictConfig`): the complete config dict for the multi-task learning experiment.
 36        """
 37        self.cfg: DictConfig = cfg
 38        r"""The complete config dict."""
 39
 40        MTLExperiment.sanity_check(self)
 41
 42        # required config fields
 43        self.train_tasks: list[int] = (
 44            cfg.train_tasks
 45            if isinstance(cfg.train_tasks, list)
 46            else list(range(1, cfg.train_tasks + 1))
 47        )
 48        r"""The list of tasks to train."""
 49        self.eval_tasks: list[int] = (
 50            cfg.eval_tasks
 51            if isinstance(cfg.eval_tasks, list)
 52            else list(range(1, cfg.eval_tasks + 1))
 53        )
 54        r"""The list of tasks to evaluate."""
 55        self.global_seed: int = cfg.global_seed
 56        r"""The global seed for the entire experiment."""
 57        self.output_dir: str = cfg.output_dir
 58        r"""The folder for storing the experiment results."""
 59
 60        # components
 61        self.mtl_dataset: MTLDataset
 62        r"""MTL dataset object."""
 63        self.backbone: CLBackbone
 64        r"""Backbone network object."""
 65        self.heads: HeadsMTL
 66        r"""MTL output heads object."""
 67        self.model: MTLAlgorithm
 68        r"""MTL model object."""
 69        self.optimizer: Optimizer
 70        r"""Optimizer object."""
 71        self.lr_scheduler: LRScheduler | None
 72        r"""Learning rate scheduler object."""
 73        self.lightning_loggers: list[Logger]
 74        r"""The list of initialized lightning loggers objects."""
 75        self.callbacks: list[Callback]
 76        r"""The list of initialized callbacks objects."""
 77        self.trainer: Trainer
 78        r"""Trainer object."""
 79
 80    def sanity_check(self) -> None:
 81        r"""Sanity check for config."""
 82
 83        # check required config fields
 84        required_config_fields = [
 85            "pipeline",
 86            "expr_name",
 87            "train_tasks",
 88            "eval_tasks",
 89            "global_seed",
 90            "mtl_dataset",
 91            "mtl_algorithm",
 92            "backbone",
 93            "optimizer",
 94            "lr_scheduler",
 95            "trainer",
 96            "metrics",
 97            "lightning_loggers",
 98            "callbacks",
 99            "output_dir",
100            # "hydra" is excluded as it doesn't appear
101            "misc",
102        ]
103        for field in required_config_fields:
104            if not self.cfg.get(field):
105                raise KeyError(
106                    f"Field `{field}` is required in the experiment index config."
107                )
108
109        # get dataset number of tasks
110        if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL":
111            cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset")
112            if cl_dataset_cfg.get("num_tasks"):
113                num_tasks = cl_dataset_cfg.get("num_tasks")
114            elif cl_dataset_cfg.get("class_split"):
115                num_tasks = len(cl_dataset_cfg.class_split)
116            elif cl_dataset_cfg.get("datasets"):
117                num_tasks = len(cl_dataset_cfg.datasets)
118            else:
119                raise KeyError(
120                    "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
121                )
122        else:
123            if self.cfg.mtl_dataset.get("num_tasks"):
124                num_tasks = self.cfg.mtl_dataset.num_tasks
125            else:
126                raise KeyError(
127                    "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config."
128                )
129
130        # check train_tasks
131        train_tasks = self.cfg.train_tasks
132        if isinstance(train_tasks, list):
133            if len(train_tasks) < 1:
134                raise ValueError("`train_tasks` must contain at least one task.")
135            if any(t < 1 or t > num_tasks for t in train_tasks):
136                raise ValueError(
137                    f"All task IDs in `train_tasks` must be between 1 and {num_tasks}."
138                )
139        elif isinstance(train_tasks, int):
140            if train_tasks < 0 or train_tasks > num_tasks:
141                raise ValueError(
142                    f"`train_tasks` as integer must be between 0 and {num_tasks}."
143                )
144        else:
145            raise TypeError(
146                "`train_tasks` must be either a list of integers or an integer."
147            )
148
149        # check eval_tasks
150        eval_tasks = self.cfg.eval_tasks
151        if isinstance(eval_tasks, list):
152            if len(eval_tasks) < 1:
153                raise ValueError("`eval_tasks` must contain at least one task.")
154            if any(t < 1 or t > num_tasks for t in eval_tasks):
155                raise ValueError(
156                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
157                )
158        elif isinstance(eval_tasks, int):
159            if eval_tasks < 0 or eval_tasks > num_tasks:
160                raise ValueError(
161                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
162                )
163        else:
164            raise TypeError(
165                "`eval_tasks` must be either a list of integers or an integer."
166            )
167
168    def instantiate_mtl_dataset(
169        self,
170        mtl_dataset_cfg: DictConfig,
171    ) -> None:
172        r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`."""
173        pylogger.debug(
174            "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset)...",
175            mtl_dataset_cfg.get("_target_"),
176        )
177        self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg)
178        pylogger.debug(
179            "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated!",
180            mtl_dataset_cfg.get("_target_"),
181        )
182
183    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
184        r"""Instantiate the MTL backbone network object from `backbone_cfg`."""
185        pylogger.debug(
186            "Instantiating backbone network <%s> (clarena.backbones.Backbone)...",
187            backbone_cfg.get("_target_"),
188        )
189        self.backbone = hydra.utils.instantiate(backbone_cfg)
190        pylogger.debug(
191            "Backbone network <%s> (clarena.backbones.Backbone) instantiated!",
192            backbone_cfg.get("_target_"),
193        )
194
195    def instantiate_heads(
196        self,
197        input_dim: int,
198    ) -> None:
199        r"""Instantiate the MTL output heads object.
200
201        **Args:**
202        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
203        """
204        pylogger.debug(
205            "Instantiating MTL heads...",
206        )
207        self.heads = HeadsMTL(input_dim=input_dim)
208        pylogger.debug("MTL heads instantiated! ")
209
210    def instantiate_mtl_algorithm(
211        self,
212        mtl_algorithm_cfg: DictConfig,
213        backbone: Backbone,
214        heads: HeadsMTL,
215        non_algorithmic_hparams: dict[str, Any],
216    ) -> None:
217        r"""Instantiate the mtl_algorithm object from `mtl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
218        pylogger.debug(
219            "MTL algorithm is set as <%s>. Instantiating <%s> (clarena.mtl_algorithms.MTLAlgorithm)...",
220            mtl_algorithm_cfg.get("_target_"),
221            mtl_algorithm_cfg.get("_target_"),
222        )
223        self.model = hydra.utils.instantiate(
224            mtl_algorithm_cfg,
225            backbone=backbone,
226            heads=heads,
227            non_algorithmic_hparams=non_algorithmic_hparams,
228        )
229        pylogger.debug(
230            "<%s> (clarena.mtl_algorithms.MTLAlgorithm) instantiated!",
231            mtl_algorithm_cfg.get("_target_"),
232        )
233
234    def instantiate_optimizer(
235        self,
236        optimizer_cfg: DictConfig,
237    ) -> None:
238        r"""Instantiate the optimizer object from `optimizer_cfg`."""
239
240        # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
241        pylogger.debug(
242            "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...",
243            optimizer_cfg.get("_target_"),
244        )
245        self.optimizer = hydra.utils.instantiate(optimizer_cfg)
246        pylogger.debug(
247            "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!",
248            optimizer_cfg.get("_target_"),
249        )
250
251    def instantiate_lr_scheduler(
252        self,
253        lr_scheduler_cfg: DictConfig,
254    ) -> None:
255        r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`."""
256
257        # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now.
258        pylogger.debug(
259            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...",
260            lr_scheduler_cfg.get("_target_"),
261        )
262        self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg)
263        pylogger.debug(
264            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!",
265            lr_scheduler_cfg.get("_target_"),
266        )
267
268    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
269        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
270        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
271        self.lightning_loggers = [
272            hydra.utils.instantiate(lightning_logger)
273            for lightning_logger in lightning_loggers_cfg.values()
274        ]
275        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")
276
277    def instantiate_callbacks(
278        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
279    ) -> None:
280        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
281        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
282
283        # instantiate metric callbacks
284        metric_callbacks = [
285            hydra.utils.instantiate(callback) for callback in metrics_cfg
286        ]
287
288        # instantiate other callbacks
289        other_callbacks = [
290            hydra.utils.instantiate(callback) for callback in callbacks_cfg
291        ]
292
293        # add metric callbacks to the list of callbacks
294        self.callbacks = metric_callbacks + other_callbacks
295        pylogger.debug("Callbacks (lightning.Callback) instantiated!")
296
297    def instantiate_trainer(
298        self,
299        trainer_cfg: DictConfig,
300        lightning_loggers: list[Logger],
301        callbacks: list[Callback],
302    ) -> None:
303        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
304
305        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
306        self.trainer = hydra.utils.instantiate(
307            trainer_cfg, logger=lightning_loggers, callbacks=callbacks
308        )
309        pylogger.debug("Trainer (lightning.Trainer) instantiated!")
310
311    def set_global_seed(self, global_seed: int) -> None:
312        r"""Set the `global_seed` for the entire experiment."""
313        L.seed_everything(self.global_seed, workers=True)
314        pylogger.debug("Global seed is set as %d.", global_seed)
315
316    def run(self) -> None:
317        r"""The main method to run the multi-task learning experiment."""
318        self.set_global_seed(self.global_seed)
319
320        self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset)
321        self.instantiate_backbone(backbone_cfg=self.cfg.backbone)
322        self.instantiate_heads(input_dim=self.cfg.backbone.output_dim)
323        self.instantiate_mtl_algorithm(
324            mtl_algorithm_cfg=self.cfg.mtl_algorithm,
325            backbone=self.backbone,
326            heads=self.heads,
327            non_algorithmic_hparams=select_hyperparameters_from_config(
328                cfg=self.cfg, type=self.cfg.pipeline
329            ),
330        )  # mtl_algorithm should be instantiated after backbone and heads
331        self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer)
332        self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler)
333        self.instantiate_lightning_loggers(
334            lightning_loggers_cfg=self.cfg.lightning_loggers
335        )
336        self.instantiate_callbacks(
337            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
338        )
339        self.instantiate_trainer(
340            trainer_cfg=self.cfg.trainer,
341            lightning_loggers=self.lightning_loggers,
342            callbacks=self.callbacks,
343        )  # trainer should be instantiated after lightning loggers and callbacks
344
345        # setup tasks for dataset and model
346        self.mtl_dataset.setup_tasks_expr(
347            train_tasks=self.train_tasks, eval_tasks=self.eval_tasks
348        )
349        self.model.setup_tasks(
350            task_ids=self.train_tasks,
351            num_classes={
352                task_id: len(self.mtl_dataset.get_mtl_class_map(task_id))
353                for task_id in self.train_tasks
354            },
355            optimizer=self.optimizer,
356            lr_scheduler=self.lr_scheduler,
357        )
358
359        # train and validate the model
360        self.trainer.fit(
361            model=self.model,
362            datamodule=self.mtl_dataset,
363        )
364
365        # evaluation after training and validation
366        self.trainer.test(
367            model=self.model,
368            datamodule=self.mtl_dataset,
369        )

The base class for multi-task learning experiment.

MTLExperiment(cfg: omegaconf.dictconfig.DictConfig)
32    def __init__(self, cfg: DictConfig) -> None:
33        r"""
34        **Args:**
35        - **cfg** (`DictConfig`): the complete config dict for the multi-task learning experiment.
36        """
37        self.cfg: DictConfig = cfg
38        r"""The complete config dict."""
39
40        MTLExperiment.sanity_check(self)
41
42        # required config fields
43        self.train_tasks: list[int] = (
44            cfg.train_tasks
45            if isinstance(cfg.train_tasks, list)
46            else list(range(1, cfg.train_tasks + 1))
47        )
48        r"""The list of tasks to train."""
49        self.eval_tasks: list[int] = (
50            cfg.eval_tasks
51            if isinstance(cfg.eval_tasks, list)
52            else list(range(1, cfg.eval_tasks + 1))
53        )
54        r"""The list of tasks to evaluate."""
55        self.global_seed: int = cfg.global_seed
56        r"""The global seed for the entire experiment."""
57        self.output_dir: str = cfg.output_dir
58        r"""The folder for storing the experiment results."""
59
60        # components
61        self.mtl_dataset: MTLDataset
62        r"""MTL dataset object."""
63        self.backbone: CLBackbone
64        r"""Backbone network object."""
65        self.heads: HeadsMTL
66        r"""MTL output heads object."""
67        self.model: MTLAlgorithm
68        r"""MTL model object."""
69        self.optimizer: Optimizer
70        r"""Optimizer object."""
71        self.lr_scheduler: LRScheduler | None
72        r"""Learning rate scheduler object."""
73        self.lightning_loggers: list[Logger]
74        r"""The list of initialized lightning loggers objects."""
75        self.callbacks: list[Callback]
76        r"""The list of initialized callbacks objects."""
77        self.trainer: Trainer
78        r"""Trainer object."""

Args:

  • cfg (DictConfig): the complete config dict for the multi-task learning experiment.
cfg: omegaconf.dictconfig.DictConfig

The complete config dict.

train_tasks: list[int]

The list of tasks to train.

eval_tasks: list[int]

The list of tasks to evaluate.

global_seed: int

The global seed for the entire experiment.

output_dir: str

The folder for storing the experiment results.

MTL dataset object.

Backbone network object.

MTL output heads object.

MTL model object.

optimizer: torch.optim.optimizer.Optimizer

Optimizer object.

lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None

Learning rate scheduler object.

lightning_loggers: list[lightning.pytorch.loggers.logger.Logger]

The list of initialized lightning loggers objects.

callbacks: list[lightning.pytorch.callbacks.callback.Callback]

The list of initialized callbacks objects.

trainer: lightning.pytorch.trainer.trainer.Trainer

Trainer object.

def sanity_check(self) -> None:
 80    def sanity_check(self) -> None:
 81        r"""Sanity check for config."""
 82
 83        # check required config fields
 84        required_config_fields = [
 85            "pipeline",
 86            "expr_name",
 87            "train_tasks",
 88            "eval_tasks",
 89            "global_seed",
 90            "mtl_dataset",
 91            "mtl_algorithm",
 92            "backbone",
 93            "optimizer",
 94            "lr_scheduler",
 95            "trainer",
 96            "metrics",
 97            "lightning_loggers",
 98            "callbacks",
 99            "output_dir",
100            # "hydra" is excluded as it doesn't appear
101            "misc",
102        ]
103        for field in required_config_fields:
104            if not self.cfg.get(field):
105                raise KeyError(
106                    f"Field `{field}` is required in the experiment index config."
107                )
108
109        # get dataset number of tasks
110        if self.cfg.mtl_dataset._target_ == "clarena.mtl_datasets.MTLDatasetFromCL":
111            cl_dataset_cfg = self.cfg.mtl_dataset.get("cl_dataset")
112            if cl_dataset_cfg.get("num_tasks"):
113                num_tasks = cl_dataset_cfg.get("num_tasks")
114            elif cl_dataset_cfg.get("class_split"):
115                num_tasks = len(cl_dataset_cfg.class_split)
116            elif cl_dataset_cfg.get("datasets"):
117                num_tasks = len(cl_dataset_cfg.datasets)
118            else:
119                raise KeyError(
120                    "`num_tasks` is required in cl_dataset config under mtl_dataset config. Please specify `num_tasks` (for `CLPermutedDataset`) or `class_split` (for `CLSplitDataset`) or `datasets` (for `CLCombinedDataset`) in cl_dataset config."
121                )
122        else:
123            if self.cfg.mtl_dataset.get("num_tasks"):
124                num_tasks = self.cfg.mtl_dataset.num_tasks
125            else:
126                raise KeyError(
127                    "`num_tasks` is required in mtl_dataset config. Please specify `num_tasks` in mtl_dataset config."
128                )
129
130        # check train_tasks
131        train_tasks = self.cfg.train_tasks
132        if isinstance(train_tasks, list):
133            if len(train_tasks) < 1:
134                raise ValueError("`train_tasks` must contain at least one task.")
135            if any(t < 1 or t > num_tasks for t in train_tasks):
136                raise ValueError(
137                    f"All task IDs in `train_tasks` must be between 1 and {num_tasks}."
138                )
139        elif isinstance(train_tasks, int):
140            if train_tasks < 0 or train_tasks > num_tasks:
141                raise ValueError(
142                    f"`train_tasks` as integer must be between 0 and {num_tasks}."
143                )
144        else:
145            raise TypeError(
146                "`train_tasks` must be either a list of integers or an integer."
147            )
148
149        # check eval_tasks
150        eval_tasks = self.cfg.eval_tasks
151        if isinstance(eval_tasks, list):
152            if len(eval_tasks) < 1:
153                raise ValueError("`eval_tasks` must contain at least one task.")
154            if any(t < 1 or t > num_tasks for t in eval_tasks):
155                raise ValueError(
156                    f"All task IDs in `eval_tasks` must be between 1 and {num_tasks}."
157                )
158        elif isinstance(eval_tasks, int):
159            if eval_tasks < 0 or eval_tasks > num_tasks:
160                raise ValueError(
161                    f"`eval_tasks` as integer must be between 0 and {num_tasks}."
162                )
163        else:
164            raise TypeError(
165                "`eval_tasks` must be either a list of integers or an integer."
166            )

Sanity check for config.

def instantiate_mtl_dataset(self, mtl_dataset_cfg: omegaconf.dictconfig.DictConfig) -> None:
168    def instantiate_mtl_dataset(
169        self,
170        mtl_dataset_cfg: DictConfig,
171    ) -> None:
172        r"""Instantiate the MTL dataset object from `mtl_dataset_cfg`."""
173        pylogger.debug(
174            "Instantiating MTL dataset <%s> (clarena.mtl_datasets.MTLDataset)...",
175            mtl_dataset_cfg.get("_target_"),
176        )
177        self.mtl_dataset = hydra.utils.instantiate(mtl_dataset_cfg)
178        pylogger.debug(
179            "MTL dataset <%s> (clarena.mtl_datasets.MTLDataset) instantiated!",
180            mtl_dataset_cfg.get("_target_"),
181        )

Instantiate the MTL dataset object from mtl_dataset_cfg.

def instantiate_backbone(self, backbone_cfg: omegaconf.dictconfig.DictConfig) -> None:
183    def instantiate_backbone(self, backbone_cfg: DictConfig) -> None:
184        r"""Instantiate the MTL backbone network object from `backbone_cfg`."""
185        pylogger.debug(
186            "Instantiating backbone network <%s> (clarena.backbones.Backbone)...",
187            backbone_cfg.get("_target_"),
188        )
189        self.backbone = hydra.utils.instantiate(backbone_cfg)
190        pylogger.debug(
191            "Backbone network <%s> (clarena.backbones.Backbone) instantiated!",
192            backbone_cfg.get("_target_"),
193        )

Instantiate the MTL backbone network object from backbone_cfg.

def instantiate_heads(self, input_dim: int) -> None:
195    def instantiate_heads(
196        self,
197        input_dim: int,
198    ) -> None:
199        r"""Instantiate the MTL output heads object.
200
201        **Args:**
202        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
203        """
204        pylogger.debug(
205            "Instantiating MTL heads...",
206        )
207        self.heads = HeadsMTL(input_dim=input_dim)
208        pylogger.debug("MTL heads instantiated! ")

Instantiate the MTL output heads object.

Args:

  • input_dim (int): the input dimension of the heads. Must be equal to the output_dim of the connected backbone.
def instantiate_mtl_algorithm( self, mtl_algorithm_cfg: omegaconf.dictconfig.DictConfig, backbone: clarena.backbones.Backbone, heads: clarena.heads.HeadsMTL, non_algorithmic_hparams: dict[str, typing.Any]) -> None:
210    def instantiate_mtl_algorithm(
211        self,
212        mtl_algorithm_cfg: DictConfig,
213        backbone: Backbone,
214        heads: HeadsMTL,
215        non_algorithmic_hparams: dict[str, Any],
216    ) -> None:
217        r"""Instantiate the mtl_algorithm object from `mtl_algorithm_cfg`, `backbone`, `heads` and `non_algorithmic_hparams`."""
218        pylogger.debug(
219            "MTL algorithm is set as <%s>. Instantiating <%s> (clarena.mtl_algorithms.MTLAlgorithm)...",
220            mtl_algorithm_cfg.get("_target_"),
221            mtl_algorithm_cfg.get("_target_"),
222        )
223        self.model = hydra.utils.instantiate(
224            mtl_algorithm_cfg,
225            backbone=backbone,
226            heads=heads,
227            non_algorithmic_hparams=non_algorithmic_hparams,
228        )
229        pylogger.debug(
230            "<%s> (clarena.mtl_algorithms.MTLAlgorithm) instantiated!",
231            mtl_algorithm_cfg.get("_target_"),
232        )

Instantiate the mtl_algorithm object from mtl_algorithm_cfg, backbone, heads and non_algorithmic_hparams.

def instantiate_optimizer(self, optimizer_cfg: omegaconf.dictconfig.DictConfig) -> None:
234    def instantiate_optimizer(
235        self,
236        optimizer_cfg: DictConfig,
237    ) -> None:
238        r"""Instantiate the optimizer object from `optimizer_cfg`."""
239
240        # partially instantiate optimizer as the 'params' argument is from Lightning Modules cannot be passed for now.
241        pylogger.debug(
242            "Partially instantiating optimizer <%s> (torch.optim.Optimizer)...",
243            optimizer_cfg.get("_target_"),
244        )
245        self.optimizer = hydra.utils.instantiate(optimizer_cfg)
246        pylogger.debug(
247            "Optimizer <%s> (torch.optim.Optimizer) partially instantiated!",
248            optimizer_cfg.get("_target_"),
249        )

Instantiate the optimizer object from optimizer_cfg.

def instantiate_lr_scheduler(self, lr_scheduler_cfg: omegaconf.dictconfig.DictConfig) -> None:
251    def instantiate_lr_scheduler(
252        self,
253        lr_scheduler_cfg: DictConfig,
254    ) -> None:
255        r"""Instantiate the learning rate scheduler object from `lr_scheduler_cfg`."""
256
257        # partially instantiate learning rate scheduler as the 'params' argument is from Lightning Modules cannot be passed for now.
258        pylogger.debug(
259            "Partially instantiating learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) ...",
260            lr_scheduler_cfg.get("_target_"),
261        )
262        self.lr_scheduler = hydra.utils.instantiate(lr_scheduler_cfg)
263        pylogger.debug(
264            "Learning rate scheduler <%s> (torch.optim.lr_scheduler.LRScheduler) partially instantiated!",
265            lr_scheduler_cfg.get("_target_"),
266        )

Instantiate the learning rate scheduler object from lr_scheduler_cfg.

def instantiate_lightning_loggers(self, lightning_loggers_cfg: omegaconf.dictconfig.DictConfig) -> None:
268    def instantiate_lightning_loggers(self, lightning_loggers_cfg: DictConfig) -> None:
269        r"""Instantiate the list of lightning loggers objects from `lightning_loggers_cfg`."""
270        pylogger.debug("Instantiating Lightning loggers (lightning.Logger)...")
271        self.lightning_loggers = [
272            hydra.utils.instantiate(lightning_logger)
273            for lightning_logger in lightning_loggers_cfg.values()
274        ]
275        pylogger.debug("Lightning loggers (lightning.Logger) instantiated!")

Instantiate the list of lightning loggers objects from lightning_loggers_cfg.

def instantiate_callbacks( self, metrics_cfg: omegaconf.dictconfig.DictConfig, callbacks_cfg: omegaconf.dictconfig.DictConfig) -> None:
277    def instantiate_callbacks(
278        self, metrics_cfg: DictConfig, callbacks_cfg: DictConfig
279    ) -> None:
280        r"""Instantiate the list of callbacks objects from `metrics_cfg` and `callbacks_cfg`. Note that `metrics_cfg` is a list of metric callbacks and `callbacks_cfg` is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks."""
281        pylogger.debug("Instantiating callbacks (lightning.Callback)...")
282
283        # instantiate metric callbacks
284        metric_callbacks = [
285            hydra.utils.instantiate(callback) for callback in metrics_cfg
286        ]
287
288        # instantiate other callbacks
289        other_callbacks = [
290            hydra.utils.instantiate(callback) for callback in callbacks_cfg
291        ]
292
293        # add metric callbacks to the list of callbacks
294        self.callbacks = metric_callbacks + other_callbacks
295        pylogger.debug("Callbacks (lightning.Callback) instantiated!")

Instantiate the list of callbacks objects from metrics_cfg and callbacks_cfg. Note that metrics_cfg is a list of metric callbacks and callbacks_cfg is a list of callbacks other the metric callbacks. The instantiated callbacks contain both metric callbacks and other callbacks.

def instantiate_trainer( self, trainer_cfg: omegaconf.dictconfig.DictConfig, lightning_loggers: list[lightning.pytorch.loggers.logger.Logger], callbacks: list[lightning.pytorch.callbacks.callback.Callback]) -> None:
297    def instantiate_trainer(
298        self,
299        trainer_cfg: DictConfig,
300        lightning_loggers: list[Logger],
301        callbacks: list[Callback],
302    ) -> None:
303        r"""Instantiate the trainer object from `trainer_cfg`, `lightning_loggers`, and `callbacks`."""
304
305        pylogger.debug("Instantiating trainer (lightning.Trainer)...")
306        self.trainer = hydra.utils.instantiate(
307            trainer_cfg, logger=lightning_loggers, callbacks=callbacks
308        )
309        pylogger.debug("Trainer (lightning.Trainer) instantiated!")

Instantiate the trainer object from trainer_cfg, lightning_loggers, and callbacks.

def set_global_seed(self, global_seed: int) -> None:
311    def set_global_seed(self, global_seed: int) -> None:
312        r"""Set the `global_seed` for the entire experiment."""
313        L.seed_everything(self.global_seed, workers=True)
314        pylogger.debug("Global seed is set as %d.", global_seed)

Set the global_seed for the entire experiment.

def run(self) -> None:
316    def run(self) -> None:
317        r"""The main method to run the multi-task learning experiment."""
318        self.set_global_seed(self.global_seed)
319
320        self.instantiate_mtl_dataset(mtl_dataset_cfg=self.cfg.mtl_dataset)
321        self.instantiate_backbone(backbone_cfg=self.cfg.backbone)
322        self.instantiate_heads(input_dim=self.cfg.backbone.output_dim)
323        self.instantiate_mtl_algorithm(
324            mtl_algorithm_cfg=self.cfg.mtl_algorithm,
325            backbone=self.backbone,
326            heads=self.heads,
327            non_algorithmic_hparams=select_hyperparameters_from_config(
328                cfg=self.cfg, type=self.cfg.pipeline
329            ),
330        )  # mtl_algorithm should be instantiated after backbone and heads
331        self.instantiate_optimizer(optimizer_cfg=self.cfg.optimizer)
332        self.instantiate_lr_scheduler(lr_scheduler_cfg=self.cfg.lr_scheduler)
333        self.instantiate_lightning_loggers(
334            lightning_loggers_cfg=self.cfg.lightning_loggers
335        )
336        self.instantiate_callbacks(
337            metrics_cfg=self.cfg.metrics, callbacks_cfg=self.cfg.callbacks
338        )
339        self.instantiate_trainer(
340            trainer_cfg=self.cfg.trainer,
341            lightning_loggers=self.lightning_loggers,
342            callbacks=self.callbacks,
343        )  # trainer should be instantiated after lightning loggers and callbacks
344
345        # setup tasks for dataset and model
346        self.mtl_dataset.setup_tasks_expr(
347            train_tasks=self.train_tasks, eval_tasks=self.eval_tasks
348        )
349        self.model.setup_tasks(
350            task_ids=self.train_tasks,
351            num_classes={
352                task_id: len(self.mtl_dataset.get_mtl_class_map(task_id))
353                for task_id in self.train_tasks
354            },
355            optimizer=self.optimizer,
356            lr_scheduler=self.lr_scheduler,
357        )
358
359        # train and validate the model
360        self.trainer.fit(
361            model=self.model,
362            datamodule=self.mtl_dataset,
363        )
364
365        # evaluation after training and validation
366        self.trainer.test(
367            model=self.model,
368            datamodule=self.mtl_dataset,
369        )

The main method to run the multi-task learning experiment.