clarena.utils.cfg

The submodule in utils with tools related to configs.

  1r"""The submodule in `utils` with tools related to configs."""
  2
  3__all__ = [
  4    "preprocess_config",
  5    "cfg_to_tree",
  6    "save_tree_to_file",
  7]
  8
  9import logging
 10import os
 11from copy import deepcopy
 12from typing import Any
 13
 14import rich
 15from omegaconf import DictConfig, ListConfig, OmegaConf
 16from rich.syntax import Syntax
 17from rich.tree import Tree
 18
 19# always get logger for built-in logging in each module
 20pylogger = logging.getLogger(__name__)
 21
 22
 23def preprocess_config(cfg: DictConfig, type: str) -> None:
 24    r"""Preprocess the configuration before constructing experiment, which include:
 25
 26    1. Construct the config for pipelines that borrow from other config.
 27    2. Convert the `DictConfig` to a Rich `Tree`, print the Rich `Tree` and save the Rich `Tree` to a file.
 28
 29    **Args:**
 30    - **cfg** (`DictConfig`): the config dict to preprocess.
 31    - **type** (`str`): the type of the pipeline; one of:
 32        1. 'CL_MAIN_EXPR': continual learning main experiment.
 33        2. 'CL_MAIN_EVAL': continual learning main evaluation.
 34        3. 'CL_REF_JOINT_EXPR': reference joint learning experiment (continual learning).
 35        4. 'CL_REF_INDEPENDENT_EXPR': reference independent learning experiment (continual learning).
 36        5. 'CL_REF_RANDOM_EXPR': reference random learning experiment (continual learning).
 37        6. 'CL_FULL_EVAL': continual learning full evaluation.
 38        7. 'CL_FULL_EVAL_ATTACHED': continual unlearning full evaluation (attached to continual learning full experiment).
 39        8. 'CUL_MAIN_EXPR': continual unlearning main experiment.
 40        9. 'CUL_MAIN_EVAL': continual unlearning main evaluation.
 41        10. 'CUL_REF_RETRAIN_EXPR': reference retrain learning experiment (continual unlearning).
 42        11. 'CUL_REF_ORIGINAL_EXPR': reference original learning experiment (contin
 43        12, 'CUL_FULL_EVAL': continual unlearning full evaluation.
 44        13. 'CUL_FULL_EVAL_ATTACHED': continual unlearning full evaluation (attached to continual unlearning full experiment).
 45        14. 'MTL_EXPR': multi-task learning experiment.
 46        15. 'MTL_EVAL': multi-task learning evaluation.
 47        16. 'STL_EXPR': single-task learning experiment.
 48        17. 'STL_EVAL': single-task learning evaluation.
 49
 50    **Returns:**
 51    - **cfg** (`DictConfig`): the preprocessed config dict.
 52    """
 53    cfg = deepcopy(cfg)
 54
 55    OmegaConf.set_struct(cfg, False)  # enable editing
 56
 57    if type in [
 58        "CL_MAIN_EXPR",
 59        "CL_MAIN_EVAL",
 60        "CL_FULL_EVAL",
 61        "CUL_MAIN_EXPR",
 62        "CUL_MAIN_EVAL",
 63        "CUL_FULL_EVAL",
 64        "MTL_EXPR",
 65        "MTL_EVAL",
 66        "STL_EXPR",
 67        "STL_EVAL",
 68    ]:
 69        pass  # keep the config unchanged
 70
 71    if type == "CL_REF_JOINT_EXPR":
 72        # construct the config for reference joint learning experiment (continual learning) from the config for continual learning main experiment
 73
 74        # set the output directory under the CL main experiment output directory
 75        cfg.output_dir = os.path.join(cfg.output_dir, "refjoint")
 76
 77        # set the CL paradigm to None, since this is a joint learning experiment
 78        del cfg.cl_paradigm
 79
 80        # set the eval tasks to the train tasks
 81        cfg.eval_tasks = cfg.train_tasks
 82
 83        # set the eval after tasks to None, since this is a joint learning experiment
 84        del cfg.eval_after_tasks
 85
 86        cl_dataset_cfg = cfg.cl_dataset
 87
 88        # add the mtl_dataset to the config
 89        cfg.mtl_dataset = {
 90            "_target_": "clarena.mtl_datasets.MTLDatasetFromCL",
 91            "cl_dataset": cl_dataset_cfg,
 92            "sampling_strategy": "mixed",
 93            "batch_size": (
 94                cl_dataset_cfg.batch_size
 95                if isinstance(cl_dataset_cfg.batch_size, int)
 96                else cl_dataset_cfg.batch_size[0]
 97            ),
 98        }
 99
