clarena.cul_algorithms.amnesiac_hat_unlearn

The submoduule in cul_algorithms for AmnesiacHAT unlearning algorithm.

  1r"""
  2The submoduule in `cul_algorithms` for AmnesiacHAT unlearning algorithm.
  3"""
  4
  5__all__ = ["AmnesiacHATUnlearn"]
  6
  7import logging
  8
  9import torch
 10from rich.progress import track
 11
 12from clarena.cl_algorithms.amnesiac_hat import AmnesiacHAT
 13from clarena.cul_algorithms import AmnesiacCULAlgorithm
 14
 15# always get logger for built-in logging in each module
 16pylogger = logging.getLogger(__name__)
 17
 18
 19class AmnesiacHATUnlearn(AmnesiacCULAlgorithm):
 20    r"""The base class of the AmnesiacHAT unlearning algorithm."""
 21
 22    def __init__(
 23        self,
 24        model: AmnesiacHAT,
 25        if_backup_compensation: bool,
 26        compensate_order: str | None,
 27        if_replay_repairing: bool,
 28        repair_batch_size: int | None,
 29        repair_num_steps: int | None,
 30        repair_strategy: str | None,
 31    ) -> None:
 32        r"""Initialize the unlearning algorithm with the continual learning model.
 33
 34        **Args:**
 35        - **model** (`AmnesiacHAT`): the continual learning model (`CLAlgorithm` object which already contains the backbone and heads). It must be an `AmnesiacHAT` algorithm.
 36        - **if_backup_compensation** (`bool`): whether to perform compensation using the backup backbones after unlearning.
 37        - **compensate_order** (`str` | `None`): the order to compensate the affected tasks after unlearning (used when `if_backup_compensation` is 'True'), must be:
 38            - 'forward': from oldest to newest.
 39            - 'reverse': from newest to oldest.
 40        - **if_replay_repairing** (`bool`): whether to perform replay after unlearning.
 41        - **repair_batch_size** (`int` | `None`): the batch size used during the replay repairing after unlearning (used when `if_replay_repairing` is 'True').
 42        - **repair_num_steps** (`int` | `None`): the number of steps to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True').
 43        - **repair_strategy** (`str` | `None`): the strategy to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). must be:
 44            - 'joint': use joint replay data from all affected tasks for repairing.
 45            - 'sequential_finetuning': use replay data from each affected task one by one (from oldest to newest) for repairing, with only one epoch per task. This forms a mini continual learning process during repairing, where we use Finetuning (no additional operation) to learn each affected task sequentially.
 46            - 'sequential_adahat': use replay data from each affected task one by one (from oldest to newest) for repairing. This forms a mini continual learning process during repairing, where we use AdaHAT (no mask sparsity reg) to learn each affected task sequentially.
 47        """
 48        super().__init__(model=model)
 49
 50        self.if_backup_compensation: bool = if_backup_compensation
 51        r"""Whether to perform compensation using the backup backbones after unlearning."""
 52        if self.if_backup_compensation:
 53            self.compensate_order: str = compensate_order
 54            r"""The order to compensate the affected tasks after unlearning."""
 55
 56        self.if_replay_repairing: bool = if_replay_repairing
 57        r"""Whether to perform replay repairing after unlearning."""
 58        if self.if_replay_repairing:
 59            self.repair_batch_size: int = repair_batch_size
 60            r"""The batch size used during the replay repairing after unlearning."""
 61            self.repair_num_steps: int = repair_num_steps
 62            r"""The number of steps to perform replay repairing after unlearning."""
 63            self.repair_strategy: str = repair_strategy
 64            r"""The strategy to perform replay repairing after unlearning."""
 65
 66    def compensate_by_backup(self) -> None:
 67        r"""Compensate the model using the backup backbones after unlearning."""
 68
 69        unlearning_task_id = self.unlearning_task_ids[
 70            0
 71        ]  # only one unlearning task is supported for now
 72
 73        task_ids_to_compensate = self.model.affected_tasks_after_unlearning()
 74
 75        if len(task_ids_to_compensate) == 0:
 76            pylogger.info(
 77                "No tasks to compensate after unlearning. Skipping compensation phase."
 78            )
 79            return
 80
 81        if self.compensate_order == "reverse":
 82            task_ids_to_compensate.reverse()  # compensate in reverse order
 83
 84        pylogger.debug(
 85            "Affected tasks by unlearning task %s is %s, will be compensated in this order.",
 86            unlearning_task_id,
 87            task_ids_to_compensate,
 88        )
 89
 90        for task_id_to_compensate in task_ids_to_compensate:
 91
 92            # get the backup state dict
 93            backup_state_dict = self.model.backbone.backup_state_dicts[
 94                (unlearning_task_id, task_id_to_compensate)
 95            ]
 96
 97            # only compensate the intersected neurons between the unlearning task and the affected task
 98            compensate_mask = self.model.backbone.combine_masks(
 99                [
100                    self.model.backbone.masks[task_id_to_compensate],
101                    self.model.backbone.masks[unlearning_task_id],
102                ],
103                mode="intersection",
104            )
105
106            for layer_name in self.model.backbone.weighted_layer_names:
107                layer = self.model.backbone.get_layer_by_name(layer_name)
108
109                # construct parameter-wise mask for the layer
110                weight_mask, bias_mask = (
111                    self.model.backbone.get_layer_measure_parameter_wise(
112                        neuron_wise_measure=compensate_mask,
113                        layer_name=layer_name,
114                        aggregation_mode="min",
115                    )
116                )
117
118                # compensate the parameters using the backup state dict
119                target_device = layer.weight.device
120                target_dtype = layer.weight.dtype
121                if weight_mask.device != target_device:
122                    weight_mask = weight_mask.to(device=target_device)
123                backup_weight = backup_state_dict[
124                    layer_name.replace("/", ".") + ".weight"
125                ].to(device=target_device, dtype=target_dtype)
126                layer.weight.data = torch.where(
127                    weight_mask.bool(),
128                    backup_weight,
129                    layer.weight.data,
130                )
131                if layer.bias is not None:
132                    if bias_mask.device != target_device:
133                        bias_mask = bias_mask.to(device=target_device)
134                    backup_bias = backup_state_dict[
135                        layer_name.replace("/", ".") + ".bias"
136                    ].to(device=target_device, dtype=layer.bias.dtype)
137                    layer.bias.data = torch.where(
138                        bias_mask.bool(),
139                        backup_bias,
140                        layer.bias.data,
141                    )
142
143            pylogger.debug(
144                "Compensated affected task %s using backup from unlearning task %s.",
145                task_id_to_compensate,
146                unlearning_task_id,
147            )
148
149    def replay_repairing(self) -> None:
150        r"""Repairing the model with replay after unlearning."""
151
152        task_ids_to_repair = self.model.affected_tasks_after_unlearning()
153
154        if len(task_ids_to_repair) == 0:
155            pylogger.info(
156                "No tasks to repair after unlearning. Skipping repairing phase."
157            )
158            return
159        else:
160            pylogger.info(
161                "Starting replay repairing tasks %s, after unlearning: %s. Repair strategy: %s.",
162                task_ids_to_repair,
163                self.unlearning_task_ids,
164                self.repair_strategy,
165            )
166
167        # align model device with replay buffer if needed (trainer may move model to CPU after fit)
168        buffer_device = (
169            self.model.memory_buffer.examples.device
170            if self.model.memory_buffer.examples.numel() > 0
171            else next(self.model.parameters()).device
172        )
173        if next(self.model.parameters()).device != buffer_device:
174            self.model.to(buffer_device)
175        model_device = next(self.model.parameters()).device
176
177        def _move_optimizer_state_to_device(optimizer, device: torch.device) -> None:
178            if isinstance(optimizer, (list, tuple)):
179                for opt in optimizer:
180                    _move_optimizer_state_to_device(opt, device)
181                return
182            opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
183            for state in opt.state.values():
184                for key, value in state.items():
185                    if torch.is_tensor(value) and value.device != device:
186                        state[key] = value.to(device)
187
188        # build the unlearning mask that aggregates all unlearning tasks. This mask is used to clip the gradients during replay repairing to prevent changing unaffected parameters
189        union_unlearning_mask = self.model.backbone.combine_masks(
190            [
191                self.model.backbone.masks[unlearning_task_id]
192                for unlearning_task_id in self.unlearning_task_ids
193            ],
194            mode="union",
195        )
196        union_unlearning_mask = {
197            layer_name: mask_tensor.to(model_device)
198            if mask_tensor.device != model_device
199            else mask_tensor
200            for layer_name, mask_tensor in union_unlearning_mask.items()
201        }
202
203        if (
204            self.repair_strategy == "sequential_finetuning"
205            or self.repair_strategy == "sequential_adahat"
206        ):
207
208            summative_mask_for_previous_tasks_in_replay_repairing = {
209                layer_name: torch.zeros(
210                    self.model.backbone.get_layer_by_name(layer_name).weight.shape[0],
211                    device=model_device,
212                )
213                for layer_name in self.model.backbone.weighted_layer_names
214            }
215
216            opt = self.model.optimizers()
217            _move_optimizer_state_to_device(opt, model_device)
218
219            total_tasks = len(task_ids_to_repair)
220            for task_index, task_id_to_repair in enumerate(task_ids_to_repair, start=1):
221                affected_task_mask = {
222                    layer_name: layer_mask.to(model_device)
223                    if layer_mask.device != model_device
224                    else layer_mask
225                    for layer_name, layer_mask in self.model.backbone.masks[
226                        task_id_to_repair
227                    ].items()
228                }
229                unlearning_and_affected_mask = self.model.backbone.combine_masks(
230                    [union_unlearning_mask, affected_task_mask],
231                    mode="intersection",
232                )
233
234                for s in track(
235                    range(self.repair_num_steps),
236                    description=f"Replay repairing task {task_id_to_repair} ({task_index}/{total_tasks})",
237                    transient=True,
238                ):
239                    # get replay data for repairing from memory buffer
240                    x_replay, labels_replay, logits_replay, _ = (
241                        self.model.memory_buffer.get_data(
242                            self.repair_batch_size,
243                            included_tasks=[task_id_to_repair],
244                        )
245                    )
246                    if x_replay.device != model_device:
247                        x_replay = x_replay.to(model_device)
248                    if labels_replay.device != model_device:
249                        labels_replay = labels_replay.to(model_device)
250                    if logits_replay.device != model_device:
251                        logits_replay = logits_replay.to(model_device)
252
253                    # zero the gradients before forward pass in manual optimization mode
254                    opt.zero_grad()
255
256                    student_feature_replay = self.model.backbone(
257                        x_replay,
258                        stage="test",
259                        test_task_id=task_id_to_repair,
260                    )[0]
261
262                    student_logits_replay = self.model.heads(
263                        student_feature_replay, task_id=task_id_to_repair
264                    )
265
266                    with torch.no_grad():  # stop updating the previous heads
267                        teacher_logits_replay = logits_replay
268
269                    loss = self.model.distillation_reg(
270                        student_logits=student_logits_replay,
271                        teacher_logits=teacher_logits_replay,
272                    )
273                    loss += self.model.replay_ce_factor * self.model.criterion(
274                        student_logits_replay, labels_replay.long()
275                    )
276
277                    self.model.manual_backward(loss)  # calculate the gradients
278
279                    # Clip gradients outside the intersection between
280                    # unlearning-affected units and the current affected task.
281
282                    self.model.clip_grad_by_mask(
283                        mask=unlearning_and_affected_mask, aggregation_mode="min"
284                    )
285
286                    if self.repair_strategy == "sequential_adahat":
287                        self.model.clip_grad_by_adjustment_in_replay_repairing(
288                            summative_mask_for_previous_tasks_in_replay_repairing=summative_mask_for_previous_tasks_in_replay_repairing
289                        )
290
291                    # update parameters with the modified gradients
292                    opt.step()
293
294                summative_mask_for_previous_tasks_in_replay_repairing = {
295                    layer_name: summative_mask_for_previous_tasks_in_replay_repairing[
296                        layer_name
297                    ]
298                    + (
299                        self.model.backbone.masks[task_id_to_repair][layer_name].to(
300                            model_device
301                        )
302                        if self.model.backbone.masks[task_id_to_repair][
303                            layer_name
304                        ].device
305                        != model_device
306                        else self.model.backbone.masks[task_id_to_repair][layer_name]
307                    )
308                    for layer_name in self.model.backbone.weighted_layer_names
309                }
310
311        elif self.repair_strategy == "joint":
312
313            opt = self.model.optimizers()
314            _move_optimizer_state_to_device(opt, model_device)
315
316            for s in track(
317                range(self.repair_num_steps),
318                description="Replay repairing (joint)",
319                transient=True,
320            ):
321
322                # get replay data for repairing from memory buffer
323                x_replay, labels_replay, logits_replay, task_labels_replay = (
324                    self.model.memory_buffer.get_data(
325                        self.repair_batch_size,
326                        included_tasks=task_ids_to_repair,
327                    )
328                )
329                if x_replay.device != model_device:
330                    x_replay = x_replay.to(model_device)
331                if labels_replay.device != model_device:
332                    labels_replay = labels_replay.to(model_device)
333                if logits_replay.device != model_device:
334                    logits_replay = logits_replay.to(model_device)
335                if task_labels_replay.device != model_device:
336                    task_labels_replay = task_labels_replay.to(model_device)
337
338                # zero the gradients before forward pass in manual optimization mode
339                opt.zero_grad()
340
341                student_feature_replay = torch.cat(
342                    [
343                        self.model.backbone(
344                            x_replay[i].unsqueeze(0),
345                            stage="test",
346                            test_task_id=tid.item(),
347                        )[0]
348                        for i, tid in enumerate(task_labels_replay)
349                    ]
350                )
351
352                student_logits_replay = torch.cat(
353                    [
354                        self.model.heads(
355                            student_feature_replay[i].unsqueeze(0), task_id=tid
356                        )
357                        for i, tid in enumerate(task_labels_replay)
358                    ]
359                )
360
361                with torch.no_grad():  # stop updating the previous heads
362                    teacher_logits_replay = logits_replay
363
364                loss = self.model.distillation_reg(
365                    student_logits=student_logits_replay,
366                    teacher_logits=teacher_logits_replay,
367                )
368                loss += self.model.replay_ce_factor * self.model.criterion(
369                    student_logits_replay, labels_replay.long()
370                )
371
372                self.model.manual_backward(loss)  # calculate the gradients
373
374                batch_task_ids = [int(task_id) for task_id in torch.unique(task_labels_replay)]
375                batch_affected_task_masks = []
376                for task_id in batch_task_ids:
377                    affected_task_mask = {
378                        layer_name: layer_mask.to(model_device)
379                        if layer_mask.device != model_device
380                        else layer_mask
381                        for layer_name, layer_mask in self.model.backbone.masks[
382                            task_id
383                        ].items()
384                    }
385                    batch_affected_task_masks.append(affected_task_mask)
386                union_batch_affected_mask = self.model.backbone.combine_masks(
387                    batch_affected_task_masks,
388                    mode="union",  
389                )
390                unlearning_and_affected_mask = self.model.backbone.combine_masks(
391                    [union_unlearning_mask, union_batch_affected_mask],
392                    mode="intersection",
393                )
394
395                self.model.clip_grad_by_mask(
396                    mask=unlearning_and_affected_mask, aggregation_mode="min"
397                )
398
399                # update parameters with the modified gradients
400                opt.step()
401
402    def unlearn(self) -> None:
403        r"""Unlearn the requested unlearning tasks (`self.unlearning_task_ids`) in the current task `self.task_id`."""
404
405        # delete the corresponding parameter update records
406        self.delete_update()
407
408        for unlearning_task_id in self.unlearning_task_ids:
409
410            # delete the data of the unlearning task from the memory buffer
411            self.model.memory_buffer.delete_task(unlearning_task_id)
412
413        if self.if_backup_compensation:
414            self.compensate_by_backup()
415
416        if self.if_replay_repairing:
417            self.replay_repairing()
418
419        # do not delete the masks and other related info of the unlearning tasks, as they may be needed in testing
class AmnesiacHATUnlearn(clarena.cul_algorithms.base.AmnesiacCULAlgorithm):
 20class AmnesiacHATUnlearn(AmnesiacCULAlgorithm):
 21    r"""The base class of the AmnesiacHAT unlearning algorithm."""
 22
 23    def __init__(
 24        self,
 25        model: AmnesiacHAT,
 26        if_backup_compensation: bool,
 27        compensate_order: str | None,
 28        if_replay_repairing: bool,
 29        repair_batch_size: int | None,
 30        repair_num_steps: int | None,
 31        repair_strategy: str | None,
 32    ) -> None:
 33        r"""Initialize the unlearning algorithm with the continual learning model.
 34
 35        **Args:**
 36        - **model** (`AmnesiacHAT`): the continual learning model (`CLAlgorithm` object which already contains the backbone and heads). It must be an `AmnesiacHAT` algorithm.
 37        - **if_backup_compensation** (`bool`): whether to perform compensation using the backup backbones after unlearning.
 38        - **compensate_order** (`str` | `None`): the order to compensate the affected tasks after unlearning (used when `if_backup_compensation` is 'True'), must be:
 39            - 'forward': from oldest to newest.
 40            - 'reverse': from newest to oldest.
 41        - **if_replay_repairing** (`bool`): whether to perform replay after unlearning.
 42        - **repair_batch_size** (`int` | `None`): the batch size used during the replay repairing after unlearning (used when `if_replay_repairing` is 'True').
 43        - **repair_num_steps** (`int` | `None`): the number of steps to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True').
 44        - **repair_strategy** (`str` | `None`): the strategy to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). must be:
 45            - 'joint': use joint replay data from all affected tasks for repairing.
 46            - 'sequential_finetuning': use replay data from each affected task one by one (from oldest to newest) for repairing, with only one epoch per task. This forms a mini continual learning process during repairing, where we use Finetuning (no additional operation) to learn each affected task sequentially.
 47            - 'sequential_adahat': use replay data from each affected task one by one (from oldest to newest) for repairing. This forms a mini continual learning process during repairing, where we use AdaHAT (no mask sparsity reg) to learn each affected task sequentially.
 48        """
 49        super().__init__(model=model)
 50
 51        self.if_backup_compensation: bool = if_backup_compensation
 52        r"""Whether to perform compensation using the backup backbones after unlearning."""
 53        if self.if_backup_compensation:
 54            self.compensate_order: str = compensate_order
 55            r"""The order to compensate the affected tasks after unlearning."""
 56
 57        self.if_replay_repairing: bool = if_replay_repairing
 58        r"""Whether to perform replay repairing after unlearning."""
 59        if self.if_replay_repairing:
 60            self.repair_batch_size: int = repair_batch_size
 61            r"""The batch size used during the replay repairing after unlearning."""
 62            self.repair_num_steps: int = repair_num_steps
 63            r"""The number of steps to perform replay repairing after unlearning."""
 64            self.repair_strategy: str = repair_strategy
 65            r"""The strategy to perform replay repairing after unlearning."""
 66
 67    def compensate_by_backup(self) -> None:
 68        r"""Compensate the model using the backup backbones after unlearning."""
 69
 70        unlearning_task_id = self.unlearning_task_ids[
 71            0
 72        ]  # only one unlearning task is supported for now
 73
 74        task_ids_to_compensate = self.model.affected_tasks_after_unlearning()
 75
 76        if len(task_ids_to_compensate) == 0:
 77            pylogger.info(
 78                "No tasks to compensate after unlearning. Skipping compensation phase."
 79            )
 80            return
 81
 82        if self.compensate_order == "reverse":
 83            task_ids_to_compensate.reverse()  # compensate in reverse order
 84
 85        pylogger.debug(
 86            "Affected tasks by unlearning task %s is %s, will be compensated in this order.",
 87            unlearning_task_id,
 88            task_ids_to_compensate,
 89        )
 90
 91        for task_id_to_compensate in task_ids_to_compensate:
 92
 93            # get the backup state dict
 94            backup_state_dict = self.model.backbone.backup_state_dicts[
 95                (unlearning_task_id, task_id_to_compensate)
 96            ]
 97
 98            # only compensate the intersected neurons between the unlearning task and the affected task
 99            compensate_mask = self.model.backbone.combine_masks(
100                [
101                    self.model.backbone.masks[task_id_to_compensate],
102                    self.model.backbone.masks[unlearning_task_id],
103                ],
104                mode="intersection",
105            )
106
107            for layer_name in self.model.backbone.weighted_layer_names:
108                layer = self.model.backbone.get_layer_by_name(layer_name)
109
110                # construct parameter-wise mask for the layer
111                weight_mask, bias_mask = (
112                    self.model.backbone.get_layer_measure_parameter_wise(
113                        neuron_wise_measure=compensate_mask,
114                        layer_name=layer_name,
115                        aggregation_mode="min",
116                    )
117                )
118
119                # compensate the parameters using the backup state dict
120                target_device = layer.weight.device
121                target_dtype = layer.weight.dtype
122                if weight_mask.device != target_device:
123                    weight_mask = weight_mask.to(device=target_device)
124                backup_weight = backup_state_dict[
125                    layer_name.replace("/", ".") + ".weight"
126                ].to(device=target_device, dtype=target_dtype)
127                layer.weight.data = torch.where(
128                    weight_mask.bool(),
129                    backup_weight,
130                    layer.weight.data,
131                )
132                if layer.bias is not None:
133                    if bias_mask.device != target_device:
134                        bias_mask = bias_mask.to(device=target_device)
135                    backup_bias = backup_state_dict[
136                        layer_name.replace("/", ".") + ".bias"
137                    ].to(device=target_device, dtype=layer.bias.dtype)
138                    layer.bias.data = torch.where(
139                        bias_mask.bool(),
140                        backup_bias,
141                        layer.bias.data,
142                    )
143
144            pylogger.debug(
145                "Compensated affected task %s using backup from unlearning task %s.",
146                task_id_to_compensate,
147                unlearning_task_id,
148            )
149
150    def replay_repairing(self) -> None:
151        r"""Repairing the model with replay after unlearning."""
152
153        task_ids_to_repair = self.model.affected_tasks_after_unlearning()
154
155        if len(task_ids_to_repair) == 0:
156            pylogger.info(
157                "No tasks to repair after unlearning. Skipping repairing phase."
158            )
159            return
160        else:
161            pylogger.info(
162                "Starting replay repairing tasks %s, after unlearning: %s. Repair strategy: %s.",
163                task_ids_to_repair,
164                self.unlearning_task_ids,
165                self.repair_strategy,
166            )
167
168        # align model device with replay buffer if needed (trainer may move model to CPU after fit)
169        buffer_device = (
170            self.model.memory_buffer.examples.device
171            if self.model.memory_buffer.examples.numel() > 0
172            else next(self.model.parameters()).device
173        )
174        if next(self.model.parameters()).device != buffer_device:
175            self.model.to(buffer_device)
176        model_device = next(self.model.parameters()).device
177
178        def _move_optimizer_state_to_device(optimizer, device: torch.device) -> None:
179            if isinstance(optimizer, (list, tuple)):
180                for opt in optimizer:
181                    _move_optimizer_state_to_device(opt, device)
182                return
183            opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
184            for state in opt.state.values():
185                for key, value in state.items():
186                    if torch.is_tensor(value) and value.device != device:
187                        state[key] = value.to(device)
188
189        # build the unlearning mask that aggregates all unlearning tasks. This mask is used to clip the gradients during replay repairing to prevent changing unaffected parameters
190        union_unlearning_mask = self.model.backbone.combine_masks(
191            [
192                self.model.backbone.masks[unlearning_task_id]
193                for unlearning_task_id in self.unlearning_task_ids
194            ],
195            mode="union",
196        )
197        union_unlearning_mask = {
198            layer_name: mask_tensor.to(model_device)
199            if mask_tensor.device != model_device
200            else mask_tensor
201            for layer_name, mask_tensor in union_unlearning_mask.items()
202        }
203
204        if (
205            self.repair_strategy == "sequential_finetuning"
206            or self.repair_strategy == "sequential_adahat"
207        ):
208
209            summative_mask_for_previous_tasks_in_replay_repairing = {
210                layer_name: torch.zeros(
211                    self.model.backbone.get_layer_by_name(layer_name).weight.shape[0],
212                    device=model_device,
213                )
214                for layer_name in self.model.backbone.weighted_layer_names
215            }
216
217            opt = self.model.optimizers()
218            _move_optimizer_state_to_device(opt, model_device)
219
220            total_tasks = len(task_ids_to_repair)
221            for task_index, task_id_to_repair in enumerate(task_ids_to_repair, start=1):
222                affected_task_mask = {
223                    layer_name: layer_mask.to(model_device)
224                    if layer_mask.device != model_device
225                    else layer_mask
226                    for layer_name, layer_mask in self.model.backbone.masks[
227                        task_id_to_repair
228                    ].items()
229                }
230                unlearning_and_affected_mask = self.model.backbone.combine_masks(
231                    [union_unlearning_mask, affected_task_mask],
232                    mode="intersection",
233                )
234
235                for s in track(
236                    range(self.repair_num_steps),
237                    description=f"Replay repairing task {task_id_to_repair} ({task_index}/{total_tasks})",
238                    transient=True,
239                ):
240                    # get replay data for repairing from memory buffer
241                    x_replay, labels_replay, logits_replay, _ = (
242                        self.model.memory_buffer.get_data(
243                            self.repair_batch_size,
244                            included_tasks=[task_id_to_repair],
245                        )
246                    )
247                    if x_replay.device != model_device:
248                        x_replay = x_replay.to(model_device)
249                    if labels_replay.device != model_device:
250                        labels_replay = labels_replay.to(model_device)
251                    if logits_replay.device != model_device:
252                        logits_replay = logits_replay.to(model_device)
253
254                    # zero the gradients before forward pass in manual optimization mode
255                    opt.zero_grad()
256
257                    student_feature_replay = self.model.backbone(
258                        x_replay,
259                        stage="test",
260                        test_task_id=task_id_to_repair,
261                    )[0]
262
263                    student_logits_replay = self.model.heads(
264                        student_feature_replay, task_id=task_id_to_repair
265                    )
266
267                    with torch.no_grad():  # stop updating the previous heads
268                        teacher_logits_replay = logits_replay
269
270                    loss = self.model.distillation_reg(
271                        student_logits=student_logits_replay,
272                        teacher_logits=teacher_logits_replay,
273                    )
274                    loss += self.model.replay_ce_factor * self.model.criterion(
275                        student_logits_replay, labels_replay.long()
276                    )
277
278                    self.model.manual_backward(loss)  # calculate the gradients
279
280                    # Clip gradients outside the intersection between
281                    # unlearning-affected units and the current affected task.
282
283                    self.model.clip_grad_by_mask(
284                        mask=unlearning_and_affected_mask, aggregation_mode="min"
285                    )
286
287                    if self.repair_strategy == "sequential_adahat":
288                        self.model.clip_grad_by_adjustment_in_replay_repairing(
289                            summative_mask_for_previous_tasks_in_replay_repairing=summative_mask_for_previous_tasks_in_replay_repairing
290                        )
291
292                    # update parameters with the modified gradients
293                    opt.step()
294
295                summative_mask_for_previous_tasks_in_replay_repairing = {
296                    layer_name: summative_mask_for_previous_tasks_in_replay_repairing[
297                        layer_name
298                    ]
299                    + (
300                        self.model.backbone.masks[task_id_to_repair][layer_name].to(
301                            model_device
302                        )
303                        if self.model.backbone.masks[task_id_to_repair][
304                            layer_name
305                        ].device
306                        != model_device
307                        else self.model.backbone.masks[task_id_to_repair][layer_name]
308                    )
309                    for layer_name in self.model.backbone.weighted_layer_names
310                }
311
312        elif self.repair_strategy == "joint":
313
314            opt = self.model.optimizers()
315            _move_optimizer_state_to_device(opt, model_device)
316
317            for s in track(
318                range(self.repair_num_steps),
319                description="Replay repairing (joint)",
320                transient=True,
321            ):
322
323                # get replay data for repairing from memory buffer
324                x_replay, labels_replay, logits_replay, task_labels_replay = (
325                    self.model.memory_buffer.get_data(
326                        self.repair_batch_size,
327                        included_tasks=task_ids_to_repair,
328                    )
329                )
330                if x_replay.device != model_device:
331                    x_replay = x_replay.to(model_device)
332                if labels_replay.device != model_device:
333                    labels_replay = labels_replay.to(model_device)
334                if logits_replay.device != model_device:
335                    logits_replay = logits_replay.to(model_device)
336                if task_labels_replay.device != model_device:
337                    task_labels_replay = task_labels_replay.to(model_device)
338
339                # zero the gradients before forward pass in manual optimization mode
340                opt.zero_grad()
341
342                student_feature_replay = torch.cat(
343                    [
344                        self.model.backbone(
345                            x_replay[i].unsqueeze(0),
346                            stage="test",
347                            test_task_id=tid.item(),
348                        )[0]
349                        for i, tid in enumerate(task_labels_replay)
350                    ]
351                )
352
353                student_logits_replay = torch.cat(
354                    [
355                        self.model.heads(
356                            student_feature_replay[i].unsqueeze(0), task_id=tid
357                        )
358                        for i, tid in enumerate(task_labels_replay)
359                    ]
360                )
361
362                with torch.no_grad():  # stop updating the previous heads
363                    teacher_logits_replay = logits_replay
364
365                loss = self.model.distillation_reg(
366                    student_logits=student_logits_replay,
367                    teacher_logits=teacher_logits_replay,
368                )
369                loss += self.model.replay_ce_factor * self.model.criterion(
370                    student_logits_replay, labels_replay.long()
371                )
372
373                self.model.manual_backward(loss)  # calculate the gradients
374
375                batch_task_ids = [int(task_id) for task_id in torch.unique(task_labels_replay)]
376                batch_affected_task_masks = []
377                for task_id in batch_task_ids:
378                    affected_task_mask = {
379                        layer_name: layer_mask.to(model_device)
380                        if layer_mask.device != model_device
381                        else layer_mask
382                        for layer_name, layer_mask in self.model.backbone.masks[
383                            task_id
384                        ].items()
385                    }
386                    batch_affected_task_masks.append(affected_task_mask)
387                union_batch_affected_mask = self.model.backbone.combine_masks(
388                    batch_affected_task_masks,
389                    mode="union",  
390                )
391                unlearning_and_affected_mask = self.model.backbone.combine_masks(
392                    [union_unlearning_mask, union_batch_affected_mask],
393                    mode="intersection",
394                )
395
396                self.model.clip_grad_by_mask(
397                    mask=unlearning_and_affected_mask, aggregation_mode="min"
398                )
399
400                # update parameters with the modified gradients
401                opt.step()
402
403    def unlearn(self) -> None:
404        r"""Unlearn the requested unlearning tasks (`self.unlearning_task_ids`) in the current task `self.task_id`."""
405
406        # delete the corresponding parameter update records
407        self.delete_update()
408
409        for unlearning_task_id in self.unlearning_task_ids:
410
411            # delete the data of the unlearning task from the memory buffer
412            self.model.memory_buffer.delete_task(unlearning_task_id)
413
414        if self.if_backup_compensation:
415            self.compensate_by_backup()
416
417        if self.if_replay_repairing:
418            self.replay_repairing()
419
420        # do not delete the masks and other related info of the unlearning tasks, as they may be needed in testing

