clarena.pipelines.cul_main_expr
The submodule in pipelines for continual unlearning main experiment.
1r""" 2The submodule in `pipelines` for continual unlearning main experiment. 3""" 4 5__all__ = ["CULMainExperiment"] 6 7import logging 8 9import hydra 10from omegaconf import DictConfig, ListConfig 11 12from clarena.cul_algorithms import CULAlgorithm 13from clarena.pipelines import CLMainExperiment 14from clarena.utils.cfg import select_hyperparameters_from_config 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class CULMainExperiment(CLMainExperiment): 21 r"""The base class for continual unlearning main experiment.""" 22 23 def __init__(self, cfg: DictConfig) -> None: 24 r""" 25 **Args:** 26 - **cfg** (`DictConfig`): the complete config dict for the CUL experiment. 27 """ 28 super().__init__( 29 cfg 30 ) # CUL main experiment inherits all configs from CL main experiment 31 32 CULMainExperiment.sanity_check(self) 33 34 self.cul_algorithm: CULAlgorithm 35 r"""Continual unlearning algorithm object.""" 36 37 self.unlearning_requests: dict[int, list[int]] = cfg.unlearning_requests 38 r"""The unlearning requests for each task in the experiment. Keys are IDs of the tasks that request unlearning after their learning, and values are the list of the previous tasks to be unlearned. Parsed from config and used in the tasks loop.""" 39 self.unlearned_task_ids: set[int] = set() 40 r"""The list of task IDs that have been unlearned in the experiment. Updated in the tasks loop when unlearning requests are made.""" 41 42 self.unlearnable_ages: dict[int, int | None] | int | None = ( 43 cfg.unlearnable_age 44 if isinstance(cfg.unlearnable_age, DictConfig) 45 else { 46 task_id: cfg.unlearnable_age 47 for task_id in range(1, cfg.train_tasks + 1) 48 } 49 ) 50 r"""The dict of task unlearnable ages. Keys are task IDs and values are the unlearnable age of the corresponding task. A task cannot be unlearned when its age (i.e., the number of tasks learned after it) exceeds this value. If `None`, the task is unlearnable at any time.""" 51 52 def sanity_check(self) -> None: 53 r"""Check the sanity of the config dict `self.cfg`.""" 54 55 # check required config fields 56 required_config_fields = [ 57 "pipeline", 58 "expr_name", 59 "cl_paradigm", 60 "train_tasks", 61 "eval_after_tasks", 62 "unlearning_requests", 63 "unlearnable_age", 64 "global_seed", 65 "cl_dataset", 66 "cl_algorithm", 67 "cul_algorithm", 68 "backbone", 69 "optimizer", 70 "lr_scheduler", 71 "trainer", 72 "metrics", 73 "lightning_loggers", 74 "callbacks", 75 "output_dir", 76 # "hydra" is excluded as it doesn't appear 77 "misc", 78 ] 79 80 for field in required_config_fields: 81 if not self.cfg.get(field): 82 raise KeyError( 83 f"Field `{field}` is required in the experiment index config." 84 ) 85 86 # check unlearning requests 87 for task_id, unlearning_task_ids in self.cfg.unlearning_requests.items(): 88 if task_id not in self.train_tasks: 89 raise ValueError( 90 f"Task ID {task_id} in unlearning_requests is not within the train_tasks in the experiment!" 91 ) 92 for unlearning_task_id in unlearning_task_ids: 93 if unlearning_task_id not in self.train_tasks: 94 raise ValueError( 95 f"Unlearning task ID {unlearning_task_id} in unlearning_requests is not within the train_tasks in the experiment!" 96 ) 97 98 def instantiate_cul_algorithm(self, cul_algorithm_cfg: DictConfig) -> None: 99 r"""Instantiate the CUL algorithm object from `cul_algorithm_cfg`.""" 100 pylogger.debug( 101 "Instantiating CUL algorithm <%s> (clarena.cul_algorithms.CULAlgorithm)...", 102 cul_algorithm_cfg.get("_target_"), 103 ) 104 self.cul_algorithm: CULAlgorithm = hydra.utils.instantiate( 105 cul_algorithm_cfg, 106 model=self.model, 107 ) 108 pylogger.debug( 109 "<%s> (clarena.cul_algorithms.CULAlgorithm) instantiated!", 110 cul_algorithm_cfg.get("_target_"), 111 ) 112 113 def unlearnable_task_ids(self, task_id: int) -> list[int]: 114 r"""Get the list of unlearnable task IDs at task `task_id`. 115 116 **Args:** 117 - **task_id** (`int`): the target task ID to check unlearnable task IDs. 118 119 **Returns:** 120 - **unlearnable_task_ids** (`list[int]`): the list of unlearnable task IDs at task `task_id`. 121 """ 122 unlearnable_task_ids = [] 123 for tid in range(1, task_id + 1): 124 unlearnable_age = self.unlearnable_ages[tid] 125 if ( 126 unlearnable_age is None or (task_id - tid) < unlearnable_age 127 ) and tid not in self.unlearned_task_ids: 128 unlearnable_task_ids.append(tid) 129 130 return unlearnable_task_ids 131 132 def task_ids_just_no_longer_unlearnable(self, task_id: int) -> list[int]: 133 r"""Get the list of task IDs just turning not unlearnable at task `task_id`. 134 135 **Args:** 136 - **task_id** (`int`): the target task ID to check. 137 138 **Returns:** 139 - **task_ids_just_no_longer_unlearnable** (`list[int]`): the list of task IDs just turning not unlearnable at task `task_id`. 140 """ 141 task_ids_just_no_longer_unlearnable = [] 142 for tid in range(1, task_id + 1): 143 unlearnable_age = self.unlearnable_ages[tid] 144 if task_id - unlearnable_age == tid and tid not in self.unlearned_task_ids: 145 task_ids_just_no_longer_unlearnable.append(tid) 146 147 return task_ids_just_no_longer_unlearnable 148 149 def run(self) -> None: 150 r"""The main method to run the continual unlearning main experiment.""" 151 152 self.set_global_seed(self.global_seed) 153 154 # global components 155 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 156 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 157 self.instantiate_backbone( 158 backbone_cfg=self.cfg.backbone, disable_unlearning=False 159 ) 160 self.instantiate_heads( 161 cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim 162 ) 163 self.instantiate_cl_algorithm( 164 cl_algorithm_cfg=self.cfg.cl_algorithm, 165 backbone=self.backbone, 166 heads=self.heads, 167 non_algorithmic_hparams=select_hyperparameters_from_config( 168 cfg=self.cfg, type=self.cfg.pipeline 169 ), 170 disable_unlearning=False, 171 ) # cl_algorithm should be instantiated after backbone and heads 172 self.instantiate_cul_algorithm( 173 self.cfg.cul_algorithm 174 ) # cul_algorithm should be instantiated after model 175 self.instantiate_lightning_loggers( 176 lightning_loggers_cfg=self.cfg.lightning_loggers 177 ) 178 self.instantiate_callbacks( 179 metrics_cfg=self.cfg.metrics, 180 callbacks_cfg=self.cfg.callbacks, 181 ) 182 183 # task loop 184 for task_id in self.train_tasks: 185 186 self.task_id = task_id 187 188 # task-specific components 189 self.instantiate_optimizer( 190 optimizer_cfg=self.cfg.optimizer, 191 task_id=task_id, 192 ) 193 if self.cfg.get("lr_scheduler"): 194 self.instantiate_lr_scheduler( 195 lr_scheduler_cfg=self.cfg.lr_scheduler, 196 task_id=task_id, 197 ) 198 self.instantiate_trainer( 199 trainer_cfg=self.cfg.trainer, 200 lightning_loggers=self.lightning_loggers, 201 callbacks=self.callbacks, 202 task_id=task_id, 203 ) # trainer should be instantiated after lightning loggers and callbacks 204 205 # setup task ID for dataset and model 206 self.cl_dataset.setup_task_id(task_id=task_id) 207 self.cul_algorithm.setup_task_id( 208 task_id=self.task_id, 209 unlearning_requests=self.unlearning_requests, 210 unlearnable_task_ids=self.unlearnable_task_ids(self.task_id), 211 task_ids_just_no_longer_unlearnable=self.task_ids_just_no_longer_unlearnable( 212 self.task_id 213 ), 214 ) 215 self.model.setup_task_id( 216 task_id=task_id, 217 num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)), 218 optimizer=self.optimizer_t, 219 lr_scheduler=self.lr_scheduler_t, 220 ) 221 222 # train and validate the model 223 self.trainer_t.fit( 224 model=self.model, 225 datamodule=self.cl_dataset, 226 ) 227 228 # unlearn 229 if self.task_id in self.unlearning_requests.keys(): 230 unlearning_task_ids = self.unlearning_requests[self.task_id] 231 pylogger.info( 232 "Starting unlearning process for tasks: %s...", unlearning_task_ids 233 ) 234 self.cul_algorithm.unlearn() 235 pylogger.info("Unlearning process finished.") 236 237 # for unlearning_task_id in self.cul_algorithm.unlearning_task_ids: 238 # self.processed_task_ids.remove(unlearning_task_id) 239 240 self.cul_algorithm.setup_test_task_id() 241 242 # evaluation after training and validation 243 if task_id in self.eval_after_tasks: 244 self.trainer_t.test( 245 model=self.model, 246 datamodule=self.cl_dataset, 247 ) 248 249 self.processed_task_ids.append(task_id)
21class CULMainExperiment(CLMainExperiment): 22 r"""The base class for continual unlearning main experiment.""" 23 24 def __init__(self, cfg: DictConfig) -> None: 25 r""" 26 **Args:** 27 - **cfg** (`DictConfig`): the complete config dict for the CUL experiment. 28 """ 29 super().__init__( 30 cfg 31 ) # CUL main experiment inherits all configs from CL main experiment 32 33 CULMainExperiment.sanity_check(self) 34 35 self.cul_algorithm: CULAlgorithm 36 r"""Continual unlearning algorithm object.""" 37 38 self.unlearning_requests: dict[int, list[int]] = cfg.unlearning_requests 39 r"""The unlearning requests for each task in the experiment. Keys are IDs of the tasks that request unlearning after their learning, and values are the list of the previous tasks to be unlearned. Parsed from config and used in the tasks loop.""" 40 self.unlearned_task_ids: set[int] = set() 41 r"""The list of task IDs that have been unlearned in the experiment. Updated in the tasks loop when unlearning requests are made.""" 42 43 self.unlearnable_ages: dict[int, int | None] | int | None = ( 44 cfg.unlearnable_age 45 if isinstance(cfg.unlearnable_age, DictConfig) 46 else { 47 task_id: cfg.unlearnable_age 48 for task_id in range(1, cfg.train_tasks + 1) 49 } 50 ) 51 r"""The dict of task unlearnable ages. Keys are task IDs and values are the unlearnable age of the corresponding task. A task cannot be unlearned when its age (i.e., the number of tasks learned after it) exceeds this value. If `None`, the task is unlearnable at any time.""" 52 53 def sanity_check(self) -> None: 54 r"""Check the sanity of the config dict `self.cfg`.""" 55 56 # check required config fields 57 required_config_fields = [ 58 "pipeline", 59 "expr_name", 60 "cl_paradigm", 61 "train_tasks", 62 "eval_after_tasks", 63 "unlearning_requests", 64 "unlearnable_age", 65 "global_seed", 66 "cl_dataset", 67 "cl_algorithm", 68 "cul_algorithm", 69 "backbone", 70 "optimizer", 71 "lr_scheduler", 72 "trainer", 73 "metrics", 74 "lightning_loggers", 75 "callbacks", 76 "output_dir", 77 # "hydra" is excluded as it doesn't appear 78 "misc", 79 ] 80 81 for field in required_config_fields: 82 if not self.cfg.get(field): 83 raise KeyError( 84 f"Field `{field}` is required in the experiment index config." 85 ) 86 87 # check unlearning requests 88 for task_id, unlearning_task_ids in self.cfg.unlearning_requests.items(): 89 if task_id not in self.train_tasks: 90 raise ValueError( 91 f"Task ID {task_id} in unlearning_requests is not within the train_tasks in the experiment!" 92 ) 93 for unlearning_task_id in unlearning_task_ids: 94 if unlearning_task_id not in self.train_tasks: 95 raise ValueError( 96 f"Unlearning task ID {unlearning_task_id} in unlearning_requests is not within the train_tasks in the experiment!" 97 ) 98 99 def instantiate_cul_algorithm(self, cul_algorithm_cfg: DictConfig) -> None: 100 r"""Instantiate the CUL algorithm object from `cul_algorithm_cfg`.""" 101 pylogger.debug( 102 "Instantiating CUL algorithm <%s> (clarena.cul_algorithms.CULAlgorithm)...", 103 cul_algorithm_cfg.get("_target_"), 104 ) 105 self.cul_algorithm: CULAlgorithm = hydra.utils.instantiate( 106 cul_algorithm_cfg, 107 model=self.model, 108 ) 109 pylogger.debug( 110 "<%s> (clarena.cul_algorithms.CULAlgorithm) instantiated!", 111 cul_algorithm_cfg.get("_target_"), 112 ) 113 114 def unlearnable_task_ids(self, task_id: int) -> list[int]: 115 r"""Get the list of unlearnable task IDs at task `task_id`. 116 117 **Args:** 118 - **task_id** (`int`): the target task ID to check unlearnable task IDs. 119 120 **Returns:** 121 - **unlearnable_task_ids** (`list[int]`): the list of unlearnable task IDs at task `task_id`. 122 """ 123 unlearnable_task_ids = [] 124 for tid in range(1, task_id + 1): 125 unlearnable_age = self.unlearnable_ages[tid] 126 if ( 127 unlearnable_age is None or (task_id - tid) < unlearnable_age 128 ) and tid not in self.unlearned_task_ids: 129 unlearnable_task_ids.append(tid) 130 131 return unlearnable_task_ids 132 133 def task_ids_just_no_longer_unlearnable(self, task_id: int) -> list[int]: 134 r"""Get the list of task IDs just turning not unlearnable at task `task_id`. 135 136 **Args:** 137 - **task_id** (`int`): the target task ID to check. 138 139 **Returns:** 140 - **task_ids_just_no_longer_unlearnable** (`list[int]`): the list of task IDs just turning not unlearnable at task `task_id`. 141 """ 142 task_ids_just_no_longer_unlearnable = [] 143 for tid in range(1, task_id + 1): 144 unlearnable_age = self.unlearnable_ages[tid] 145 if task_id - unlearnable_age == tid and tid not in self.unlearned_task_ids: 146 task_ids_just_no_longer_unlearnable.append(tid) 147 148 return task_ids_just_no_longer_unlearnable 149 150 def run(self) -> None: 151 r"""The main method to run the continual unlearning main experiment.""" 152 153 self.set_global_seed(self.global_seed) 154 155 # global components 156 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 157 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 158 self.instantiate_backbone( 159 backbone_cfg=self.cfg.backbone, disable_unlearning=False 160 ) 161 self.instantiate_heads( 162 cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim 163 ) 164 self.instantiate_cl_algorithm( 165 cl_algorithm_cfg=self.cfg.cl_algorithm, 166 backbone=self.backbone, 167 heads=self.heads, 168 non_algorithmic_hparams=select_hyperparameters_from_config( 169 cfg=self.cfg, type=self.cfg.pipeline 170 ), 171 disable_unlearning=False, 172 ) # cl_algorithm should be instantiated after backbone and heads 173 self.instantiate_cul_algorithm( 174 self.cfg.cul_algorithm 175 ) # cul_algorithm should be instantiated after model 176 self.instantiate_lightning_loggers( 177 lightning_loggers_cfg=self.cfg.lightning_loggers 178 ) 179 self.instantiate_callbacks( 180 metrics_cfg=self.cfg.metrics, 181 callbacks_cfg=self.cfg.callbacks, 182 ) 183 184 # task loop 185 for task_id in self.train_tasks: 186 187 self.task_id = task_id 188 189 # task-specific components 190 self.instantiate_optimizer( 191 optimizer_cfg=self.cfg.optimizer, 192 task_id=task_id, 193 ) 194 if self.cfg.get("lr_scheduler"): 195 self.instantiate_lr_scheduler( 196 lr_scheduler_cfg=self.cfg.lr_scheduler, 197 task_id=task_id, 198 ) 199 self.instantiate_trainer( 200 trainer_cfg=self.cfg.trainer, 201 lightning_loggers=self.lightning_loggers, 202 callbacks=self.callbacks, 203 task_id=task_id, 204 ) # trainer should be instantiated after lightning loggers and callbacks 205 206 # setup task ID for dataset and model 207 self.cl_dataset.setup_task_id(task_id=task_id) 208 self.cul_algorithm.setup_task_id( 209 task_id=self.task_id, 210 unlearning_requests=self.unlearning_requests, 211 unlearnable_task_ids=self.unlearnable_task_ids(self.task_id), 212 task_ids_just_no_longer_unlearnable=self.task_ids_just_no_longer_unlearnable( 213 self.task_id 214 ), 215 ) 216 self.model.setup_task_id( 217 task_id=task_id, 218 num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)), 219 optimizer=self.optimizer_t, 220 lr_scheduler=self.lr_scheduler_t, 221 ) 222 223 # train and validate the model 224 self.trainer_t.fit( 225 model=self.model, 226 datamodule=self.cl_dataset, 227 ) 228 229 # unlearn 230 if self.task_id in self.unlearning_requests.keys(): 231 unlearning_task_ids = self.unlearning_requests[self.task_id] 232 pylogger.info( 233 "Starting unlearning process for tasks: %s...", unlearning_task_ids 234 ) 235 self.cul_algorithm.unlearn() 236 pylogger.info("Unlearning process finished.") 237 238 # for unlearning_task_id in self.cul_algorithm.unlearning_task_ids: 239 # self.processed_task_ids.remove(unlearning_task_id) 240 241 self.cul_algorithm.setup_test_task_id() 242 243 # evaluation after training and validation 244 if task_id in self.eval_after_tasks: 245 self.trainer_t.test( 246 model=self.model, 247 datamodule=self.cl_dataset, 248 ) 249 250 self.processed_task_ids.append(task_id)
The base class for continual unlearning main experiment.
24 def __init__(self, cfg: DictConfig) -> None: 25 r""" 26 **Args:** 27 - **cfg** (`DictConfig`): the complete config dict for the CUL experiment. 28 """ 29 super().__init__( 30 cfg 31 ) # CUL main experiment inherits all configs from CL main experiment 32 33 CULMainExperiment.sanity_check(self) 34 35 self.cul_algorithm: CULAlgorithm 36 r"""Continual unlearning algorithm object.""" 37 38 self.unlearning_requests: dict[int, list[int]] = cfg.unlearning_requests 39 r"""The unlearning requests for each task in the experiment. Keys are IDs of the tasks that request unlearning after their learning, and values are the list of the previous tasks to be unlearned. Parsed from config and used in the tasks loop.""" 40 self.unlearned_task_ids: set[int] = set() 41 r"""The list of task IDs that have been unlearned in the experiment. Updated in the tasks loop when unlearning requests are made.""" 42 43 self.unlearnable_ages: dict[int, int | None] | int | None = ( 44 cfg.unlearnable_age 45 if isinstance(cfg.unlearnable_age, DictConfig) 46 else { 47 task_id: cfg.unlearnable_age 48 for task_id in range(1, cfg.train_tasks + 1) 49 } 50 ) 51 r"""The dict of task unlearnable ages. Keys are task IDs and values are the unlearnable age of the corresponding task. A task cannot be unlearned when its age (i.e., the number of tasks learned after it) exceeds this value. If `None`, the task is unlearnable at any time."""
Args:
- cfg (
DictConfig): the complete config dict for the CUL experiment.
The unlearning requests for each task in the experiment. Keys are IDs of the tasks that request unlearning after their learning, and values are the list of the previous tasks to be unlearned. Parsed from config and used in the tasks loop.
The list of task IDs that have been unlearned in the experiment. Updated in the tasks loop when unlearning requests are made.
The dict of task unlearnable ages. Keys are task IDs and values are the unlearnable age of the corresponding task. A task cannot be unlearned when its age (i.e., the number of tasks learned after it) exceeds this value. If None, the task is unlearnable at any time.
53 def sanity_check(self) -> None: 54 r"""Check the sanity of the config dict `self.cfg`.""" 55 56 # check required config fields 57 required_config_fields = [ 58 "pipeline", 59 "expr_name", 60 "cl_paradigm", 61 "train_tasks", 62 "eval_after_tasks", 63 "unlearning_requests", 64 "unlearnable_age", 65 "global_seed", 66 "cl_dataset", 67 "cl_algorithm", 68 "cul_algorithm", 69 "backbone", 70 "optimizer", 71 "lr_scheduler", 72 "trainer", 73 "metrics", 74 "lightning_loggers", 75 "callbacks", 76 "output_dir", 77 # "hydra" is excluded as it doesn't appear 78 "misc", 79 ] 80 81 for field in required_config_fields: 82 if not self.cfg.get(field): 83 raise KeyError( 84 f"Field `{field}` is required in the experiment index config." 85 ) 86 87 # check unlearning requests 88 for task_id, unlearning_task_ids in self.cfg.unlearning_requests.items(): 89 if task_id not in self.train_tasks: 90 raise ValueError( 91 f"Task ID {task_id} in unlearning_requests is not within the train_tasks in the experiment!" 92 ) 93 for unlearning_task_id in unlearning_task_ids: 94 if unlearning_task_id not in self.train_tasks: 95 raise ValueError( 96 f"Unlearning task ID {unlearning_task_id} in unlearning_requests is not within the train_tasks in the experiment!" 97 )
Check the sanity of the config dict self.cfg.
99 def instantiate_cul_algorithm(self, cul_algorithm_cfg: DictConfig) -> None: 100 r"""Instantiate the CUL algorithm object from `cul_algorithm_cfg`.""" 101 pylogger.debug( 102 "Instantiating CUL algorithm <%s> (clarena.cul_algorithms.CULAlgorithm)...", 103 cul_algorithm_cfg.get("_target_"), 104 ) 105 self.cul_algorithm: CULAlgorithm = hydra.utils.instantiate( 106 cul_algorithm_cfg, 107 model=self.model, 108 ) 109 pylogger.debug( 110 "<%s> (clarena.cul_algorithms.CULAlgorithm) instantiated!", 111 cul_algorithm_cfg.get("_target_"), 112 )
Instantiate the CUL algorithm object from cul_algorithm_cfg.
114 def unlearnable_task_ids(self, task_id: int) -> list[int]: 115 r"""Get the list of unlearnable task IDs at task `task_id`. 116 117 **Args:** 118 - **task_id** (`int`): the target task ID to check unlearnable task IDs. 119 120 **Returns:** 121 - **unlearnable_task_ids** (`list[int]`): the list of unlearnable task IDs at task `task_id`. 122 """ 123 unlearnable_task_ids = [] 124 for tid in range(1, task_id + 1): 125 unlearnable_age = self.unlearnable_ages[tid] 126 if ( 127 unlearnable_age is None or (task_id - tid) < unlearnable_age 128 ) and tid not in self.unlearned_task_ids: 129 unlearnable_task_ids.append(tid) 130 131 return unlearnable_task_ids
133 def task_ids_just_no_longer_unlearnable(self, task_id: int) -> list[int]: 134 r"""Get the list of task IDs just turning not unlearnable at task `task_id`. 135 136 **Args:** 137 - **task_id** (`int`): the target task ID to check. 138 139 **Returns:** 140 - **task_ids_just_no_longer_unlearnable** (`list[int]`): the list of task IDs just turning not unlearnable at task `task_id`. 141 """ 142 task_ids_just_no_longer_unlearnable = [] 143 for tid in range(1, task_id + 1): 144 unlearnable_age = self.unlearnable_ages[tid] 145 if task_id - unlearnable_age == tid and tid not in self.unlearned_task_ids: 146 task_ids_just_no_longer_unlearnable.append(tid) 147 148 return task_ids_just_no_longer_unlearnable
150 def run(self) -> None: 151 r"""The main method to run the continual unlearning main experiment.""" 152 153 self.set_global_seed(self.global_seed) 154 155 # global components 156 self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset) 157 self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm) 158 self.instantiate_backbone( 159 backbone_cfg=self.cfg.backbone, disable_unlearning=False 160 ) 161 self.instantiate_heads( 162 cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim 163 ) 164 self.instantiate_cl_algorithm( 165 cl_algorithm_cfg=self.cfg.cl_algorithm, 166 backbone=self.backbone, 167 heads=self.heads, 168 non_algorithmic_hparams=select_hyperparameters_from_config( 169 cfg=self.cfg, type=self.cfg.pipeline 170 ), 171 disable_unlearning=False, 172 ) # cl_algorithm should be instantiated after backbone and heads 173 self.instantiate_cul_algorithm( 174 self.cfg.cul_algorithm 175 ) # cul_algorithm should be instantiated after model 176 self.instantiate_lightning_loggers( 177 lightning_loggers_cfg=self.cfg.lightning_loggers 178 ) 179 self.instantiate_callbacks( 180 metrics_cfg=self.cfg.metrics, 181 callbacks_cfg=self.cfg.callbacks, 182 ) 183 184 # task loop 185 for task_id in self.train_tasks: 186 187 self.task_id = task_id 188 189 # task-specific components 190 self.instantiate_optimizer( 191 optimizer_cfg=self.cfg.optimizer, 192 task_id=task_id, 193 ) 194 if self.cfg.get("lr_scheduler"): 195 self.instantiate_lr_scheduler( 196 lr_scheduler_cfg=self.cfg.lr_scheduler, 197 task_id=task_id, 198 ) 199 self.instantiate_trainer( 200 trainer_cfg=self.cfg.trainer, 201 lightning_loggers=self.lightning_loggers, 202 callbacks=self.callbacks, 203 task_id=task_id, 204 ) # trainer should be instantiated after lightning loggers and callbacks 205 206 # setup task ID for dataset and model 207 self.cl_dataset.setup_task_id(task_id=task_id) 208 self.cul_algorithm.setup_task_id( 209 task_id=self.task_id, 210 unlearning_requests=self.unlearning_requests, 211 unlearnable_task_ids=self.unlearnable_task_ids(self.task_id), 212 task_ids_just_no_longer_unlearnable=self.task_ids_just_no_longer_unlearnable( 213 self.task_id 214 ), 215 ) 216 self.model.setup_task_id( 217 task_id=task_id, 218 num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)), 219 optimizer=self.optimizer_t, 220 lr_scheduler=self.lr_scheduler_t, 221 ) 222 223 # train and validate the model 224 self.trainer_t.fit( 225 model=self.model, 226 datamodule=self.cl_dataset, 227 ) 228 229 # unlearn 230 if self.task_id in self.unlearning_requests.keys(): 231 unlearning_task_ids = self.unlearning_requests[self.task_id] 232 pylogger.info( 233 "Starting unlearning process for tasks: %s...", unlearning_task_ids 234 ) 235 self.cul_algorithm.unlearn() 236 pylogger.info("Unlearning process finished.") 237 238 # for unlearning_task_id in self.cul_algorithm.unlearning_task_ids: 239 # self.processed_task_ids.remove(unlearning_task_id) 240 241 self.cul_algorithm.setup_test_task_id() 242 243 # evaluation after training and validation 244 if task_id in self.eval_after_tasks: 245 self.trainer_t.test( 246 model=self.model, 247 datamodule=self.cl_dataset, 248 ) 249 250 self.processed_task_ids.append(task_id)
The main method to run the continual unlearning main experiment.
Inherited Members
- clarena.pipelines.cl_main_expr.CLMainExperiment
- cfg
- cl_paradigm
- train_tasks
- eval_after_tasks
- global_seed
- output_dir
- cl_dataset
- backbone
- heads
- model
- lightning_loggers
- callbacks
- optimizer_t
- lr_scheduler_t
- trainer_t
- task_id
- processed_task_ids
- instantiate_cl_dataset
- instantiate_backbone
- instantiate_heads
- instantiate_cl_algorithm
- instantiate_optimizer
- instantiate_lr_scheduler
- instantiate_lightning_loggers
- instantiate_callbacks
- instantiate_trainer
- set_global_seed