100        # delete the cl_dataset, since this is a joint learning experiment
101        del cfg.cl_dataset
102
103        # delete the cl_algorithm, since this is a joint learning experiment
104        del cfg.cl_algorithm
105
106        # add the mtl_algorithm to the config
107        cfg.mtl_algorithm = {"_target_": "clarena.mtl_algorithms.JointLearning"}
108
109        # revise metrics
110        new_metrics = []
111        for metric in cfg.metrics:
112            target = metric.get("_target_")
113            if target == "clarena.metrics.CLAccuracy":
114                new_metrics.append(
115                    {
116                        "_target_": "clarena.metrics.MTLAccuracy",
117                        "save_dir": "${output_dir}/results/",
118                        "test_acc_csv_name": "acc.csv",
119                        "test_acc_plot_name": "acc.png",
120                    }
121                )
122            elif target == "clarena.metrics.CLLoss":
123                new_metrics.append(
124                    {
125                        "_target_": "clarena.metrics.MTLLoss",
126                        "save_dir": "${output_dir}/results/",
127                        "test_loss_cls_csv_name": "loss_cls.csv",
128                        "test_loss_cls_plot_name": "loss_cls.png",
129                    }
130                )
131            else:
132                new_metrics.append(metric)
133        cfg.metrics = new_metrics
134
135        # revise callbacks
136        for callback in cfg.callbacks:
137            if callback.get("_target_") == "clarena.callbacks.CLPylogger":
138                callback["_target_"] = "clarena.callbacks.MTLPylogger"
139
140    elif type == "CL_REF_INDEPENDENT_EXPR":
141        # construct the config for reference independent learning experiment (continual learning) from the config for continual learning main experiment
142
143        # set the output directory under the CL main experiment output directory
144        cfg.output_dir = os.path.join(cfg.output_dir, "refindependent")
145
146        # change the cl_algorithm in the config
147        cfg.cl_algorithm = {"_target_": "clarena.cl_algorithms.Independent"}
148
149    elif type == "CL_REF_RANDOM_EXPR":
150        # construct the config for reference random learning experiment (continual learning) from the config for continual learning main experiment
151
152        # set the output directory under the CL main experiment output directory
153        cfg.output_dir = os.path.join(cfg.output_dir, "refrandom")
154
155        # change the cl_algorithm in the config
156        cfg.cl_algorithm = {"_target_": "clarena.cl_algorithms.Random"}
157
158    elif type == "CL_FULL_EVAL_ATTACHED":
159        # construct the config for continual learning full evaluation from the config for continual learning main experiment
160
161        eval_tasks = cfg.train_tasks
162
163        main_acc_csv_path = os.path.join(cfg.output_dir, "results", "acc.csv")
164
165        if cfg.get("refjoint_acc_csv_path"):
166            refjoint_acc_csv_path = cfg.refjoint_acc_csv_path
167        else:
168            refjoint_acc_csv_path = os.path.join(
169                cfg.output_dir, "refjoint", "results", "acc.csv"
170            )
171
172        if cfg.get("refindependent_acc_csv_path"):
173            refindependent_acc_csv_path = cfg.refindependent_acc_csv_path
174        else:
175            refindependent_acc_csv_path = os.path.join(
176                cfg.output_dir, "refindependent", "results", "acc.csv"
177            )
178
179        if cfg.get("refrandom_acc_csv_path"):
180            refrandom_acc_csv_path = cfg.refrandom_acc_csv_path
181        else:
182            refrandom_acc_csv_path = os.path.join(
183                cfg.output_dir, "refrandom", "results", "acc.csv"
184            )
185
186        output_dir = cfg.output_dir
187        bwt_save_dir = os.path.join(output_dir, "results")
188        bwt_csv_name = "bwt.csv"
189        bwt_plot_name = "bwt.png"
190        fwt_save_dir = os.path.join(output_dir, "results")
191        fwt_csv_name = "fwt.csv"
192        fwt_plot_name = "fwt.png"
193        fr_save_dir = os.path.join(output_dir, "results")
194        fr_csv_name = "fr.csv"
195        misc_cfg = cfg.misc
196
197        cfg = OmegaConf.create(
198            {
199                "pipeline": "CL_FULL_EVAL",
200                "eval_tasks": eval_tasks,
201                "main_acc_csv_path": main_acc_csv_path,
202                "refjoint_acc_csv_path": refjoint_acc_csv_path,
203                "refindependent_acc_csv_path": refindependent_acc_csv_path,
204                "refrandom_acc_csv_path": refrandom_acc_csv_path,
205                "output_dir": output_dir,
206                "bwt_save_dir": bwt_save_dir,
207                "bwt_csv_name": bwt_csv_name,
208                "bwt_plot_name": bwt_plot_name,
209                "fwt_save_dir": fwt_save_dir,
210                "fwt_csv_name": fwt_csv_name,
211                "fwt_plot_name": fwt_plot_name,
212                "fr_save_dir": fr_save_dir,
213                "fr_csv_name": fr_csv_name,
214                "misc": misc_cfg,
215            }
216        )
217
218    elif type == "CUL_REF_RETRAIN_EXPR":
219        # construct the config for reference retrain learning experiment (continual unlearning) from the config for continual unlearning main experiment
220
221        # set the output directory under the main experiment output directory
222        cfg.output_dir = os.path.join(cfg.output_dir, "refretrain")
223
224        # skip the unlearning tasks specified in unlearning_requests
225        train_tasks = (
226            cfg.train_tasks
227            if isinstance(cfg.train_tasks, ListConfig)
228            else ListConfig(list(range(1, cfg.train_tasks + 1)))
229        )
230        for unlearning_task_ids in cfg.unlearning_requests.values():
231            for unlearning_task_id in unlearning_task_ids:
232                if unlearning_task_id in train_tasks:
233                    train_tasks.remove(unlearning_task_id)
234        cfg.train_tasks = train_tasks
235
236        # delete the unlearning configs, since this is a continual learning experiment
237        del cfg.cul_algorithm, cfg.unlearning_requests
238
239        # revise callbacks
240        for callback in cfg.callbacks:
241            if callback.get("_target_") == "clarena.callbacks.CULPylogger":
242                callback["_target_"] = "clarena.callbacks.CLPylogger"
243
244    elif type == "CUL_REF_ORIGINAL_EXPR":
245        # construct the config for reference original learning experiment (continual unlearning) from the config for continual unlearning main experiment
246
247        # set the output directory under the main experiment output directory
248        cfg.output_dir = os.path.join(cfg.output_dir, "reforiginal")
249
250        # just do the CLExperiment using the unlearning experiment config will automatically ignore the unlearning process, which is exactly the full experiment
251
252        # delete the unlearning configs, since this is a continual learning experiment
253        del cfg.cul_algorithm, cfg.unlearning_requests
254
255        # revise callbacks
256        for callback in cfg.callbacks:
257            if callback.get("_target_") == "clarena.callbacks.CULPylogger":
258                callback["_target_"] = "clarena.callbacks.CLPylogger"
259
260    elif type == "CUL_FULL_EVAL_ATTACHED":
261        # construct the config for continual unlearning full evaluation from the config for continual unlearning main experiment
262
263        dd_eval_tasks = cfg.train_tasks
264        if_run_reforiginal = cfg.get("if_run_reforiginal") is not False
265        ag_eval_tasks = cfg.train_tasks if if_run_reforiginal else []
266        global_seed = cfg.global_seed
267
268        main_model_path = os.path.join(cfg.output_dir, "saved_models", "cl_model.pth")
269
270        if cfg.get("refretrain_model_path"):
271            refretrain_model_path = cfg.refretrain_model_path
272        else:
273            refretrain_model_path = os.path.join(
274                cfg.output_dir, "refretrain", "saved_models", "cl_model.pth"
275            )
276
277        if not if_run_reforiginal:
278            reforiginal_model_path = None
279        elif cfg.get("reforiginal_model_path"):
280            reforiginal_model_path = cfg.reforiginal_model_path
281        else:
282            reforiginal_model_path = os.path.join(
283                cfg.output_dir, "reforiginal", "saved_models", "cl_model.pth"
284            )
285
286        cl_paradigm = cfg.cl_paradigm
287        cl_dataset = cfg.cl_dataset
288        trainer = cfg.trainer
289        metrics = [
290            {
291                "_target_": "clarena.metrics.CULDistributionDistance",
292                "save_dir": "${output_dir}/results/",
293                "distribution_distance_type": "linear_cka",
294                "distribution_distance_csv_name": "dd.csv",
295                "distribution_distance_plot_name": "dd.png",
296            }
297        ]
298        if if_run_reforiginal:
299            metrics.append(
300                {
301                    "_target_": "clarena.metrics.CULAccuracyGain",
302                    "save_dir": "${output_dir}/results/",
303                    "accuracy_gain_csv_name": "ag.csv",
304                    "accuracy_gain_plot_name": "ag.png",
305                }
306            )
307        metrics = OmegaConf.create(metrics)
308        callbacks = OmegaConf.create(
309            [
310                {
311                    "_target_": "lightning.pytorch.callbacks.RichProgressBar",
312                },
313            ]
314        )
315        misc = cfg.misc
316        output_dir = cfg.output_dir
317
318        cfg = OmegaConf.create(
319            {
320                "pipeline": "CUL_FULL_EVAL",
321                "dd_eval_tasks": dd_eval_tasks,
322                "ag_eval_tasks": ag_eval_tasks,
323                "global_seed": global_seed,
324                "main_model_path": main_model_path,
325                "refretrain_model_path": refretrain_model_path,
326                "reforiginal_model_path": reforiginal_model_path,
327                "if_run_reforiginal": if_run_reforiginal,
328                "cl_paradigm": cl_paradigm,
329                "cl_dataset": cl_dataset,
330                "trainer": trainer,
331                "metrics": metrics,
332                "callbacks": callbacks,
333                "misc": misc,
334                "output_dir": output_dir,
335            }
336        )
337
338    if type == "CUL_FULL_EVAL" and cfg.get("if_run_reforiginal") is False:
339        cfg.reforiginal_model_path = None
340        if cfg.get("ag_eval_tasks") is not None:
341            cfg.ag_eval_tasks = OmegaConf.create([])
342        if cfg.get("metrics"):
343            cfg.metrics = OmegaConf.create(
344                [
345                    metric
346                    for metric in cfg.metrics
347                    if metric.get("_target_") != "clarena.metrics.CULAccuracyGain"
348                ]
349            )
350
351    # For full-experiment attached evals, skip writing config_tree.log (last eval step).
352    if type in ["CL_FULL_EVAL_ATTACHED", "CUL_FULL_EVAL_ATTACHED"]:
353        if cfg.get("misc") and cfg.misc.get("config_tree"):
354            cfg.misc.config_tree.save = False
355
356    OmegaConf.set_struct(cfg, True)
357
358    if cfg.get("misc"):
359        if cfg.misc.get("config_tree"):
360            # parse config used for config tree
361            config_tree_cfg = cfg.misc.config_tree
362            if_print = (
363                config_tree_cfg.print
364            )  # to avoid using `print` as a variable name, which is supposed to be a built-in function
365            save = config_tree_cfg.save
366            save_path = config_tree_cfg.save_path
367
368            # convert config to tree
369            tree = cfg_to_tree(cfg, config_tree_cfg)
370
371            if if_print:
372                rich.print(tree)  # print the tree
373            if save:
374                save_tree_to_file(tree, save_path)  # save the tree to file
375
376    return cfg
377
378
379def select_hyperparameters_from_config(cfg: DictConfig, type: str) -> dict[str, Any]:
380    r"""Select hyperparameters from the configuration based on the experiment type.
381
382    **Args:**
383    - **cfg** (`DictConfig`): the config dict to select hyperparameters from.
384    - **type** (`str`): the type of the experiment; one of:
385        1. 'CL_MAIN_EXPR': continual learning main experiment.
386        2. 'CL_REF_JOINT_EXPR': reference joint learning experiment (continual learning).
387        3. 'CL_REF_INDEPENDENT_EXPR': reference independent learning experiment (continual learning).
388        4. 'CL_REF_RANDOM_EXPR': reference random learning experiment (continual learning).
389        5. 'CL_FULL_EVAL_ATTACHED': continual unlearning full evaluation (attached to continual learning full experiment).
390        6. 'CUL_MAIN_EXPR': continual unlearning main experiment.
391        7. 'CUL_REF_RETRAIN_EXPR': reference retrain learning experiment (continual unlearning).
392        8. 'CUL_REF_ORIGINAL_EXPR': reference original learning experiment (contin
393        9. 'MTL_EXPR': multi-task learning experiment.
394        10. 'STL_EXPR': single-task learning experiment.
395
396    **Returns:**
397    - **hyperparameters** (`dict[str, Any]`): the selected hyperparameters.
398    """
399    hparams = {}
400
401    if cfg.get("cl_dataset"):
402        hparams["batch_size"] = cfg.cl_dataset.batch_size
403    elif cfg.get("mtl_dataset"):
404        hparams["batch_size"] = cfg.mtl_dataset.batch_size
405    elif cfg.get("stl_dataset"):
406        hparams["batch_size"] = cfg.stl_dataset.batch_size
407
408    # take backbone hyperparameters
409    hparams["backbone"] = cfg.backbone.get("_target_")
410    for k, v in cfg.backbone.items():
411        if k != "_target_":
412            hparams[f"backbone.{k}"] = v
413
414    # take optimizer hyperparameters
415    if isinstance(
416        cfg.optimizer, ListConfig
417    ):  # only apply to uniform optimizer, or it will be too messy
418        hparams["optimizer"] = cfg.optimizer.get("_target_")
419        for k, v in cfg.optimizer.items():
420            if k != "_target_" and k != "_partial_":
421                hparams[f"optimizer.{k}"] = v
422
423    # take lr_scheduler hyperparameters
424    if cfg.get("lr_scheduler"):
425        if isinstance(
426            cfg.lr_scheduler, ListConfig
427        ):  # only apply to uniform lr_scheduler, or it will be too messy
428            hparams["lr_scheduler"] = cfg.lr_scheduler.get("_target_")
429            for k, v in cfg.lr_scheduler.items():
430                if k != "_target_" and k != "_partial_":
431                    hparams[f"lr_scheduler.{k}"] = v
432
433    return hparams
434
435
436def cfg_to_tree(cfg: DictConfig, config_tree_cfg: DictConfig) -> Tree:
437    r"""Convert the configuration to a Rich `Tree`.
438
439    **Args:**
440    - **cfg** (`DictConfig`): the target config dict to be converted.
441    - **config_tree_cfg** (`DictConfig`): the configuration for conversion of config tree.
442
443    **Returns:**
444    - **tree** (`Tree`): the Rich `Tree`.
445    """
446    # configs for tree
447    style = config_tree_cfg.style
448    guide_style = config_tree_cfg.guide_style
449
450    # initialize the tree
451    tree = rich.tree.Tree(label="CONFIG", style=style, guide_style=guide_style)
452
453    queue = []
454
455    # add all fields to queue
456    for field in cfg:
457        queue.append(field)
458
459    # generate config tree from queue
460    for field in queue:
461        branch = tree.add(field, style=style, guide_style=guide_style)
462        field_cfg = cfg[field]
463        branch_content = (
464            OmegaConf.to_yaml(field_cfg, resolve=True)
465            if isinstance(field_cfg, DictConfig)
466            else str(field_cfg)
467        )
468        branch.add(Syntax(branch_content, "yaml"))
469
470    return tree
471
472
473def save_tree_to_file(tree: dict, save_path: str) -> None:
474    """Save Rich `Tree` to a file.
475
476    **Args:**
477    - **tree** (`dict`): the Rich `Tree` to save.
478    - **save_path** (`str`): the path to save the tree.
479    """
480    if not os.path.exists(save_path):
481        os.makedirs(os.path.dirname(save_path), exist_ok=True)
482
483    with open(save_path, "w") as file:
484        rich.print(tree, file=file)
def preprocess_config(cfg: omegaconf.dictconfig.DictConfig, type: str) -> None:
 24def preprocess_config(cfg: DictConfig, type: str) -> None:
 25    r"""Preprocess the configuration before constructing experiment, which include:
 26
 27    1. Construct the config for pipelines that borrow from other config.
 28    2. Convert the `DictConfig` to a Rich `Tree`, print the Rich `Tree` and save the Rich `Tree` to a file.
 29
 30    **Args:**
 31    - **cfg** (`DictConfig`): the config dict to preprocess.
 32    - **type** (`str`): the type of the pipeline; one of:
 33        1. 'CL_MAIN_EXPR': continual learning main experiment.
 34        2. 'CL_MAIN_EVAL': continual learning main evaluation.
 35        3. 'CL_REF_JOINT_EXPR': reference joint learning experiment (continual learning).
 36        4. 'CL_REF_INDEPENDENT_EXPR': reference independent learning experiment (continual learning).
 37        5. 'CL_REF_RANDOM_EXPR': reference random learning experiment (continual learning).
 38        6. 'CL_FULL_EVAL': continual learning full evaluation.
 39        7. 'CL_FULL_EVAL_ATTACHED': continual unlearning full evaluation (attached to continual learning full experiment).
 40        8. 'CUL_MAIN_EXPR': continual unlearning main experiment.
 41        9. 'CUL_MAIN_EVAL': continual unlearning main evaluation.
 42        10. 'CUL_REF_RETRAIN_EXPR': reference retrain learning experiment (continual unlearning).
 43        11. 'CUL_REF_ORIGINAL_EXPR': reference original learning experiment (contin
 44        12, 'CUL_FULL_EVAL': continual unlearning full evaluation.
 45        13. 'CUL_FULL_EVAL_ATTACHED': continual unlearning full evaluation (attached to continual unlearning full experiment).
 46        14. 'MTL_EXPR': multi-task learning experiment.
 47        15. 'MTL_EVAL': multi-task learning evaluation.
 48        16. 'STL_EXPR': single-task learning experiment.
 49        17. 'STL_EVAL': single-task learning evaluation.
 50
 51    **Returns:**
 52    - **cfg** (`DictConfig`): the preprocessed config dict.
 53    """
 54    cfg = deepcopy(cfg)
 55
 56    OmegaConf.set_struct(cfg, False)  # enable editing
 57
 58    if type in [
 59        "CL_MAIN_EXPR",
 60        "CL_MAIN_EVAL",
 61        "CL_FULL_EVAL",
 62        "CUL_MAIN_EXPR",
 63        "CUL_MAIN_EVAL",
 64        "CUL_FULL_EVAL",
 65        "MTL_EXPR",
 66        "MTL_EVAL",
 67        "STL_EXPR",
 68        "STL_EVAL",
 69    ]:
 70        pass  # keep the config unchanged
 71
 72    if type == "CL_REF_JOINT_EXPR":
 73        # construct the config for reference joint learning experiment (continual learning) from the config for continual learning main experiment
 74
 75        # set the output directory under the CL main experiment output directory
 76        cfg.output_dir = os.path.join(cfg.output_dir, "refjoint")
 77
 78        # set the CL paradigm to None, since this is a joint learning experiment
 79        del cfg.cl_paradigm
 80
 81        # set the eval tasks to the train tasks
 82        cfg.eval_tasks = cfg.train_tasks
 83
 84        # set the eval after tasks to None, since this is a joint learning experiment
 85        del cfg.eval_after_tasks
 86
 87        cl_dataset_cfg = cfg.cl_dataset
 88
 89        # add the mtl_dataset to the config
 90        cfg.mtl_dataset = {
 91            "_target_": "clarena.mtl_datasets.MTLDatasetFromCL",
 92            "cl_dataset": cl_dataset_cfg,
 93            "sampling_strategy": "mixed",
 94            "batch_size": (
 95                cl_dataset_cfg.batch_size
 96                if isinstance(cl_dataset_cfg.batch_size, int)
 97                else cl_dataset_cfg.batch_size[0]
 98            ),
 99        }