The base class of the AmnesiacHAT unlearning algorithm.

AmnesiacHATUnlearn( model: clarena.cl_algorithms.amnesiac_hat.AmnesiacHAT, if_backup_compensation: bool, compensate_order: str | None, if_replay_repairing: bool, repair_batch_size: int | None, repair_num_steps: int | None, repair_strategy: str | None)
23    def __init__(
24        self,
25        model: AmnesiacHAT,
26        if_backup_compensation: bool,
27        compensate_order: str | None,
28        if_replay_repairing: bool,
29        repair_batch_size: int | None,
30        repair_num_steps: int | None,
31        repair_strategy: str | None,
32    ) -> None:
33        r"""Initialize the unlearning algorithm with the continual learning model.
34
35        **Args:**
36        - **model** (`AmnesiacHAT`): the continual learning model (`CLAlgorithm` object which already contains the backbone and heads). It must be an `AmnesiacHAT` algorithm.
37        - **if_backup_compensation** (`bool`): whether to perform compensation using the backup backbones after unlearning.
38        - **compensate_order** (`str` | `None`): the order to compensate the affected tasks after unlearning (used when `if_backup_compensation` is 'True'), must be:
39            - 'forward': from oldest to newest.
40            - 'reverse': from newest to oldest.
41        - **if_replay_repairing** (`bool`): whether to perform replay after unlearning.
42        - **repair_batch_size** (`int` | `None`): the batch size used during the replay repairing after unlearning (used when `if_replay_repairing` is 'True').
43        - **repair_num_steps** (`int` | `None`): the number of steps to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True').
44        - **repair_strategy** (`str` | `None`): the strategy to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). must be:
45            - 'joint': use joint replay data from all affected tasks for repairing.
46            - 'sequential_finetuning': use replay data from each affected task one by one (from oldest to newest) for repairing, with only one epoch per task. This forms a mini continual learning process during repairing, where we use Finetuning (no additional operation) to learn each affected task sequentially.
47            - 'sequential_adahat': use replay data from each affected task one by one (from oldest to newest) for repairing. This forms a mini continual learning process during repairing, where we use AdaHAT (no mask sparsity reg) to learn each affected task sequentially.
48        """
49        super().__init__(model=model)
50
51        self.if_backup_compensation: bool = if_backup_compensation
52        r"""Whether to perform compensation using the backup backbones after unlearning."""
53        if self.if_backup_compensation:
54            self.compensate_order: str = compensate_order
55            r"""The order to compensate the affected tasks after unlearning."""
56
57        self.if_replay_repairing: bool = if_replay_repairing
58        r"""Whether to perform replay repairing after unlearning."""
59        if self.if_replay_repairing:
60            self.repair_batch_size: int = repair_batch_size
61            r"""The batch size used during the replay repairing after unlearning."""
62            self.repair_num_steps: int = repair_num_steps
63            r"""The number of steps to perform replay repairing after unlearning."""
64            self.repair_strategy: str = repair_strategy
65            r"""The strategy to perform replay repairing after unlearning."""

