clarena.pipelines.cul_main_expr

The submodule in pipelines for continual unlearning main experiment.

  1r"""
  2The submodule in `pipelines` for continual unlearning main experiment.
  3"""
  4
  5__all__ = ["CULMainExperiment"]
  6
  7import logging
  8
  9import hydra
 10from omegaconf import DictConfig, ListConfig
 11
 12from clarena.cul_algorithms import CULAlgorithm
 13from clarena.pipelines import CLMainExperiment
 14from clarena.utils.cfg import select_hyperparameters_from_config
 15
 16# always get logger for built-in logging in each module
 17pylogger = logging.getLogger(__name__)
 18
 19
 20class CULMainExperiment(CLMainExperiment):
 21    r"""The base class for continual unlearning main experiment."""
 22
 23    def __init__(self, cfg: DictConfig) -> None:
 24        r"""
 25        **Args:**
 26        - **cfg** (`DictConfig`): the complete config dict for the CUL experiment.
 27        """
 28        super().__init__(
 29            cfg
 30        )  # CUL main experiment inherits all configs from CL main experiment
 31
 32        CULMainExperiment.sanity_check(self)
 33
 34        self.cul_algorithm: CULAlgorithm
 35        r"""Continual unlearning algorithm object."""
 36
 37        self.unlearning_requests: dict[int, list[int]] = cfg.unlearning_requests
 38        r"""The unlearning requests for each task in the experiment. Keys are IDs of the tasks that request unlearning after their learning, and values are the list of the previous tasks to be unlearned. Parsed from config and used in the tasks loop."""
 39        self.unlearned_task_ids: set[int] = set()
 40        r"""The list of task IDs that have been unlearned in the experiment. Updated in the tasks loop when unlearning requests are made."""
 41
 42        self.unlearnable_ages: dict[int, int | None] | int | None = (
 43            cfg.unlearnable_age
 44            if isinstance(cfg.unlearnable_age, DictConfig)
 45            else {
 46                task_id: cfg.unlearnable_age
 47                for task_id in range(1, cfg.train_tasks + 1)
 48            }
 49        )
 50        r"""The dict of task unlearnable ages. Keys are task IDs and values are the unlearnable age of the corresponding task. A task cannot be unlearned when its age (i.e., the number of tasks learned after it) exceeds this value. If `None`, the task is unlearnable at any time."""
 51
 52    def sanity_check(self) -> None:
 53        r"""Check the sanity of the config dict `self.cfg`."""
 54
 55        # check required config fields
 56        required_config_fields = [
 57            "pipeline",
 58            "expr_name",
 59            "cl_paradigm",
 60            "train_tasks",
 61            "eval_after_tasks",
 62            "unlearning_requests",
 63            "unlearnable_age",
 64            "global_seed",
 65            "cl_dataset",
 66            "cl_algorithm",
 67            "cul_algorithm",
 68            "backbone",
 69            "optimizer",
 70            "lr_scheduler",
 71            "trainer",
 72            "metrics",
 73            "lightning_loggers",
 74            "callbacks",
 75            "output_dir",
 76            # "hydra" is excluded as it doesn't appear
 77            "misc",
 78        ]
 79
 80        for field in required_config_fields:
 81            if not self.cfg.get(field):
 82                raise KeyError(
 83                    f"Field `{field}` is required in the experiment index config."
 84                )
 85
 86        # check unlearning requests
 87        for task_id, unlearning_task_ids in self.cfg.unlearning_requests.items():
 88            if task_id not in self.train_tasks:
 89                raise ValueError(
 90                    f"Task ID {task_id} in unlearning_requests is not within the train_tasks in the experiment!"
 91                )
 92            for unlearning_task_id in unlearning_task_ids:
 93                if unlearning_task_id not in self.train_tasks:
 94                    raise ValueError(
 95                        f"Unlearning task ID {unlearning_task_id} in unlearning_requests is not within the train_tasks in the experiment!"
 96                    )
 97
 98    def instantiate_cul_algorithm(self, cul_algorithm_cfg: DictConfig) -> None:
 99        r"""Instantiate the CUL algorithm object from `cul_algorithm_cfg`."""