100
101        # delete the cl_dataset, since this is a joint learning experiment
102        del cfg.cl_dataset
103
104        # delete the cl_algorithm, since this is a joint learning experiment
105        del cfg.cl_algorithm
106
107        # add the mtl_algorithm to the config
108        cfg.mtl_algorithm = {"_target_": "clarena.mtl_algorithms.JointLearning"}
109
110        # revise metrics
111        new_metrics = []
112        for metric in cfg.metrics:
113            target = metric.get("_target_")
114            if target == "clarena.metrics.CLAccuracy":
115                new_metrics.append(
116                    {
117                        "_target_": "clarena.metrics.MTLAccuracy",
118                        "save_dir": "${output_dir}/results/",
119                        "test_acc_csv_name": "acc.csv",
120                        "test_acc_plot_name": "acc.png",
121                    }
122                )
123            elif target == "clarena.metrics.CLLoss":
124                new_metrics.append(
125                    {
126                        "_target_": "clarena.metrics.MTLLoss",
127                        "save_dir": "${output_dir}/results/",
128                        "test_loss_cls_csv_name": "loss_cls.csv",
129                        "test_loss_cls_plot_name": "loss_cls.png",
130                    }
131                )
132            else:
133                new_metrics.append(metric)
134        cfg.metrics = new_metrics
135
136        # revise callbacks
137        for callback in cfg.callbacks:
138            if callback.get("_target_") == "clarena.callbacks.CLPylogger":
139                callback["_target_"] = "clarena.callbacks.MTLPylogger"
140
141    elif type == "CL_REF_INDEPENDENT_EXPR":
142        # construct the config for reference independent learning experiment (continual learning) from the config for continual learning main experiment
143
144        # set the output directory under the CL main experiment output directory
145        cfg.output_dir = os.path.join(cfg.output_dir, "refindependent")
146
147        # change the cl_algorithm in the config
148        cfg.cl_algorithm = {"_target_": "clarena.cl_algorithms.Independent"}
149
150    elif type == "CL_REF_RANDOM_EXPR":
151        # construct the config for reference random learning experiment (continual learning) from the config for continual learning main experiment
152
153        # set the output directory under the CL main experiment output directory
154        cfg.output_dir = os.path.join(cfg.output_dir, "refrandom")
155
156        # change the cl_algorithm in the config
157        cfg.cl_algorithm = {"_target_": "clarena.cl_algorithms.Random"}
158
159    elif type == "CL_FULL_EVAL_ATTACHED":
160        # construct the config for continual learning full evaluation from the config for continual learning main experiment
161
162        eval_tasks = cfg.train_tasks
163
164        main_acc_csv_path = os.path.join(cfg.output_dir, "results", "acc.csv")
165
166        if cfg.get("refjoint_acc_csv_path"):
167            refjoint_acc_csv_path = cfg.refjoint_acc_csv_path
168        else:
169            refjoint_acc_csv_path = os.path.join(
170                cfg.output_dir, "refjoint", "results", "acc.csv"
171            )
172
173        if cfg.get("refindependent_acc_csv_path"):
174            refindependent_acc_csv_path = cfg.refindependent_acc_csv_path
175        else:
176            refindependent_acc_csv_path = os.path.join(
177                cfg.output_dir, "refindependent", "results", "acc.csv"
178            )
179
180        if cfg.get("refrandom_acc_csv_path"):
181            refrandom_acc_csv_path = cfg.refrandom_acc_csv_path
182        else:
183            refrandom_acc_csv_path = os.path.join(
184                cfg.output_dir, "refrandom", "results", "acc.csv"
185            )
186
187        output_dir = cfg.output_dir
188        bwt_save_dir = os.path.join(output_dir, "results")
189        bwt_csv_name = "bwt.csv"
190        bwt_plot_name = "bwt.png"
191        fwt_save_dir = os.path.join(output_dir, "results")
192        fwt_csv_name = "fwt.csv"
193        fwt_plot_name = "fwt.png"
194        fr_save_dir = os.path.join(output_dir, "results")
195        fr_csv_name = "fr.csv"
196        misc_cfg = cfg.misc
197
198        cfg = OmegaConf.create(
199            {
200                "pipeline": "CL_FULL_EVAL",
201                "eval_tasks": eval_tasks,
202                "main_acc_csv_path": main_acc_csv_path,
203                "refjoint_acc_csv_path": refjoint_acc_csv_path,
204                "refindependent_acc_csv_path": refindependent_acc_csv_path,
205                "refrandom_acc_csv_path": refrandom_acc_csv_path,
206                "output_dir": output_dir,
207                "bwt_save_dir": bwt_save_dir,
208                "bwt_csv_name": bwt_csv_name,
209                "bwt_plot_name": bwt_plot_name,
210                "fwt_save_dir": fwt_save_dir,
211                "fwt_csv_name": fwt_csv_name,
212                "fwt_plot_name": fwt_plot_name,
213                "fr_save_dir": fr_save_dir,
214                "fr_csv_name": fr_csv_name,
215                "misc": misc_cfg,
216            }
217        )
218
219    elif type == "CUL_REF_RETRAIN_EXPR":
220        # construct the config for reference retrain learning experiment (continual unlearning) from the config for continual unlearning main experiment
221
222        # set the output directory under the main experiment output directory
223        cfg.output_dir = os.path.join(cfg.output_dir, "refretrain")
224
225        # skip the unlearning tasks specified in unlearning_requests
226        train_tasks = (
227            cfg.train_tasks
228            if isinstance(cfg.train_tasks, ListConfig)
229            else ListConfig(list(range(1, cfg.train_tasks + 1)))
230        )
231        for unlearning_task_ids in cfg.unlearning_requests.values():
232            for unlearning_task_id in unlearning_task_ids:
233                if unlearning_task_id in train_tasks:
234                    train_tasks.remove(unlearning_task_id)
235        cfg.train_tasks = train_tasks
236
237        # delete the unlearning configs, since this is a continual learning experiment
238        del cfg.cul_algorithm, cfg.unlearning_requests
239
240        # revise callbacks
241        for callback in cfg.callbacks:
242            if callback.get("_target_") == "clarena.callbacks.CULPylogger":
243                callback["_target_"] = "clarena.callbacks.CLPylogger"
244
245    elif type == "CUL_REF_ORIGINAL_EXPR":
246        # construct the config for reference original learning experiment (continual unlearning) from the config for continual unlearning main experiment
247
248        # set the output directory under the main experiment output directory
249        cfg.output_dir = os.path.join(cfg.output_dir, "reforiginal")
250
251        # just do the CLExperiment using the unlearning experiment config will automatically ignore the unlearning process, which is exactly the full experiment
252
253        # delete the unlearning configs, since this is a continual learning experiment
254        del cfg.cul_algorithm, cfg.unlearning_requests
255
256        # revise callbacks
257        for callback in cfg.callbacks:
258            if callback.get("_target_") == "clarena.callbacks.CULPylogger":
259                callback["_target_"] = "clarena.callbacks.CLPylogger"
260
261    elif type == "CUL_FULL_EVAL_ATTACHED":
262        # construct the config for continual unlearning full evaluation from the config for continual unlearning main experiment
263
264        dd_eval_tasks = cfg.train_tasks
265        if_run_reforiginal = cfg.get("if_run_reforiginal") is not False
266        ag_eval_tasks = cfg.train_tasks if if_run_reforiginal else []
267        global_seed = cfg.global_seed
268
269        main_model_path = os.path.join(cfg.output_dir, "saved_models", "cl_model.pth")
270
271        if cfg.get("refretrain_model_path"):
272            refretrain_model_path = cfg.refretrain_model_path
273        else:
274            refretrain_model_path = os.path.join(
275                cfg.output_dir, "refretrain", "saved_models", "cl_model.pth"
276            )
277
278        if not if_run_reforiginal:
279            reforiginal_model_path = None
280        elif cfg.get("reforiginal_model_path"):
281            reforiginal_model_path = cfg.reforiginal_model_path
282        else:
283            reforiginal_model_path = os.path.join(
284                cfg.output_dir, "reforiginal", "saved_models", "cl_model.pth"
285            )
286
287        cl_paradigm = cfg.cl_paradigm
288        cl_dataset = cfg.cl_dataset
289        trainer = cfg.trainer
290        metrics = [
291            {
292                "_target_": "clarena.metrics.CULDistributionDistance",
293                "save_dir": "${output_dir}/results/",
294                "distribution_distance_type": "linear_cka",
295                "distribution_distance_csv_name": "dd.csv",
296                "distribution_distance_plot_name": "dd.png",
297            }
298        ]
299        if if_run_reforiginal:
300            metrics.append(
301                {
302                    "_target_": "clarena.metrics.CULAccuracyGain",
303                    "save_dir": "${output_dir}/results/",
304                    "accuracy_gain_csv_name": "ag.csv",
305                    "accuracy_gain_plot_name": "ag.png",
306                }
307            )
308        metrics = OmegaConf.create(metrics)
309        callbacks = OmegaConf.create(
310            [
311                {
312                    "_target_": "lightning.pytorch.callbacks.RichProgressBar",
313                },
314            ]
315        )
316        misc = cfg.misc
317        output_dir = cfg.output_dir
318
319        cfg = OmegaConf.create(
320            {
321                "pipeline": "CUL_FULL_EVAL",
322                "dd_eval_tasks": dd_eval_tasks,
323                "ag_eval_tasks": ag_eval_tasks,
324                "global_seed": global_seed,
325                "main_model_path": main_model_path,
326                "refretrain_model_path": refretrain_model_path,
327                "reforiginal_model_path": reforiginal_model_path,
328                "if_run_reforiginal": if_run_reforiginal,
329                "cl_paradigm": cl_paradigm,
330                "cl_dataset": cl_dataset,
331                "trainer": trainer,
332                "metrics": metrics,
333                "callbacks": callbacks,
334                "misc": misc,
335                "output_dir": output_dir,
336            }
337        )
338
339    if type == "CUL_FULL_EVAL" and cfg.get("if_run_reforiginal") is False:
340        cfg.reforiginal_model_path = None
341        if cfg.get("ag_eval_tasks") is not None:
342            cfg.ag_eval_tasks = OmegaConf.create([])
343        if cfg.get("metrics"):
344            cfg.metrics = OmegaConf.create(
345                [
346                    metric
347                    for metric in cfg.metrics
348                    if metric.get("_target_") != "clarena.metrics.CULAccuracyGain"
349                ]
350            )
351
352    # For full-experiment attached evals, skip writing config_tree.log (last eval step).
353    if type in ["CL_FULL_EVAL_ATTACHED", "CUL_FULL_EVAL_ATTACHED"]:
354        if cfg.get("misc") and cfg.misc.get("config_tree"):
355            cfg.misc.config_tree.save = False
356
357    OmegaConf.set_struct(cfg, True)
358
359    if cfg.get("misc"):
360        if cfg.misc.get("config_tree"):
361            # parse config used for config tree
362            config_tree_cfg = cfg.misc.config_tree
363            if_print = (
364                config_tree_cfg.print
365            )  # to avoid using `print` as a variable name, which is supposed to be a built-in function
366            save = config_tree_cfg.save
367            save_path = config_tree_cfg.save_path
368
369            # convert config to tree
370            tree = cfg_to_tree(cfg, config_tree_cfg)
371
372            if if_print:
373                rich.print(tree)  # print the tree
374            if save:
375                save_tree_to_file(tree, save_path)  # save the tree to file
376
377    return cfg