Initialize the unlearning algorithm with the continual learning model.

Args:

  • model (AmnesiacHAT): the continual learning model (CLAlgorithm object which already contains the backbone and heads). It must be an AmnesiacHAT algorithm.
  • if_backup_compensation (bool): whether to perform compensation using the backup backbones after unlearning.
  • compensate_order (str | None): the order to compensate the affected tasks after unlearning (used when if_backup_compensation is 'True'), must be:
    • 'forward': from oldest to newest.
    • 'reverse': from newest to oldest.
  • if_replay_repairing (bool): whether to perform replay after unlearning.
  • repair_batch_size (int | None): the batch size used during the replay repairing after unlearning (used when if_replay_repairing is 'True').
  • repair_num_steps (int | None): the number of steps to perform replay repairing after unlearning (used when if_replay_repairing is 'True').
  • repair_strategy (str | None): the strategy to perform replay repairing after unlearning (used when if_replay_repairing is 'True'). must be:
    • 'joint': use joint replay data from all affected tasks for repairing.
    • 'sequential_finetuning': use replay data from each affected task one by one (from oldest to newest) for repairing, with only one epoch per task. This forms a mini continual learning process during repairing, where we use Finetuning (no additional operation) to learn each affected task sequentially.
    • 'sequential_adahat': use replay data from each affected task one by one (from oldest to newest) for repairing. This forms a mini continual learning process during repairing, where we use AdaHAT (no mask sparsity reg) to learn each affected task sequentially.