100        pylogger.debug(
101            "Instantiating CUL algorithm <%s> (clarena.cul_algorithms.CULAlgorithm)...",
102            cul_algorithm_cfg.get("_target_"),
103        )
104        self.cul_algorithm: CULAlgorithm = hydra.utils.instantiate(
105            cul_algorithm_cfg,
106            model=self.model,
107        )
108        pylogger.debug(
109            "<%s> (clarena.cul_algorithms.CULAlgorithm) instantiated!",
110            cul_algorithm_cfg.get("_target_"),
111        )
112
113    def unlearnable_task_ids(self, task_id: int) -> list[int]:
114        r"""Get the list of unlearnable task IDs at task `task_id`.
115
116        **Args:**
117        - **task_id** (`int`): the target task ID to check unlearnable task IDs.
118
119        **Returns:**
120        - **unlearnable_task_ids** (`list[int]`): the list of unlearnable task IDs at task `task_id`.
121        """
122        unlearnable_task_ids = []
123        for tid in range(1, task_id + 1):
124            unlearnable_age = self.unlearnable_ages[tid]
125            if (
126                unlearnable_age is None or (task_id - tid) < unlearnable_age
127            ) and tid not in self.unlearned_task_ids:
128                unlearnable_task_ids.append(tid)
129
130        return unlearnable_task_ids
131
132    def task_ids_just_no_longer_unlearnable(self, task_id: int) -> list[int]:
133        r"""Get the list of task IDs just turning not unlearnable at task `task_id`.
134
135        **Args:**
136        - **task_id** (`int`): the target task ID to check.
137
138        **Returns:**
139        - **task_ids_just_no_longer_unlearnable** (`list[int]`): the list of task IDs just turning not unlearnable at task `task_id`.
140        """
141        task_ids_just_no_longer_unlearnable = []
142        for tid in range(1, task_id + 1):
143            unlearnable_age = self.unlearnable_ages[tid]
144            if task_id - unlearnable_age == tid and tid not in self.unlearned_task_ids:
145                task_ids_just_no_longer_unlearnable.append(tid)
146
147        return task_ids_just_no_longer_unlearnable
148
149    def run(self) -> None:
150        r"""The main method to run the continual unlearning main experiment."""
151
152        self.set_global_seed(self.global_seed)
153
154        # global components
155        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
156        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
157        self.instantiate_backbone(
158            backbone_cfg=self.cfg.backbone, disable_unlearning=False
159        )
160        self.instantiate_heads(
161            cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim
162        )
163        self.instantiate_cl_algorithm(
164            cl_algorithm_cfg=self.cfg.cl_algorithm,
165            backbone=self.backbone,
166            heads=self.heads,
167            non_algorithmic_hparams=select_hyperparameters_from_config(
168                cfg=self.cfg, type=self.cfg.pipeline
169            ),
170            disable_unlearning=False,
171        )  # cl_algorithm should be instantiated after backbone and heads
172        self.instantiate_cul_algorithm(
173            self.cfg.cul_algorithm
174        )  # cul_algorithm should be instantiated after model
175        self.instantiate_lightning_loggers(
176            lightning_loggers_cfg=self.cfg.lightning_loggers
177        )
178        self.instantiate_callbacks(
179            metrics_cfg=self.cfg.metrics,
180            callbacks_cfg=self.cfg.callbacks,
181        )
182
183        # task loop
184        for task_id in self.train_tasks:
185
186            self.task_id = task_id
187
188            # task-specific components
189            self.instantiate_optimizer(
190                optimizer_cfg=self.cfg.optimizer,
191                task_id=task_id,
192            )
193            if self.cfg.get("lr_scheduler"):
194                self.instantiate_lr_scheduler(
195                    lr_scheduler_cfg=self.cfg.lr_scheduler,
196                    task_id=task_id,
197                )
198            self.instantiate_trainer(
199                trainer_cfg=self.cfg.trainer,
200                lightning_loggers=self.lightning_loggers,
201                callbacks=self.callbacks,
202                task_id=task_id,
203            )  # trainer should be instantiated after lightning loggers and callbacks
204
205            # setup task ID for dataset and model
206            self.cl_dataset.setup_task_id(task_id=task_id)
207            self.cul_algorithm.setup_task_id(
208                task_id=self.task_id,
209                unlearning_requests=self.unlearning_requests,
210                unlearnable_task_ids=self.unlearnable_task_ids(self.task_id),
211                task_ids_just_no_longer_unlearnable=self.task_ids_just_no_longer_unlearnable(
212                    self.task_id
213                ),
214            )
215            self.model.setup_task_id(
216                task_id=task_id,
217                num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)),
218                optimizer=self.optimizer_t,
219                lr_scheduler=self.lr_scheduler_t,
220            )
221
222            # train and validate the model
223            self.trainer_t.fit(
224                model=self.model,
225                datamodule=self.cl_dataset,
226            )
227
228            # unlearn
229            if self.task_id in self.unlearning_requests.keys():
230                unlearning_task_ids = self.unlearning_requests[self.task_id]
231                pylogger.info(
232                    "Starting unlearning process for tasks: %s...", unlearning_task_ids
233                )
234                self.cul_algorithm.unlearn()
235                pylogger.info("Unlearning process finished.")
236
237            # for unlearning_task_id in self.cul_algorithm.unlearning_task_ids:
238            #     self.processed_task_ids.remove(unlearning_task_id)
239
240            self.cul_algorithm.setup_test_task_id()
241
242            # evaluation after training and validation
243            if task_id in self.eval_after_tasks:
244                self.trainer_t.test(
245                    model=self.model,
246                    datamodule=self.cl_dataset,
247                )
248
249            self.processed_task_ids.append(task_id)
class CULMainExperiment(clarena.pipelines.cl_main_expr.CLMainExperiment):
 21class CULMainExperiment(CLMainExperiment):
 22    r"""The base class for continual unlearning main experiment."""
 23
 24    def __init__(self, cfg: DictConfig) -> None:
 25        r"""
 26        **Args:**
 27        - **cfg** (`DictConfig`): the complete config dict for the CUL experiment.
 28        """
 29        super().__init__(
 30            cfg
 31        )  # CUL main experiment inherits all configs from CL main experiment
 32
 33        CULMainExperiment.sanity_check(self)
 34
 35        self.cul_algorithm: CULAlgorithm
 36        r"""Continual unlearning algorithm object."""
 37
 38        self.unlearning_requests: dict[int, list[int]] = cfg.unlearning_requests
 39        r"""The unlearning requests for each task in the experiment. Keys are IDs of the tasks that request unlearning after their learning, and values are the list of the previous tasks to be unlearned. Parsed from config and used in the tasks loop."""
 40        self.unlearned_task_ids: set[int] = set()
 41        r"""The list of task IDs that have been unlearned in the experiment. Updated in the tasks loop when unlearning requests are made."""
 42
 43        self.unlearnable_ages: dict[int, int | None] | int | None = (
 44            cfg.unlearnable_age
 45            if isinstance(cfg.unlearnable_age, DictConfig)
 46            else {
 47                task_id: cfg.unlearnable_age
 48                for task_id in range(1, cfg.train_tasks + 1)
 49            }
 50        )
 51        r"""The dict of task unlearnable ages. Keys are task IDs and values are the unlearnable age of the corresponding task. A task cannot be unlearned when its age (i.e., the number of tasks learned after it) exceeds this value. If `None`, the task is unlearnable at any time."""
 52
 53    def sanity_check(self) -> None:
 54        r"""Check the sanity of the config dict `self.cfg`."""
 55
 56        # check required config fields
 57        required_config_fields = [
 58            "pipeline",
 59            "expr_name",
 60            "cl_paradigm",
 61            "train_tasks",
 62            "eval_after_tasks",
 63            "unlearning_requests",
 64            "unlearnable_age",
 65            "global_seed",
 66            "cl_dataset",
 67            "cl_algorithm",
 68            "cul_algorithm",
 69            "backbone",
 70            "optimizer",
 71            "lr_scheduler",
 72            "trainer",
 73            "metrics",
 74            "lightning_loggers",
 75            "callbacks",
 76            "output_dir",
 77            # "hydra" is excluded as it doesn't appear
 78            "misc",
 79        ]
 80
 81        for field in required_config_fields:
 82            if not self.cfg.get(field):
 83                raise KeyError(
 84                    f"Field `{field}` is required in the experiment index config."
 85                )
 86
 87        # check unlearning requests
 88        for task_id, unlearning_task_ids in self.cfg.unlearning_requests.items():
 89            if task_id not in self.train_tasks:
 90                raise ValueError(
 91                    f"Task ID {task_id} in unlearning_requests is not within the train_tasks in the experiment!"
 92                )
 93            for unlearning_task_id in unlearning_task_ids:
 94                if unlearning_task_id not in self.train_tasks:
 95                    raise ValueError(
 96                        f"Unlearning task ID {unlearning_task_id} in unlearning_requests is not within the train_tasks in the experiment!"
 97                    )
 98
 99    def instantiate_cul_algorithm(self, cul_algorithm_cfg: DictConfig) -> None:
100        r"""Instantiate the CUL algorithm object from `cul_algorithm_cfg`."""
101        pylogger.debug(
102            "Instantiating CUL algorithm <%s> (clarena.cul_algorithms.CULAlgorithm)...",
103            cul_algorithm_cfg.get("_target_"),
104        )
105        self.cul_algorithm: CULAlgorithm = hydra.utils.instantiate(
106            cul_algorithm_cfg,
107            model=self.model,
108        )
109        pylogger.debug(
110            "<%s> (clarena.cul_algorithms.CULAlgorithm) instantiated!",
111            cul_algorithm_cfg.get("_target_"),
112        )
113
114    def unlearnable_task_ids(self, task_id: int) -> list[int]:
115        r"""Get the list of unlearnable task IDs at task `task_id`.
116
117        **Args:**
118        - **task_id** (`int`): the target task ID to check unlearnable task IDs.
119
120        **Returns:**
121        - **unlearnable_task_ids** (`list[int]`): the list of unlearnable task IDs at task `task_id`.
122        """
123        unlearnable_task_ids = []
124        for tid in range(1, task_id + 1):
125            unlearnable_age = self.unlearnable_ages[tid]
126            if (
127                unlearnable_age is None or (task_id - tid) < unlearnable_age
128            ) and tid not in self.unlearned_task_ids:
129                unlearnable_task_ids.append(tid)
130
131        return unlearnable_task_ids
132
133    def task_ids_just_no_longer_unlearnable(self, task_id: int) -> list[int]:
134        r"""Get the list of task IDs just turning not unlearnable at task `task_id`.
135
136        **Args:**
137        - **task_id** (`int`): the target task ID to check.
138
139        **Returns:**
140        - **task_ids_just_no_longer_unlearnable** (`list[int]`): the list of task IDs just turning not unlearnable at task `task_id`.
141        """
142        task_ids_just_no_longer_unlearnable = []
143        for tid in range(1, task_id + 1):
144            unlearnable_age = self.unlearnable_ages[tid]
145            if task_id - unlearnable_age == tid and tid not in self.unlearned_task_ids:
146                task_ids_just_no_longer_unlearnable.append(tid)
147
148        return task_ids_just_no_longer_unlearnable
149
150    def run(self) -> None:
151        r"""The main method to run the continual unlearning main experiment."""
152
153        self.set_global_seed(self.global_seed)
154
155        # global components
156        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
157        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
158        self.instantiate_backbone(
159            backbone_cfg=self.cfg.backbone, disable_unlearning=False
160        )
161        self.instantiate_heads(
162            cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim
163        )
164        self.instantiate_cl_algorithm(
165            cl_algorithm_cfg=self.cfg.cl_algorithm,
166            backbone=self.backbone,
167            heads=self.heads,
168            non_algorithmic_hparams=select_hyperparameters_from_config(
169                cfg=self.cfg, type=self.cfg.pipeline
170            ),
171            disable_unlearning=False,
172        )  # cl_algorithm should be instantiated after backbone and heads
173        self.instantiate_cul_algorithm(
174            self.cfg.cul_algorithm
175        )  # cul_algorithm should be instantiated after model
176        self.instantiate_lightning_loggers(
177            lightning_loggers_cfg=self.cfg.lightning_loggers
178        )
179        self.instantiate_callbacks(
180            metrics_cfg=self.cfg.metrics,
181            callbacks_cfg=self.cfg.callbacks,
182        )
183
184        # task loop
185        for task_id in self.train_tasks:
186
187            self.task_id = task_id
188
189            # task-specific components
190            self.instantiate_optimizer(
191                optimizer_cfg=self.cfg.optimizer,
192                task_id=task_id,
193            )
194            if self.cfg.get("lr_scheduler"):
195                self.instantiate_lr_scheduler(
196                    lr_scheduler_cfg=self.cfg.lr_scheduler,
197                    task_id=task_id,
198                )
199            self.instantiate_trainer(
200                trainer_cfg=self.cfg.trainer,
201                lightning_loggers=self.lightning_loggers,
202                callbacks=self.callbacks,
203                task_id=task_id,
204            )  # trainer should be instantiated after lightning loggers and callbacks
205
206            # setup task ID for dataset and model
207            self.cl_dataset.setup_task_id(task_id=task_id)
208            self.cul_algorithm.setup_task_id(
209                task_id=self.task_id,
210                unlearning_requests=self.unlearning_requests,
211                unlearnable_task_ids=self.unlearnable_task_ids(self.task_id),
212                task_ids_just_no_longer_unlearnable=self.task_ids_just_no_longer_unlearnable(
213                    self.task_id
214                ),
215            )
216            self.model.setup_task_id(
217                task_id=task_id,
218                num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)),
219                optimizer=self.optimizer_t,
220                lr_scheduler=self.lr_scheduler_t,
221            )
222
223            # train and validate the model
224            self.trainer_t.fit(
225                model=self.model,
226                datamodule=self.cl_dataset,
227            )
228
229            # unlearn
230            if self.task_id in self.unlearning_requests.keys():
231                unlearning_task_ids = self.unlearning_requests[self.task_id]
232                pylogger.info(
233                    "Starting unlearning process for tasks: %s...", unlearning_task_ids
234                )
235                self.cul_algorithm.unlearn()
236                pylogger.info("Unlearning process finished.")
237
238            # for unlearning_task_id in self.cul_algorithm.unlearning_task_ids:
239            #     self.processed_task_ids.remove(unlearning_task_id)
240
241            self.cul_algorithm.setup_test_task_id()
242
243            # evaluation after training and validation
244            if task_id in self.eval_after_tasks:
245                self.trainer_t.test(
246                    model=self.model,
247                    datamodule=self.cl_dataset,
248                )
249
250            self.processed_task_ids.append(task_id)