Preprocess the configuration before constructing experiment, which include:

  1. Construct the config for pipelines that borrow from other config.
  2. Convert the DictConfig to a Rich Tree, print the Rich Tree and save the Rich Tree to a file.

Args:

  • cfg (DictConfig): the config dict to preprocess.
  • type (str): the type of the pipeline; one of:
    1. 'CL_MAIN_EXPR': continual learning main experiment.
    2. 'CL_MAIN_EVAL': continual learning main evaluation.
    3. 'CL_REF_JOINT_EXPR': reference joint learning experiment (continual learning).
    4. 'CL_REF_INDEPENDENT_EXPR': reference independent learning experiment (continual learning).
    5. 'CL_REF_RANDOM_EXPR': reference random learning experiment (continual learning).
    6. 'CL_FULL_EVAL': continual learning full evaluation.
    7. 'CL_FULL_EVAL_ATTACHED': continual unlearning full evaluation (attached to continual learning full experiment).
    8. 'CUL_MAIN_EXPR': continual unlearning main experiment.
    9. 'CUL_MAIN_EVAL': continual unlearning main evaluation.
    10. 'CUL_REF_RETRAIN_EXPR': reference retrain learning experiment (continual unlearning).
    11. 'CUL_REF_ORIGINAL_EXPR': reference original learning experiment (contin 12, 'CUL_FULL_EVAL': continual unlearning full evaluation.
    12. 'CUL_FULL_EVAL_ATTACHED': continual unlearning full evaluation (attached to continual unlearning full experiment).
    13. 'MTL_EXPR': multi-task learning experiment.
    14. 'MTL_EVAL': multi-task learning evaluation.
    15. 'STL_EXPR': single-task learning experiment.
    16. 'STL_EVAL': single-task learning evaluation.

Returns:

  • cfg (DictConfig): the preprocessed config dict.
def cfg_to_tree( cfg: omegaconf.dictconfig.DictConfig, config_tree_cfg: omegaconf.dictconfig.DictConfig) -> rich.tree.Tree:
437def cfg_to_tree(cfg: DictConfig, config_tree_cfg: DictConfig) -> Tree:
438    r"""Convert the configuration to a Rich `Tree`.
439
440    **Args:**
441    - **cfg** (`DictConfig`): the target config dict to be converted.
442    - **config_tree_cfg** (`DictConfig`): the configuration for conversion of config tree.
443
444    **Returns:**
445    - **tree** (`Tree`): the Rich `Tree`.
446    """
447    # configs for tree
448    style = config_tree_cfg.style
449    guide_style = config_tree_cfg.guide_style
450
451    # initialize the tree
452    tree = rich.tree.Tree(label="CONFIG", style=style, guide_style=guide_style)
453
454    queue = []
455
456    # add all fields to queue
457    for field in cfg:
458        queue.append(field)
459
460    # generate config tree from queue
461    for field in queue:
462        branch = tree.add(field, style=style, guide_style=guide_style)
463        field_cfg = cfg[field]
464        branch_content = (
465            OmegaConf.to_yaml(field_cfg, resolve=True)
466            if isinstance(field_cfg, DictConfig)
467            else str(field_cfg)
468        )
469        branch.add(Syntax(branch_content, "yaml"))
470
471    return tree

Convert the configuration to a Rich Tree.

Args:

  • cfg (DictConfig): the target config dict to be converted.
  • config_tree_cfg (DictConfig): the configuration for conversion of config tree.

Returns:

  • tree (Tree): the Rich Tree.
def save_tree_to_file(tree: dict, save_path: str) -> None:
474def save_tree_to_file(tree: dict, save_path: str) -> None:
475    """Save Rich `Tree` to a file.
476
477    **Args:**
478    - **tree** (`dict`): the Rich `Tree` to save.
479    - **save_path** (`str`): the path to save the tree.
480    """
481    if not os.path.exists(save_path):
482        os.makedirs(os.path.dirname(save_path), exist_ok=True)
483
484    with open(save_path, "w") as file:
485        rich.print(tree, file=file)

Save Rich Tree to a file.

Args:

  • tree (dict): the Rich Tree to save.
  • save_path (str): the path to save the tree.