if_backup_compensation: bool

Whether to perform compensation using the backup backbones after unlearning.

if_replay_repairing: bool

Whether to perform replay repairing after unlearning.

def compensate_by_backup(self) -> None:
 67    def compensate_by_backup(self) -> None:
 68        r"""Compensate the model using the backup backbones after unlearning."""
 69
 70        unlearning_task_id = self.unlearning_task_ids[
 71            0
 72        ]  # only one unlearning task is supported for now
 73
 74        task_ids_to_compensate = self.model.affected_tasks_after_unlearning()
 75
 76        if len(task_ids_to_compensate) == 0:
 77            pylogger.info(
 78                "No tasks to compensate after unlearning. Skipping compensation phase."
 79            )
 80            return
 81
 82        if self.compensate_order == "reverse":
 83            task_ids_to_compensate.reverse()  # compensate in reverse order
 84
 85        pylogger.debug(
 86            "Affected tasks by unlearning task %s is %s, will be compensated in this order.",
 87            unlearning_task_id,
 88            task_ids_to_compensate,
 89        )
 90
 91        for task_id_to_compensate in task_ids_to_compensate:
 92
 93            # get the backup state dict
 94            backup_state_dict = self.model.backbone.backup_state_dicts[
 95                (unlearning_task_id, task_id_to_compensate)
 96            ]
 97
 98            # only compensate the intersected neurons between the unlearning task and the affected task
 99            compensate_mask = self.model.backbone.combine_masks(
100                [
101                    self.model.backbone.masks[task_id_to_compensate],
102                    self.model.backbone.masks[unlearning_task_id],
103                ],
104                mode="intersection",
105            )
106
107            for layer_name in self.model.backbone.weighted_layer_names:
108                layer = self.model.backbone.get_layer_by_name(layer_name)
109
110                # construct parameter-wise mask for the layer
111                weight_mask, bias_mask = (
112                    self.model.backbone.get_layer_measure_parameter_wise(
113                        neuron_wise_measure=compensate_mask,
114                        layer_name=layer_name,
115                        aggregation_mode="min",
116                    )
117                )
118
119                # compensate the parameters using the backup state dict
120                target_device = layer.weight.device
121                target_dtype = layer.weight.dtype
122                if weight_mask.device != target_device:
123                    weight_mask = weight_mask.to(device=target_device)
124                backup_weight = backup_state_dict[
125                    layer_name.replace("/", ".") + ".weight"
126                ].to(device=target_device, dtype=target_dtype)
127                layer.weight.data = torch.where(
128                    weight_mask.bool(),
129                    backup_weight,
130                    layer.weight.data,
131                )
132                if layer.bias is not None:
133                    if bias_mask.device != target_device:
134                        bias_mask = bias_mask.to(device=target_device)
135                    backup_bias = backup_state_dict[
136                        layer_name.replace("/", ".") + ".bias"
137                    ].to(device=target_device, dtype=layer.bias.dtype)
138                    layer.bias.data = torch.where(
139                        bias_mask.bool(),
140                        backup_bias,
141                        layer.bias.data,
142                    )
143
144            pylogger.debug(
145                "Compensated affected task %s using backup from unlearning task %s.",
146                task_id_to_compensate,
147                unlearning_task_id,
148            )