The base class for continual unlearning main experiment.

CULMainExperiment(cfg: omegaconf.dictconfig.DictConfig)
24    def __init__(self, cfg: DictConfig) -> None:
25        r"""
26        **Args:**
27        - **cfg** (`DictConfig`): the complete config dict for the CUL experiment.
28        """
29        super().__init__(
30            cfg
31        )  # CUL main experiment inherits all configs from CL main experiment
32
33        CULMainExperiment.sanity_check(self)
34
35        self.cul_algorithm: CULAlgorithm
36        r"""Continual unlearning algorithm object."""
37
38        self.unlearning_requests: dict[int, list[int]] = cfg.unlearning_requests
39        r"""The unlearning requests for each task in the experiment. Keys are IDs of the tasks that request unlearning after their learning, and values are the list of the previous tasks to be unlearned. Parsed from config and used in the tasks loop."""
40        self.unlearned_task_ids: set[int] = set()
41        r"""The list of task IDs that have been unlearned in the experiment. Updated in the tasks loop when unlearning requests are made."""
42
43        self.unlearnable_ages: dict[int, int | None] | int | None = (
44            cfg.unlearnable_age
45            if isinstance(cfg.unlearnable_age, DictConfig)
46            else {
47                task_id: cfg.unlearnable_age
48                for task_id in range(1, cfg.train_tasks + 1)
49            }
50        )
51        r"""The dict of task unlearnable ages. Keys are task IDs and values are the unlearnable age of the corresponding task. A task cannot be unlearned when its age (i.e., the number of tasks learned after it) exceeds this value. If `None`, the task is unlearnable at any time."""

Args:

  • cfg (DictConfig): the complete config dict for the CUL experiment.

Continual unlearning algorithm object.

unlearning_requests: dict[int, list[int]]

The unlearning requests for each task in the experiment. Keys are IDs of the tasks that request unlearning after their learning, and values are the list of the previous tasks to be unlearned. Parsed from config and used in the tasks loop.

unlearned_task_ids: set[int]

The list of task IDs that have been unlearned in the experiment. Updated in the tasks loop when unlearning requests are made.

unlearnable_ages: dict[int, int | None] | int | None

The dict of task unlearnable ages. Keys are task IDs and values are the unlearnable age of the corresponding task. A task cannot be unlearned when its age (i.e., the number of tasks learned after it) exceeds this value. If None, the task is unlearnable at any time.

def sanity_check(self) -> None:
53    def sanity_check(self) -> None:
54        r"""Check the sanity of the config dict `self.cfg`."""
55
56        # check required config fields
57        required_config_fields = [
58            "pipeline",
59            "expr_name",
60            "cl_paradigm",
61            "train_tasks",
62            "eval_after_tasks",
63            "unlearning_requests",
64            "unlearnable_age",
65            "global_seed",
66            "cl_dataset",
67            "cl_algorithm",
68            "cul_algorithm",
69            "backbone",
70            "optimizer",
71            "lr_scheduler",
72            "trainer",
73            "metrics",
74            "lightning_loggers",
75            "callbacks",
76            "output_dir",
77            # "hydra" is excluded as it doesn't appear
78            "misc",
79        ]
80
81        for field in required_config_fields:
82            if not self.cfg.get(field):
83                raise KeyError(
84                    f"Field `{field}` is required in the experiment index config."
85                )
86
87        # check unlearning requests
88        for task_id, unlearning_task_ids in self.cfg.unlearning_requests.items():
89            if task_id not in self.train_tasks:
90                raise ValueError(
91                    f"Task ID {task_id} in unlearning_requests is not within the train_tasks in the experiment!"
92                )
93            for unlearning_task_id in unlearning_task_ids:
94                if unlearning_task_id not in self.train_tasks:
95                    raise ValueError(
96                        f"Unlearning task ID {unlearning_task_id} in unlearning_requests is not within the train_tasks in the experiment!"
97                    )