Compensate the model using the backup backbones after unlearning.

def replay_repairing(self) -> None:
150    def replay_repairing(self) -> None:
151        r"""Repairing the model with replay after unlearning."""
152
153        task_ids_to_repair = self.model.affected_tasks_after_unlearning()
154
155        if len(task_ids_to_repair) == 0:
156            pylogger.info(
157                "No tasks to repair after unlearning. Skipping repairing phase."
158            )
159            return
160        else:
161            pylogger.info(
162                "Starting replay repairing tasks %s, after unlearning: %s. Repair strategy: %s.",
163                task_ids_to_repair,
164                self.unlearning_task_ids,
165                self.repair_strategy,
166            )
167
168        # align model device with replay buffer if needed (trainer may move model to CPU after fit)
169        buffer_device = (
170            self.model.memory_buffer.examples.device
171            if self.model.memory_buffer.examples.numel() > 0
172            else next(self.model.parameters()).device
173        )
174        if next(self.model.parameters()).device != buffer_device:
175            self.model.to(buffer_device)
176        model_device = next(self.model.parameters()).device
177
178        def _move_optimizer_state_to_device(optimizer, device: torch.device) -> None:
179            if isinstance(optimizer, (list, tuple)):
180                for opt in optimizer:
181                    _move_optimizer_state_to_device(opt, device)
182                return
183            opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
184            for state in opt.state.values():
185                for key, value in state.items():
186                    if torch.is_tensor(value) and value.device != device:
187                        state[key] = value.to(device)
188
189        # build the unlearning mask that aggregates all unlearning tasks. This mask is used to clip the gradients during replay repairing to prevent changing unaffected parameters
190        union_unlearning_mask = self.model.backbone.combine_masks(
191            [
192                self.model.backbone.masks[unlearning_task_id]
193                for unlearning_task_id in self.unlearning_task_ids
194            ],
195            mode="union",
196        )
197        union_unlearning_mask = {
198            layer_name: mask_tensor.to(model_device)
199            if mask_tensor.device != model_device
200            else mask_tensor
201            for layer_name, mask_tensor in union_unlearning_mask.items()
202        }
203
204        if (
205            self.repair_strategy == "sequential_finetuning"
206            or self.repair_strategy == "sequential_adahat"
207        ):
208
209            summative_mask_for_previous_tasks_in_replay_repairing = {
210                layer_name: torch.zeros(
211                    self.model.backbone.get_layer_by_name(layer_name).weight.shape[0],
212                    device=model_device,
213                )
214                for layer_name in self.model.backbone.weighted_layer_names
215            }
216
217            opt = self.model.optimizers()
218            _move_optimizer_state_to_device(opt, model_device)
219
220            total_tasks = len(task_ids_to_repair)
221            for task_index, task_id_to_repair in enumerate(task_ids_to_repair, start=1):
222                affected_task_mask = {
223                    layer_name: layer_mask.to(model_device)
224                    if layer_mask.device != model_device
225                    else layer_mask
226                    for layer_name, layer_mask in self.model.backbone.masks[
227                        task_id_to_repair
228                    ].items()
229                }
230                unlearning_and_affected_mask = self.model.backbone.combine_masks(
231                    [union_unlearning_mask, affected_task_mask],
232                    mode="intersection",
233                )
234
235                for s in track(
236                    range(self.repair_num_steps),
237                    description=f"Replay repairing task {task_id_to_repair} ({task_index}/{total_tasks})",
238                    transient=True,
239                ):
240                    # get replay data for repairing from memory buffer
241                    x_replay, labels_replay, logits_replay, _ = (
242                        self.model.memory_buffer.get_data(
243                            self.repair_batch_size,
244                            included_tasks=[task_id_to_repair],
245                        )
246                    )
247                    if x_replay.device != model_device:
248                        x_replay = x_replay.to(model_device)
249                    if labels_replay.device != model_device:
250                        labels_replay = labels_replay.to(model_device)
251                    if logits_replay.device != model_device:
252                        logits_replay = logits_replay.to(model_device)
253
254                    # zero the gradients before forward pass in manual optimization mode
255                    opt.zero_grad()
256
257                    student_feature_replay = self.model.backbone(
258                        x_replay,
259                        stage="test",
260                        test_task_id=task_id_to_repair,
261                    )[0]
262
263                    student_logits_replay = self.model.heads(
264                        student_feature_replay, task_id=task_id_to_repair
265                    )
266
267                    with torch.no_grad():  # stop updating the previous heads
268                        teacher_logits_replay = logits_replay
269
270                    loss = self.model.distillation_reg(
271                        student_logits=student_logits_replay,
272                        teacher_logits=teacher_logits_replay,
273                    )
274                    loss += self.model.replay_ce_factor * self.model.criterion(
275                        student_logits_replay, labels_replay.long()
276                    )
277
278                    self.model.manual_backward(loss)  # calculate the gradients
279
280                    # Clip gradients outside the intersection between
281                    # unlearning-affected units and the current affected task.
282
283                    self.model.clip_grad_by_mask(
284                        mask=unlearning_and_affected_mask, aggregation_mode="min"
285                    )
286
287                    if self.repair_strategy == "sequential_adahat":
288                        self.model.clip_grad_by_adjustment_in_replay_repairing(
289                            summative_mask_for_previous_tasks_in_replay_repairing=summative_mask_for_previous_tasks_in_replay_repairing
290                        )
291
292                    # update parameters with the modified gradients
293                    opt.step()
294
295                summative_mask_for_previous_tasks_in_replay_repairing = {
296                    layer_name: summative_mask_for_previous_tasks_in_replay_repairing[
297                        layer_name
298                    ]
299                    + (
300                        self.model.backbone.masks[task_id_to_repair][layer_name].to(
301                            model_device
302                        )
303                        if self.model.backbone.masks[task_id_to_repair][
304                            layer_name
305                        ].device
306                        != model_device
307                        else self.model.backbone.masks[task_id_to_repair][layer_name]
308                    )
309                    for layer_name in self.model.backbone.weighted_layer_names
310                }
311
312        elif self.repair_strategy == "joint":
313
314            opt = self.model.optimizers()
315            _move_optimizer_state_to_device(opt, model_device)
316
317            for s in track(
318                range(self.repair_num_steps),
319                description="Replay repairing (joint)",
320                transient=True,
321            ):
322
323                # get replay data for repairing from memory buffer
324                x_replay, labels_replay, logits_replay, task_labels_replay = (
325                    self.model.memory_buffer.get_data(
326                        self.repair_batch_size,
327                        included_tasks=task_ids_to_repair,
328                    )
329                )
330                if x_replay.device != model_device:
331                    x_replay = x_replay.to(model_device)
332                if labels_replay.device != model_device:
333                    labels_replay = labels_replay.to(model_device)
334                if logits_replay.device != model_device:
335                    logits_replay = logits_replay.to(model_device)
336                if task_labels_replay.device != model_device:
337                    task_labels_replay = task_labels_replay.to(model_device)
338
339                # zero the gradients before forward pass in manual optimization mode
340                opt.zero_grad()
341
342                student_feature_replay = torch.cat(
343                    [
344                        self.model.backbone(
345                            x_replay[i].unsqueeze(0),
346                            stage="test",
347                            test_task_id=tid.item(),
348                        )[0]
349                        for i, tid in enumerate(task_labels_replay)
350                    ]
351                )
352
353                student_logits_replay = torch.cat(
354                    [
355                        self.model.heads(
356                            student_feature_replay[i].unsqueeze(0), task_id=tid
357                        )
358                        for i, tid in enumerate(task_labels_replay)
359                    ]
360                )
361
362                with torch.no_grad():  # stop updating the previous heads
363                    teacher_logits_replay = logits_replay
364
365                loss = self.model.distillation_reg(
366                    student_logits=student_logits_replay,
367                    teacher_logits=teacher_logits_replay,
368                )
369                loss += self.model.replay_ce_factor * self.model.criterion(
370                    student_logits_replay, labels_replay.long()
371                )
372
373                self.model.manual_backward(loss)  # calculate the gradients
374
375                batch_task_ids = [int(task_id) for task_id in torch.unique(task_labels_replay)]
376                batch_affected_task_masks = []
377                for task_id in batch_task_ids:
378                    affected_task_mask = {
379                        layer_name: layer_mask.to(model_device)
380                        if layer_mask.device != model_device
381                        else layer_mask
382                        for layer_name, layer_mask in self.model.backbone.masks[
383                            task_id
384                        ].items()
385                    }
386                    batch_affected_task_masks.append(affected_task_mask)
387                union_batch_affected_mask = self.model.backbone.combine_masks(
388                    batch_affected_task_masks,
389                    mode="union",  
390                )
391                unlearning_and_affected_mask = self.model.backbone.combine_masks(
392                    [union_unlearning_mask, union_batch_affected_mask],
393                    mode="intersection",
394                )
395
396                self.model.clip_grad_by_mask(
397                    mask=unlearning_and_affected_mask, aggregation_mode="min"
398                )
399
400                # update parameters with the modified gradients
401                opt.step()

Repairing the model with replay after unlearning.

def unlearn(self) -> None:
403    def unlearn(self) -> None:
404        r"""Unlearn the requested unlearning tasks (`self.unlearning_task_ids`) in the current task `self.task_id`."""
405
406        # delete the corresponding parameter update records
407        self.delete_update()
408
409        for unlearning_task_id in self.unlearning_task_ids:
410
411            # delete the data of the unlearning task from the memory buffer
412            self.model.memory_buffer.delete_task(unlearning_task_id)
413
414        if self.if_backup_compensation:
415            self.compensate_by_backup()
416
417        if self.if_replay_repairing:
418            self.replay_repairing()
419
420        # do not delete the masks and other related info of the unlearning tasks, as they may be needed in testing

Unlearn the requested unlearning tasks (self.unlearning_task_ids) in the current task self.task_id.