Check the sanity of the config dict self.cfg.

def instantiate_cul_algorithm(self, cul_algorithm_cfg: omegaconf.dictconfig.DictConfig) -> None:
 99    def instantiate_cul_algorithm(self, cul_algorithm_cfg: DictConfig) -> None:
100        r"""Instantiate the CUL algorithm object from `cul_algorithm_cfg`."""
101        pylogger.debug(
102            "Instantiating CUL algorithm <%s> (clarena.cul_algorithms.CULAlgorithm)...",
103            cul_algorithm_cfg.get("_target_"),
104        )
105        self.cul_algorithm: CULAlgorithm = hydra.utils.instantiate(
106            cul_algorithm_cfg,
107            model=self.model,
108        )
109        pylogger.debug(
110            "<%s> (clarena.cul_algorithms.CULAlgorithm) instantiated!",
111            cul_algorithm_cfg.get("_target_"),
112        )

Instantiate the CUL algorithm object from cul_algorithm_cfg.

def unlearnable_task_ids(self, task_id: int) -> list[int]:
114    def unlearnable_task_ids(self, task_id: int) -> list[int]:
115        r"""Get the list of unlearnable task IDs at task `task_id`.
116
117        **Args:**
118        - **task_id** (`int`): the target task ID to check unlearnable task IDs.
119
120        **Returns:**
121        - **unlearnable_task_ids** (`list[int]`): the list of unlearnable task IDs at task `task_id`.
122        """
123        unlearnable_task_ids = []
124        for tid in range(1, task_id + 1):
125            unlearnable_age = self.unlearnable_ages[tid]
126            if (
127                unlearnable_age is None or (task_id - tid) < unlearnable_age
128            ) and tid not in self.unlearned_task_ids:
129                unlearnable_task_ids.append(tid)
130
131        return unlearnable_task_ids

Get the list of unlearnable task IDs at task task_id.

Args:

  • task_id (int): the target task ID to check unlearnable task IDs.

Returns:

  • unlearnable_task_ids (list[int]): the list of unlearnable task IDs at task task_id.
def task_ids_just_no_longer_unlearnable(self, task_id: int) -> list[int]:
133    def task_ids_just_no_longer_unlearnable(self, task_id: int) -> list[int]:
134        r"""Get the list of task IDs just turning not unlearnable at task `task_id`.
135
136        **Args:**
137        - **task_id** (`int`): the target task ID to check.
138
139        **Returns:**
140        - **task_ids_just_no_longer_unlearnable** (`list[int]`): the list of task IDs just turning not unlearnable at task `task_id`.
141        """
142        task_ids_just_no_longer_unlearnable = []
143        for tid in range(1, task_id + 1):
144            unlearnable_age = self.unlearnable_ages[tid]
145            if task_id - unlearnable_age == tid and tid not in self.unlearned_task_ids:
146                task_ids_just_no_longer_unlearnable.append(tid)
147
148        return task_ids_just_no_longer_unlearnable

Get the list of task IDs just turning not unlearnable at task task_id.

Args:

  • task_id (int): the target task ID to check.

Returns:

  • task_ids_just_no_longer_unlearnable (list[int]): the list of task IDs just turning not unlearnable at task task_id.
def run(self) -> None:
150    def run(self) -> None:
151        r"""The main method to run the continual unlearning main experiment."""
152
153        self.set_global_seed(self.global_seed)
154
155        # global components
156        self.instantiate_cl_dataset(cl_dataset_cfg=self.cfg.cl_dataset)
157        self.cl_dataset.set_cl_paradigm(cl_paradigm=self.cl_paradigm)
158        self.instantiate_backbone(
159            backbone_cfg=self.cfg.backbone, disable_unlearning=False
160        )
161        self.instantiate_heads(
162            cl_paradigm=self.cl_paradigm, input_dim=self.cfg.backbone.output_dim
163        )
164        self.instantiate_cl_algorithm(
165            cl_algorithm_cfg=self.cfg.cl_algorithm,
166            backbone=self.backbone,
167            heads=self.heads,
168            non_algorithmic_hparams=select_hyperparameters_from_config(
169                cfg=self.cfg, type=self.cfg.pipeline
170            ),
171            disable_unlearning=False,
172        )  # cl_algorithm should be instantiated after backbone and heads
173        self.instantiate_cul_algorithm(
174            self.cfg.cul_algorithm
175        )  # cul_algorithm should be instantiated after model
176        self.instantiate_lightning_loggers(
177            lightning_loggers_cfg=self.cfg.lightning_loggers
178        )
179        self.instantiate_callbacks(
180            metrics_cfg=self.cfg.metrics,
181            callbacks_cfg=self.cfg.callbacks,
182        )
183
184        # task loop
185        for task_id in self.train_tasks:
186
187            self.task_id = task_id
188
189            # task-specific components
190            self.instantiate_optimizer(
191                optimizer_cfg=self.cfg.optimizer,
192                task_id=task_id,
193            )
194            if self.cfg.get("lr_scheduler"):
195                self.instantiate_lr_scheduler(
196                    lr_scheduler_cfg=self.cfg.lr_scheduler,
197                    task_id=task_id,
198                )
199            self.instantiate_trainer(
200                trainer_cfg=self.cfg.trainer,
201                lightning_loggers=self.lightning_loggers,
202                callbacks=self.callbacks,
203                task_id=task_id,
204            )  # trainer should be instantiated after lightning loggers and callbacks
205
206            # setup task ID for dataset and model
207            self.cl_dataset.setup_task_id(task_id=task_id)
208            self.cul_algorithm.setup_task_id(
209                task_id=self.task_id,
210                unlearning_requests=self.unlearning_requests,
211                unlearnable_task_ids=self.unlearnable_task_ids(self.task_id),
212                task_ids_just_no_longer_unlearnable=self.task_ids_just_no_longer_unlearnable(
213                    self.task_id
214                ),
215            )
216            self.model.setup_task_id(
217                task_id=task_id,
218                num_classes=len(self.cl_dataset.get_cl_class_map(self.task_id)),
219                optimizer=self.optimizer_t,
220                lr_scheduler=self.lr_scheduler_t,
221            )
222
223            # train and validate the model
224            self.trainer_t.fit(
225                model=self.model,
226                datamodule=self.cl_dataset,
227            )
228
229            # unlearn
230            if self.task_id in self.unlearning_requests.keys():
231                unlearning_task_ids = self.unlearning_requests[self.task_id]
232                pylogger.info(
233                    "Starting unlearning process for tasks: %s...", unlearning_task_ids
234                )
235                self.cul_algorithm.unlearn()
236                pylogger.info("Unlearning process finished.")
237
238            # for unlearning_task_id in self.cul_algorithm.unlearning_task_ids:
239            #     self.processed_task_ids.remove(unlearning_task_id)
240
241            self.cul_algorithm.setup_test_task_id()
242
243            # evaluation after training and validation
244            if task_id in self.eval_after_tasks:
245                self.trainer_t.test(
246                    model=self.model,
247                    datamodule=self.cl_dataset,
248                )
249
250            self.processed_task_ids.append(task_id)

The main method to run the continual unlearning main experiment.