clarena.cul_algorithms.amnesiac_hat_unlearn
The submoduule in cul_algorithms for AmnesiacHAT unlearning algorithm.
1r""" 2The submoduule in `cul_algorithms` for AmnesiacHAT unlearning algorithm. 3""" 4 5__all__ = ["AmnesiacHATUnlearn"] 6 7import logging 8 9import torch 10from rich.progress import track 11 12from clarena.cl_algorithms.amnesiac_hat import AmnesiacHAT 13from clarena.cul_algorithms import AmnesiacCULAlgorithm 14 15# always get logger for built-in logging in each module 16pylogger = logging.getLogger(__name__) 17 18 19class AmnesiacHATUnlearn(AmnesiacCULAlgorithm): 20 r"""The base class of the AmnesiacHAT unlearning algorithm.""" 21 22 def __init__( 23 self, 24 model: AmnesiacHAT, 25 if_backup_compensation: bool, 26 compensate_order: str | None, 27 if_replay_repairing: bool, 28 repair_batch_size: int | None, 29 repair_num_steps: int | None, 30 repair_strategy: str | None, 31 ) -> None: 32 r"""Initialize the unlearning algorithm with the continual learning model. 33 34 **Args:** 35 - **model** (`AmnesiacHAT`): the continual learning model (`CLAlgorithm` object which already contains the backbone and heads). It must be an `AmnesiacHAT` algorithm. 36 - **if_backup_compensation** (`bool`): whether to perform compensation using the backup backbones after unlearning. 37 - **compensate_order** (`str` | `None`): the order to compensate the affected tasks after unlearning (used when `if_backup_compensation` is 'True'), must be: 38 - 'forward': from oldest to newest. 39 - 'reverse': from newest to oldest. 40 - **if_replay_repairing** (`bool`): whether to perform replay after unlearning. 41 - **repair_batch_size** (`int` | `None`): the batch size used during the replay repairing after unlearning (used when `if_replay_repairing` is 'True'). 42 - **repair_num_steps** (`int` | `None`): the number of steps to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). 43 - **repair_strategy** (`str` | `None`): the strategy to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). must be: 44 - 'joint': use joint replay data from all affected tasks for repairing. 45 - 'sequential_finetuning': use replay data from each affected task one by one (from oldest to newest) for repairing, with only one epoch per task. This forms a mini continual learning process during repairing, where we use Finetuning (no additional operation) to learn each affected task sequentially. 46 - 'sequential_adahat': use replay data from each affected task one by one (from oldest to newest) for repairing. This forms a mini continual learning process during repairing, where we use AdaHAT (no mask sparsity reg) to learn each affected task sequentially. 47 """ 48 super().__init__(model=model) 49 50 self.if_backup_compensation: bool = if_backup_compensation 51 r"""Whether to perform compensation using the backup backbones after unlearning.""" 52 if self.if_backup_compensation: 53 self.compensate_order: str = compensate_order 54 r"""The order to compensate the affected tasks after unlearning.""" 55 56 self.if_replay_repairing: bool = if_replay_repairing 57 r"""Whether to perform replay repairing after unlearning.""" 58 if self.if_replay_repairing: 59 self.repair_batch_size: int = repair_batch_size 60 r"""The batch size used during the replay repairing after unlearning.""" 61 self.repair_num_steps: int = repair_num_steps 62 r"""The number of steps to perform replay repairing after unlearning.""" 63 self.repair_strategy: str = repair_strategy 64 r"""The strategy to perform replay repairing after unlearning.""" 65 66 def compensate_by_backup(self) -> None: 67 r"""Compensate the model using the backup backbones after unlearning.""" 68 69 unlearning_task_id = self.unlearning_task_ids[ 70 0 71 ] # only one unlearning task is supported for now 72 73 task_ids_to_compensate = self.model.affected_tasks_after_unlearning() 74 75 if len(task_ids_to_compensate) == 0: 76 pylogger.info( 77 "No tasks to compensate after unlearning. Skipping compensation phase." 78 ) 79 return 80 81 if self.compensate_order == "reverse": 82 task_ids_to_compensate.reverse() # compensate in reverse order 83 84 pylogger.debug( 85 "Affected tasks by unlearning task %s is %s, will be compensated in this order.", 86 unlearning_task_id, 87 task_ids_to_compensate, 88 ) 89 90 for task_id_to_compensate in task_ids_to_compensate: 91 92 # get the backup state dict 93 backup_state_dict = self.model.backbone.backup_state_dicts[ 94 (unlearning_task_id, task_id_to_compensate) 95 ] 96 97 # only compensate the intersected neurons between the unlearning task and the affected task 98 compensate_mask = self.model.backbone.combine_masks( 99 [ 100 self.model.backbone.masks[task_id_to_compensate], 101 self.model.backbone.masks[unlearning_task_id], 102 ], 103 mode="intersection", 104 ) 105 106 for layer_name in self.model.backbone.weighted_layer_names: 107 layer = self.model.backbone.get_layer_by_name(layer_name) 108 109 # construct parameter-wise mask for the layer 110 weight_mask, bias_mask = ( 111 self.model.backbone.get_layer_measure_parameter_wise( 112 neuron_wise_measure=compensate_mask, 113 layer_name=layer_name, 114 aggregation_mode="min", 115 ) 116 ) 117 118 # compensate the parameters using the backup state dict 119 target_device = layer.weight.device 120 target_dtype = layer.weight.dtype 121 if weight_mask.device != target_device: 122 weight_mask = weight_mask.to(device=target_device) 123 backup_weight = backup_state_dict[ 124 layer_name.replace("/", ".") + ".weight" 125 ].to(device=target_device, dtype=target_dtype) 126 layer.weight.data = torch.where( 127 weight_mask.bool(), 128 backup_weight, 129 layer.weight.data, 130 ) 131 if layer.bias is not None: 132 if bias_mask.device != target_device: 133 bias_mask = bias_mask.to(device=target_device) 134 backup_bias = backup_state_dict[ 135 layer_name.replace("/", ".") + ".bias" 136 ].to(device=target_device, dtype=layer.bias.dtype) 137 layer.bias.data = torch.where( 138 bias_mask.bool(), 139 backup_bias, 140 layer.bias.data, 141 ) 142 143 pylogger.debug( 144 "Compensated affected task %s using backup from unlearning task %s.", 145 task_id_to_compensate, 146 unlearning_task_id, 147 ) 148 149 def replay_repairing(self) -> None: 150 r"""Repairing the model with replay after unlearning.""" 151 152 task_ids_to_repair = self.model.affected_tasks_after_unlearning() 153 154 if len(task_ids_to_repair) == 0: 155 pylogger.info( 156 "No tasks to repair after unlearning. Skipping repairing phase." 157 ) 158 return 159 else: 160 pylogger.info( 161 "Starting replay repairing tasks %s, after unlearning: %s. Repair strategy: %s.", 162 task_ids_to_repair, 163 self.unlearning_task_ids, 164 self.repair_strategy, 165 ) 166 167 # align model device with replay buffer if needed (trainer may move model to CPU after fit) 168 buffer_device = ( 169 self.model.memory_buffer.examples.device 170 if self.model.memory_buffer.examples.numel() > 0 171 else next(self.model.parameters()).device 172 ) 173 if next(self.model.parameters()).device != buffer_device: 174 self.model.to(buffer_device) 175 model_device = next(self.model.parameters()).device 176 177 def _move_optimizer_state_to_device(optimizer, device: torch.device) -> None: 178 if isinstance(optimizer, (list, tuple)): 179 for opt in optimizer: 180 _move_optimizer_state_to_device(opt, device) 181 return 182 opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer 183 for state in opt.state.values(): 184 for key, value in state.items(): 185 if torch.is_tensor(value) and value.device != device: 186 state[key] = value.to(device) 187 188 # build the unlearning mask that aggregates all unlearning tasks. This mask is used to clip the gradients during replay repairing to prevent changing unaffected parameters 189 union_unlearning_mask = self.model.backbone.combine_masks( 190 [ 191 self.model.backbone.masks[unlearning_task_id] 192 for unlearning_task_id in self.unlearning_task_ids 193 ], 194 mode="union", 195 ) 196 union_unlearning_mask = { 197 layer_name: mask_tensor.to(model_device) 198 if mask_tensor.device != model_device 199 else mask_tensor 200 for layer_name, mask_tensor in union_unlearning_mask.items() 201 } 202 203 if ( 204 self.repair_strategy == "sequential_finetuning" 205 or self.repair_strategy == "sequential_adahat" 206 ): 207 208 summative_mask_for_previous_tasks_in_replay_repairing = { 209 layer_name: torch.zeros( 210 self.model.backbone.get_layer_by_name(layer_name).weight.shape[0], 211 device=model_device, 212 ) 213 for layer_name in self.model.backbone.weighted_layer_names 214 } 215 216 opt = self.model.optimizers() 217 _move_optimizer_state_to_device(opt, model_device) 218 219 total_tasks = len(task_ids_to_repair) 220 for task_index, task_id_to_repair in enumerate(task_ids_to_repair, start=1): 221 affected_task_mask = { 222 layer_name: layer_mask.to(model_device) 223 if layer_mask.device != model_device 224 else layer_mask 225 for layer_name, layer_mask in self.model.backbone.masks[ 226 task_id_to_repair 227 ].items() 228 } 229 unlearning_and_affected_mask = self.model.backbone.combine_masks( 230 [union_unlearning_mask, affected_task_mask], 231 mode="intersection", 232 ) 233 234 for s in track( 235 range(self.repair_num_steps), 236 description=f"Replay repairing task {task_id_to_repair} ({task_index}/{total_tasks})", 237 transient=True, 238 ): 239 # get replay data for repairing from memory buffer 240 x_replay, labels_replay, logits_replay, _ = ( 241 self.model.memory_buffer.get_data( 242 self.repair_batch_size, 243 included_tasks=[task_id_to_repair], 244 ) 245 ) 246 if x_replay.device != model_device: 247 x_replay = x_replay.to(model_device) 248 if labels_replay.device != model_device: 249 labels_replay = labels_replay.to(model_device) 250 if logits_replay.device != model_device: 251 logits_replay = logits_replay.to(model_device) 252 253 # zero the gradients before forward pass in manual optimization mode 254 opt.zero_grad() 255 256 student_feature_replay = self.model.backbone( 257 x_replay, 258 stage="test", 259 test_task_id=task_id_to_repair, 260 )[0] 261 262 student_logits_replay = self.model.heads( 263 student_feature_replay, task_id=task_id_to_repair 264 ) 265 266 with torch.no_grad(): # stop updating the previous heads 267 teacher_logits_replay = logits_replay 268 269 loss = self.model.distillation_reg( 270 student_logits=student_logits_replay, 271 teacher_logits=teacher_logits_replay, 272 ) 273 loss += self.model.replay_ce_factor * self.model.criterion( 274 student_logits_replay, labels_replay.long() 275 ) 276 277 self.model.manual_backward(loss) # calculate the gradients 278 279 # Clip gradients outside the intersection between 280 # unlearning-affected units and the current affected task. 281 282 self.model.clip_grad_by_mask( 283 mask=unlearning_and_affected_mask, aggregation_mode="min" 284 ) 285 286 if self.repair_strategy == "sequential_adahat": 287 self.model.clip_grad_by_adjustment_in_replay_repairing( 288 summative_mask_for_previous_tasks_in_replay_repairing=summative_mask_for_previous_tasks_in_replay_repairing 289 ) 290 291 # update parameters with the modified gradients 292 opt.step() 293 294 summative_mask_for_previous_tasks_in_replay_repairing = { 295 layer_name: summative_mask_for_previous_tasks_in_replay_repairing[ 296 layer_name 297 ] 298 + ( 299 self.model.backbone.masks[task_id_to_repair][layer_name].to( 300 model_device 301 ) 302 if self.model.backbone.masks[task_id_to_repair][ 303 layer_name 304 ].device 305 != model_device 306 else self.model.backbone.masks[task_id_to_repair][layer_name] 307 ) 308 for layer_name in self.model.backbone.weighted_layer_names 309 } 310 311 elif self.repair_strategy == "joint": 312 313 opt = self.model.optimizers() 314 _move_optimizer_state_to_device(opt, model_device) 315 316 for s in track( 317 range(self.repair_num_steps), 318 description="Replay repairing (joint)", 319 transient=True, 320 ): 321 322 # get replay data for repairing from memory buffer 323 x_replay, labels_replay, logits_replay, task_labels_replay = ( 324 self.model.memory_buffer.get_data( 325 self.repair_batch_size, 326 included_tasks=task_ids_to_repair, 327 ) 328 ) 329 if x_replay.device != model_device: 330 x_replay = x_replay.to(model_device) 331 if labels_replay.device != model_device: 332 labels_replay = labels_replay.to(model_device) 333 if logits_replay.device != model_device: 334 logits_replay = logits_replay.to(model_device) 335 if task_labels_replay.device != model_device: 336 task_labels_replay = task_labels_replay.to(model_device) 337 338 # zero the gradients before forward pass in manual optimization mode 339 opt.zero_grad() 340 341 student_feature_replay = torch.cat( 342 [ 343 self.model.backbone( 344 x_replay[i].unsqueeze(0), 345 stage="test", 346 test_task_id=tid.item(), 347 )[0] 348 for i, tid in enumerate(task_labels_replay) 349 ] 350 ) 351 352 student_logits_replay = torch.cat( 353 [ 354 self.model.heads( 355 student_feature_replay[i].unsqueeze(0), task_id=tid 356 ) 357 for i, tid in enumerate(task_labels_replay) 358 ] 359 ) 360 361 with torch.no_grad(): # stop updating the previous heads 362 teacher_logits_replay = logits_replay 363 364 loss = self.model.distillation_reg( 365 student_logits=student_logits_replay, 366 teacher_logits=teacher_logits_replay, 367 ) 368 loss += self.model.replay_ce_factor * self.model.criterion( 369 student_logits_replay, labels_replay.long() 370 ) 371 372 self.model.manual_backward(loss) # calculate the gradients 373 374 batch_task_ids = [int(task_id) for task_id in torch.unique(task_labels_replay)] 375 batch_affected_task_masks = [] 376 for task_id in batch_task_ids: 377 affected_task_mask = { 378 layer_name: layer_mask.to(model_device) 379 if layer_mask.device != model_device 380 else layer_mask 381 for layer_name, layer_mask in self.model.backbone.masks[ 382 task_id 383 ].items() 384 } 385 batch_affected_task_masks.append(affected_task_mask) 386 union_batch_affected_mask = self.model.backbone.combine_masks( 387 batch_affected_task_masks, 388 mode="union", 389 ) 390 unlearning_and_affected_mask = self.model.backbone.combine_masks( 391 [union_unlearning_mask, union_batch_affected_mask], 392 mode="intersection", 393 ) 394 395 self.model.clip_grad_by_mask( 396 mask=unlearning_and_affected_mask, aggregation_mode="min" 397 ) 398 399 # update parameters with the modified gradients 400 opt.step() 401 402 def unlearn(self) -> None: 403 r"""Unlearn the requested unlearning tasks (`self.unlearning_task_ids`) in the current task `self.task_id`.""" 404 405 # delete the corresponding parameter update records 406 self.delete_update() 407 408 for unlearning_task_id in self.unlearning_task_ids: 409 410 # delete the data of the unlearning task from the memory buffer 411 self.model.memory_buffer.delete_task(unlearning_task_id) 412 413 if self.if_backup_compensation: 414 self.compensate_by_backup() 415 416 if self.if_replay_repairing: 417 self.replay_repairing() 418 419 # do not delete the masks and other related info of the unlearning tasks, as they may be needed in testing
class
AmnesiacHATUnlearn(clarena.cul_algorithms.base.AmnesiacCULAlgorithm):
20class AmnesiacHATUnlearn(AmnesiacCULAlgorithm): 21 r"""The base class of the AmnesiacHAT unlearning algorithm.""" 22 23 def __init__( 24 self, 25 model: AmnesiacHAT, 26 if_backup_compensation: bool, 27 compensate_order: str | None, 28 if_replay_repairing: bool, 29 repair_batch_size: int | None, 30 repair_num_steps: int | None, 31 repair_strategy: str | None, 32 ) -> None: 33 r"""Initialize the unlearning algorithm with the continual learning model. 34 35 **Args:** 36 - **model** (`AmnesiacHAT`): the continual learning model (`CLAlgorithm` object which already contains the backbone and heads). It must be an `AmnesiacHAT` algorithm. 37 - **if_backup_compensation** (`bool`): whether to perform compensation using the backup backbones after unlearning. 38 - **compensate_order** (`str` | `None`): the order to compensate the affected tasks after unlearning (used when `if_backup_compensation` is 'True'), must be: 39 - 'forward': from oldest to newest. 40 - 'reverse': from newest to oldest. 41 - **if_replay_repairing** (`bool`): whether to perform replay after unlearning. 42 - **repair_batch_size** (`int` | `None`): the batch size used during the replay repairing after unlearning (used when `if_replay_repairing` is 'True'). 43 - **repair_num_steps** (`int` | `None`): the number of steps to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). 44 - **repair_strategy** (`str` | `None`): the strategy to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). must be: 45 - 'joint': use joint replay data from all affected tasks for repairing. 46 - 'sequential_finetuning': use replay data from each affected task one by one (from oldest to newest) for repairing, with only one epoch per task. This forms a mini continual learning process during repairing, where we use Finetuning (no additional operation) to learn each affected task sequentially. 47 - 'sequential_adahat': use replay data from each affected task one by one (from oldest to newest) for repairing. This forms a mini continual learning process during repairing, where we use AdaHAT (no mask sparsity reg) to learn each affected task sequentially. 48 """ 49 super().__init__(model=model) 50 51 self.if_backup_compensation: bool = if_backup_compensation 52 r"""Whether to perform compensation using the backup backbones after unlearning.""" 53 if self.if_backup_compensation: 54 self.compensate_order: str = compensate_order 55 r"""The order to compensate the affected tasks after unlearning.""" 56 57 self.if_replay_repairing: bool = if_replay_repairing 58 r"""Whether to perform replay repairing after unlearning.""" 59 if self.if_replay_repairing: 60 self.repair_batch_size: int = repair_batch_size 61 r"""The batch size used during the replay repairing after unlearning.""" 62 self.repair_num_steps: int = repair_num_steps 63 r"""The number of steps to perform replay repairing after unlearning.""" 64 self.repair_strategy: str = repair_strategy 65 r"""The strategy to perform replay repairing after unlearning.""" 66 67 def compensate_by_backup(self) -> None: 68 r"""Compensate the model using the backup backbones after unlearning.""" 69 70 unlearning_task_id = self.unlearning_task_ids[ 71 0 72 ] # only one unlearning task is supported for now 73 74 task_ids_to_compensate = self.model.affected_tasks_after_unlearning() 75 76 if len(task_ids_to_compensate) == 0: 77 pylogger.info( 78 "No tasks to compensate after unlearning. Skipping compensation phase." 79 ) 80 return 81 82 if self.compensate_order == "reverse": 83 task_ids_to_compensate.reverse() # compensate in reverse order 84 85 pylogger.debug( 86 "Affected tasks by unlearning task %s is %s, will be compensated in this order.", 87 unlearning_task_id, 88 task_ids_to_compensate, 89 ) 90 91 for task_id_to_compensate in task_ids_to_compensate: 92 93 # get the backup state dict 94 backup_state_dict = self.model.backbone.backup_state_dicts[ 95 (unlearning_task_id, task_id_to_compensate) 96 ] 97 98 # only compensate the intersected neurons between the unlearning task and the affected task 99 compensate_mask = self.model.backbone.combine_masks( 100 [ 101 self.model.backbone.masks[task_id_to_compensate], 102 self.model.backbone.masks[unlearning_task_id], 103 ], 104 mode="intersection", 105 ) 106 107 for layer_name in self.model.backbone.weighted_layer_names: 108 layer = self.model.backbone.get_layer_by_name(layer_name) 109 110 # construct parameter-wise mask for the layer 111 weight_mask, bias_mask = ( 112 self.model.backbone.get_layer_measure_parameter_wise( 113 neuron_wise_measure=compensate_mask, 114 layer_name=layer_name, 115 aggregation_mode="min", 116 ) 117 ) 118 119 # compensate the parameters using the backup state dict 120 target_device = layer.weight.device 121 target_dtype = layer.weight.dtype 122 if weight_mask.device != target_device: 123 weight_mask = weight_mask.to(device=target_device) 124 backup_weight = backup_state_dict[ 125 layer_name.replace("/", ".") + ".weight" 126 ].to(device=target_device, dtype=target_dtype) 127 layer.weight.data = torch.where( 128 weight_mask.bool(), 129 backup_weight, 130 layer.weight.data, 131 ) 132 if layer.bias is not None: 133 if bias_mask.device != target_device: 134 bias_mask = bias_mask.to(device=target_device) 135 backup_bias = backup_state_dict[ 136 layer_name.replace("/", ".") + ".bias" 137 ].to(device=target_device, dtype=layer.bias.dtype) 138 layer.bias.data = torch.where( 139 bias_mask.bool(), 140 backup_bias, 141 layer.bias.data, 142 ) 143 144 pylogger.debug( 145 "Compensated affected task %s using backup from unlearning task %s.", 146 task_id_to_compensate, 147 unlearning_task_id, 148 ) 149 150 def replay_repairing(self) -> None: 151 r"""Repairing the model with replay after unlearning.""" 152 153 task_ids_to_repair = self.model.affected_tasks_after_unlearning() 154 155 if len(task_ids_to_repair) == 0: 156 pylogger.info( 157 "No tasks to repair after unlearning. Skipping repairing phase." 158 ) 159 return 160 else: 161 pylogger.info( 162 "Starting replay repairing tasks %s, after unlearning: %s. Repair strategy: %s.", 163 task_ids_to_repair, 164 self.unlearning_task_ids, 165 self.repair_strategy, 166 ) 167 168 # align model device with replay buffer if needed (trainer may move model to CPU after fit) 169 buffer_device = ( 170 self.model.memory_buffer.examples.device 171 if self.model.memory_buffer.examples.numel() > 0 172 else next(self.model.parameters()).device 173 ) 174 if next(self.model.parameters()).device != buffer_device: 175 self.model.to(buffer_device) 176 model_device = next(self.model.parameters()).device 177 178 def _move_optimizer_state_to_device(optimizer, device: torch.device) -> None: 179 if isinstance(optimizer, (list, tuple)): 180 for opt in optimizer: 181 _move_optimizer_state_to_device(opt, device) 182 return 183 opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer 184 for state in opt.state.values(): 185 for key, value in state.items(): 186 if torch.is_tensor(value) and value.device != device: 187 state[key] = value.to(device) 188 189 # build the unlearning mask that aggregates all unlearning tasks. This mask is used to clip the gradients during replay repairing to prevent changing unaffected parameters 190 union_unlearning_mask = self.model.backbone.combine_masks( 191 [ 192 self.model.backbone.masks[unlearning_task_id] 193 for unlearning_task_id in self.unlearning_task_ids 194 ], 195 mode="union", 196 ) 197 union_unlearning_mask = { 198 layer_name: mask_tensor.to(model_device) 199 if mask_tensor.device != model_device 200 else mask_tensor 201 for layer_name, mask_tensor in union_unlearning_mask.items() 202 } 203 204 if ( 205 self.repair_strategy == "sequential_finetuning" 206 or self.repair_strategy == "sequential_adahat" 207 ): 208 209 summative_mask_for_previous_tasks_in_replay_repairing = { 210 layer_name: torch.zeros( 211 self.model.backbone.get_layer_by_name(layer_name).weight.shape[0], 212 device=model_device, 213 ) 214 for layer_name in self.model.backbone.weighted_layer_names 215 } 216 217 opt = self.model.optimizers() 218 _move_optimizer_state_to_device(opt, model_device) 219 220 total_tasks = len(task_ids_to_repair) 221 for task_index, task_id_to_repair in enumerate(task_ids_to_repair, start=1): 222 affected_task_mask = { 223 layer_name: layer_mask.to(model_device) 224 if layer_mask.device != model_device 225 else layer_mask 226 for layer_name, layer_mask in self.model.backbone.masks[ 227 task_id_to_repair 228 ].items() 229 } 230 unlearning_and_affected_mask = self.model.backbone.combine_masks( 231 [union_unlearning_mask, affected_task_mask], 232 mode="intersection", 233 ) 234 235 for s in track( 236 range(self.repair_num_steps), 237 description=f"Replay repairing task {task_id_to_repair} ({task_index}/{total_tasks})", 238 transient=True, 239 ): 240 # get replay data for repairing from memory buffer 241 x_replay, labels_replay, logits_replay, _ = ( 242 self.model.memory_buffer.get_data( 243 self.repair_batch_size, 244 included_tasks=[task_id_to_repair], 245 ) 246 ) 247 if x_replay.device != model_device: 248 x_replay = x_replay.to(model_device) 249 if labels_replay.device != model_device: 250 labels_replay = labels_replay.to(model_device) 251 if logits_replay.device != model_device: 252 logits_replay = logits_replay.to(model_device) 253 254 # zero the gradients before forward pass in manual optimization mode 255 opt.zero_grad() 256 257 student_feature_replay = self.model.backbone( 258 x_replay, 259 stage="test", 260 test_task_id=task_id_to_repair, 261 )[0] 262 263 student_logits_replay = self.model.heads( 264 student_feature_replay, task_id=task_id_to_repair 265 ) 266 267 with torch.no_grad(): # stop updating the previous heads 268 teacher_logits_replay = logits_replay 269 270 loss = self.model.distillation_reg( 271 student_logits=student_logits_replay, 272 teacher_logits=teacher_logits_replay, 273 ) 274 loss += self.model.replay_ce_factor * self.model.criterion( 275 student_logits_replay, labels_replay.long() 276 ) 277 278 self.model.manual_backward(loss) # calculate the gradients 279 280 # Clip gradients outside the intersection between 281 # unlearning-affected units and the current affected task. 282 283 self.model.clip_grad_by_mask( 284 mask=unlearning_and_affected_mask, aggregation_mode="min" 285 ) 286 287 if self.repair_strategy == "sequential_adahat": 288 self.model.clip_grad_by_adjustment_in_replay_repairing( 289 summative_mask_for_previous_tasks_in_replay_repairing=summative_mask_for_previous_tasks_in_replay_repairing 290 ) 291 292 # update parameters with the modified gradients 293 opt.step() 294 295 summative_mask_for_previous_tasks_in_replay_repairing = { 296 layer_name: summative_mask_for_previous_tasks_in_replay_repairing[ 297 layer_name 298 ] 299 + ( 300 self.model.backbone.masks[task_id_to_repair][layer_name].to( 301 model_device 302 ) 303 if self.model.backbone.masks[task_id_to_repair][ 304 layer_name 305 ].device 306 != model_device 307 else self.model.backbone.masks[task_id_to_repair][layer_name] 308 ) 309 for layer_name in self.model.backbone.weighted_layer_names 310 } 311 312 elif self.repair_strategy == "joint": 313 314 opt = self.model.optimizers() 315 _move_optimizer_state_to_device(opt, model_device) 316 317 for s in track( 318 range(self.repair_num_steps), 319 description="Replay repairing (joint)", 320 transient=True, 321 ): 322 323 # get replay data for repairing from memory buffer 324 x_replay, labels_replay, logits_replay, task_labels_replay = ( 325 self.model.memory_buffer.get_data( 326 self.repair_batch_size, 327 included_tasks=task_ids_to_repair, 328 ) 329 ) 330 if x_replay.device != model_device: 331 x_replay = x_replay.to(model_device) 332 if labels_replay.device != model_device: 333 labels_replay = labels_replay.to(model_device) 334 if logits_replay.device != model_device: 335 logits_replay = logits_replay.to(model_device) 336 if task_labels_replay.device != model_device: 337 task_labels_replay = task_labels_replay.to(model_device) 338 339 # zero the gradients before forward pass in manual optimization mode 340 opt.zero_grad() 341 342 student_feature_replay = torch.cat( 343 [ 344 self.model.backbone( 345 x_replay[i].unsqueeze(0), 346 stage="test", 347 test_task_id=tid.item(), 348 )[0] 349 for i, tid in enumerate(task_labels_replay) 350 ] 351 ) 352 353 student_logits_replay = torch.cat( 354 [ 355 self.model.heads( 356 student_feature_replay[i].unsqueeze(0), task_id=tid 357 ) 358 for i, tid in enumerate(task_labels_replay) 359 ] 360 ) 361 362 with torch.no_grad(): # stop updating the previous heads 363 teacher_logits_replay = logits_replay 364 365 loss = self.model.distillation_reg( 366 student_logits=student_logits_replay, 367 teacher_logits=teacher_logits_replay, 368 ) 369 loss += self.model.replay_ce_factor * self.model.criterion( 370 student_logits_replay, labels_replay.long() 371 ) 372 373 self.model.manual_backward(loss) # calculate the gradients 374 375 batch_task_ids = [int(task_id) for task_id in torch.unique(task_labels_replay)] 376 batch_affected_task_masks = [] 377 for task_id in batch_task_ids: 378 affected_task_mask = { 379 layer_name: layer_mask.to(model_device) 380 if layer_mask.device != model_device 381 else layer_mask 382 for layer_name, layer_mask in self.model.backbone.masks[ 383 task_id 384 ].items() 385 } 386 batch_affected_task_masks.append(affected_task_mask) 387 union_batch_affected_mask = self.model.backbone.combine_masks( 388 batch_affected_task_masks, 389 mode="union", 390 ) 391 unlearning_and_affected_mask = self.model.backbone.combine_masks( 392 [union_unlearning_mask, union_batch_affected_mask], 393 mode="intersection", 394 ) 395 396 self.model.clip_grad_by_mask( 397 mask=unlearning_and_affected_mask, aggregation_mode="min" 398 ) 399 400 # update parameters with the modified gradients 401 opt.step() 402 403 def unlearn(self) -> None: 404 r"""Unlearn the requested unlearning tasks (`self.unlearning_task_ids`) in the current task `self.task_id`.""" 405 406 # delete the corresponding parameter update records 407 self.delete_update() 408 409 for unlearning_task_id in self.unlearning_task_ids: 410 411 # delete the data of the unlearning task from the memory buffer 412 self.model.memory_buffer.delete_task(unlearning_task_id) 413 414 if self.if_backup_compensation: 415 self.compensate_by_backup() 416 417 if self.if_replay_repairing: 418 self.replay_repairing() 419 420 # do not delete the masks and other related info of the unlearning tasks, as they may be needed in testing
The base class of the AmnesiacHAT unlearning algorithm.
AmnesiacHATUnlearn( model: clarena.cl_algorithms.amnesiac_hat.AmnesiacHAT, if_backup_compensation: bool, compensate_order: str | None, if_replay_repairing: bool, repair_batch_size: int | None, repair_num_steps: int | None, repair_strategy: str | None)
23 def __init__( 24 self, 25 model: AmnesiacHAT, 26 if_backup_compensation: bool, 27 compensate_order: str | None, 28 if_replay_repairing: bool, 29 repair_batch_size: int | None, 30 repair_num_steps: int | None, 31 repair_strategy: str | None, 32 ) -> None: 33 r"""Initialize the unlearning algorithm with the continual learning model. 34 35 **Args:** 36 - **model** (`AmnesiacHAT`): the continual learning model (`CLAlgorithm` object which already contains the backbone and heads). It must be an `AmnesiacHAT` algorithm. 37 - **if_backup_compensation** (`bool`): whether to perform compensation using the backup backbones after unlearning. 38 - **compensate_order** (`str` | `None`): the order to compensate the affected tasks after unlearning (used when `if_backup_compensation` is 'True'), must be: 39 - 'forward': from oldest to newest. 40 - 'reverse': from newest to oldest. 41 - **if_replay_repairing** (`bool`): whether to perform replay after unlearning. 42 - **repair_batch_size** (`int` | `None`): the batch size used during the replay repairing after unlearning (used when `if_replay_repairing` is 'True'). 43 - **repair_num_steps** (`int` | `None`): the number of steps to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). 44 - **repair_strategy** (`str` | `None`): the strategy to perform replay repairing after unlearning (used when `if_replay_repairing` is 'True'). must be: 45 - 'joint': use joint replay data from all affected tasks for repairing. 46 - 'sequential_finetuning': use replay data from each affected task one by one (from oldest to newest) for repairing, with only one epoch per task. This forms a mini continual learning process during repairing, where we use Finetuning (no additional operation) to learn each affected task sequentially. 47 - 'sequential_adahat': use replay data from each affected task one by one (from oldest to newest) for repairing. This forms a mini continual learning process during repairing, where we use AdaHAT (no mask sparsity reg) to learn each affected task sequentially. 48 """ 49 super().__init__(model=model) 50 51 self.if_backup_compensation: bool = if_backup_compensation 52 r"""Whether to perform compensation using the backup backbones after unlearning.""" 53 if self.if_backup_compensation: 54 self.compensate_order: str = compensate_order 55 r"""The order to compensate the affected tasks after unlearning.""" 56 57 self.if_replay_repairing: bool = if_replay_repairing 58 r"""Whether to perform replay repairing after unlearning.""" 59 if self.if_replay_repairing: 60 self.repair_batch_size: int = repair_batch_size 61 r"""The batch size used during the replay repairing after unlearning.""" 62 self.repair_num_steps: int = repair_num_steps 63 r"""The number of steps to perform replay repairing after unlearning.""" 64 self.repair_strategy: str = repair_strategy 65 r"""The strategy to perform replay repairing after unlearning."""
Initialize the unlearning algorithm with the continual learning model.
Args:
- model (
AmnesiacHAT): the continual learning model (CLAlgorithmobject which already contains the backbone and heads). It must be anAmnesiacHATalgorithm. - if_backup_compensation (
bool): whether to perform compensation using the backup backbones after unlearning. - compensate_order (
str|None): the order to compensate the affected tasks after unlearning (used whenif_backup_compensationis 'True'), must be:- 'forward': from oldest to newest.
- 'reverse': from newest to oldest.
- if_replay_repairing (
bool): whether to perform replay after unlearning. - repair_batch_size (
int|None): the batch size used during the replay repairing after unlearning (used whenif_replay_repairingis 'True'). - repair_num_steps (
int|None): the number of steps to perform replay repairing after unlearning (used whenif_replay_repairingis 'True'). - repair_strategy (
str|None): the strategy to perform replay repairing after unlearning (used whenif_replay_repairingis 'True'). must be:- 'joint': use joint replay data from all affected tasks for repairing.
- 'sequential_finetuning': use replay data from each affected task one by one (from oldest to newest) for repairing, with only one epoch per task. This forms a mini continual learning process during repairing, where we use Finetuning (no additional operation) to learn each affected task sequentially.
- 'sequential_adahat': use replay data from each affected task one by one (from oldest to newest) for repairing. This forms a mini continual learning process during repairing, where we use AdaHAT (no mask sparsity reg) to learn each affected task sequentially.
if_backup_compensation: bool
Whether to perform compensation using the backup backbones after unlearning.
def
compensate_by_backup(self) -> None:
67 def compensate_by_backup(self) -> None: 68 r"""Compensate the model using the backup backbones after unlearning.""" 69 70 unlearning_task_id = self.unlearning_task_ids[ 71 0 72 ] # only one unlearning task is supported for now 73 74 task_ids_to_compensate = self.model.affected_tasks_after_unlearning() 75 76 if len(task_ids_to_compensate) == 0: 77 pylogger.info( 78 "No tasks to compensate after unlearning. Skipping compensation phase." 79 ) 80 return 81 82 if self.compensate_order == "reverse": 83 task_ids_to_compensate.reverse() # compensate in reverse order 84 85 pylogger.debug( 86 "Affected tasks by unlearning task %s is %s, will be compensated in this order.", 87 unlearning_task_id, 88 task_ids_to_compensate, 89 ) 90 91 for task_id_to_compensate in task_ids_to_compensate: 92 93 # get the backup state dict 94 backup_state_dict = self.model.backbone.backup_state_dicts[ 95 (unlearning_task_id, task_id_to_compensate) 96 ] 97 98 # only compensate the intersected neurons between the unlearning task and the affected task 99 compensate_mask = self.model.backbone.combine_masks( 100 [ 101 self.model.backbone.masks[task_id_to_compensate], 102 self.model.backbone.masks[unlearning_task_id], 103 ], 104 mode="intersection", 105 ) 106 107 for layer_name in self.model.backbone.weighted_layer_names: 108 layer = self.model.backbone.get_layer_by_name(layer_name) 109 110 # construct parameter-wise mask for the layer 111 weight_mask, bias_mask = ( 112 self.model.backbone.get_layer_measure_parameter_wise( 113 neuron_wise_measure=compensate_mask, 114 layer_name=layer_name, 115 aggregation_mode="min", 116 ) 117 ) 118 119 # compensate the parameters using the backup state dict 120 target_device = layer.weight.device 121 target_dtype = layer.weight.dtype 122 if weight_mask.device != target_device: 123 weight_mask = weight_mask.to(device=target_device) 124 backup_weight = backup_state_dict[ 125 layer_name.replace("/", ".") + ".weight" 126 ].to(device=target_device, dtype=target_dtype) 127 layer.weight.data = torch.where( 128 weight_mask.bool(), 129 backup_weight, 130 layer.weight.data, 131 ) 132 if layer.bias is not None: 133 if bias_mask.device != target_device: 134 bias_mask = bias_mask.to(device=target_device) 135 backup_bias = backup_state_dict[ 136 layer_name.replace("/", ".") + ".bias" 137 ].to(device=target_device, dtype=layer.bias.dtype) 138 layer.bias.data = torch.where( 139 bias_mask.bool(), 140 backup_bias, 141 layer.bias.data, 142 ) 143 144 pylogger.debug( 145 "Compensated affected task %s using backup from unlearning task %s.", 146 task_id_to_compensate, 147 unlearning_task_id, 148 )
Compensate the model using the backup backbones after unlearning.
def
replay_repairing(self) -> None:
150 def replay_repairing(self) -> None: 151 r"""Repairing the model with replay after unlearning.""" 152 153 task_ids_to_repair = self.model.affected_tasks_after_unlearning() 154 155 if len(task_ids_to_repair) == 0: 156 pylogger.info( 157 "No tasks to repair after unlearning. Skipping repairing phase." 158 ) 159 return 160 else: 161 pylogger.info( 162 "Starting replay repairing tasks %s, after unlearning: %s. Repair strategy: %s.", 163 task_ids_to_repair, 164 self.unlearning_task_ids, 165 self.repair_strategy, 166 ) 167 168 # align model device with replay buffer if needed (trainer may move model to CPU after fit) 169 buffer_device = ( 170 self.model.memory_buffer.examples.device 171 if self.model.memory_buffer.examples.numel() > 0 172 else next(self.model.parameters()).device 173 ) 174 if next(self.model.parameters()).device != buffer_device: 175 self.model.to(buffer_device) 176 model_device = next(self.model.parameters()).device 177 178 def _move_optimizer_state_to_device(optimizer, device: torch.device) -> None: 179 if isinstance(optimizer, (list, tuple)): 180 for opt in optimizer: 181 _move_optimizer_state_to_device(opt, device) 182 return 183 opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer 184 for state in opt.state.values(): 185 for key, value in state.items(): 186 if torch.is_tensor(value) and value.device != device: 187 state[key] = value.to(device) 188 189 # build the unlearning mask that aggregates all unlearning tasks. This mask is used to clip the gradients during replay repairing to prevent changing unaffected parameters 190 union_unlearning_mask = self.model.backbone.combine_masks( 191 [ 192 self.model.backbone.masks[unlearning_task_id] 193 for unlearning_task_id in self.unlearning_task_ids 194 ], 195 mode="union", 196 ) 197 union_unlearning_mask = { 198 layer_name: mask_tensor.to(model_device) 199 if mask_tensor.device != model_device 200 else mask_tensor 201 for layer_name, mask_tensor in union_unlearning_mask.items() 202 } 203 204 if ( 205 self.repair_strategy == "sequential_finetuning" 206 or self.repair_strategy == "sequential_adahat" 207 ): 208 209 summative_mask_for_previous_tasks_in_replay_repairing = { 210 layer_name: torch.zeros( 211 self.model.backbone.get_layer_by_name(layer_name).weight.shape[0], 212 device=model_device, 213 ) 214 for layer_name in self.model.backbone.weighted_layer_names 215 } 216 217 opt = self.model.optimizers() 218 _move_optimizer_state_to_device(opt, model_device) 219 220 total_tasks = len(task_ids_to_repair) 221 for task_index, task_id_to_repair in enumerate(task_ids_to_repair, start=1): 222 affected_task_mask = { 223 layer_name: layer_mask.to(model_device) 224 if layer_mask.device != model_device 225 else layer_mask 226 for layer_name, layer_mask in self.model.backbone.masks[ 227 task_id_to_repair 228 ].items() 229 } 230 unlearning_and_affected_mask = self.model.backbone.combine_masks( 231 [union_unlearning_mask, affected_task_mask], 232 mode="intersection", 233 ) 234 235 for s in track( 236 range(self.repair_num_steps), 237 description=f"Replay repairing task {task_id_to_repair} ({task_index}/{total_tasks})", 238 transient=True, 239 ): 240 # get replay data for repairing from memory buffer 241 x_replay, labels_replay, logits_replay, _ = ( 242 self.model.memory_buffer.get_data( 243 self.repair_batch_size, 244 included_tasks=[task_id_to_repair], 245 ) 246 ) 247 if x_replay.device != model_device: 248 x_replay = x_replay.to(model_device) 249 if labels_replay.device != model_device: 250 labels_replay = labels_replay.to(model_device) 251 if logits_replay.device != model_device: 252 logits_replay = logits_replay.to(model_device) 253 254 # zero the gradients before forward pass in manual optimization mode 255 opt.zero_grad() 256 257 student_feature_replay = self.model.backbone( 258 x_replay, 259 stage="test", 260 test_task_id=task_id_to_repair, 261 )[0] 262 263 student_logits_replay = self.model.heads( 264 student_feature_replay, task_id=task_id_to_repair 265 ) 266 267 with torch.no_grad(): # stop updating the previous heads 268 teacher_logits_replay = logits_replay 269 270 loss = self.model.distillation_reg( 271 student_logits=student_logits_replay, 272 teacher_logits=teacher_logits_replay, 273 ) 274 loss += self.model.replay_ce_factor * self.model.criterion( 275 student_logits_replay, labels_replay.long() 276 ) 277 278 self.model.manual_backward(loss) # calculate the gradients 279 280 # Clip gradients outside the intersection between 281 # unlearning-affected units and the current affected task. 282 283 self.model.clip_grad_by_mask( 284 mask=unlearning_and_affected_mask, aggregation_mode="min" 285 ) 286 287 if self.repair_strategy == "sequential_adahat": 288 self.model.clip_grad_by_adjustment_in_replay_repairing( 289 summative_mask_for_previous_tasks_in_replay_repairing=summative_mask_for_previous_tasks_in_replay_repairing 290 ) 291 292 # update parameters with the modified gradients 293 opt.step() 294 295 summative_mask_for_previous_tasks_in_replay_repairing = { 296 layer_name: summative_mask_for_previous_tasks_in_replay_repairing[ 297 layer_name 298 ] 299 + ( 300 self.model.backbone.masks[task_id_to_repair][layer_name].to( 301 model_device 302 ) 303 if self.model.backbone.masks[task_id_to_repair][ 304 layer_name 305 ].device 306 != model_device 307 else self.model.backbone.masks[task_id_to_repair][layer_name] 308 ) 309 for layer_name in self.model.backbone.weighted_layer_names 310 } 311 312 elif self.repair_strategy == "joint": 313 314 opt = self.model.optimizers() 315 _move_optimizer_state_to_device(opt, model_device) 316 317 for s in track( 318 range(self.repair_num_steps), 319 description="Replay repairing (joint)", 320 transient=True, 321 ): 322 323 # get replay data for repairing from memory buffer 324 x_replay, labels_replay, logits_replay, task_labels_replay = ( 325 self.model.memory_buffer.get_data( 326 self.repair_batch_size, 327 included_tasks=task_ids_to_repair, 328 ) 329 ) 330 if x_replay.device != model_device: 331 x_replay = x_replay.to(model_device) 332 if labels_replay.device != model_device: 333 labels_replay = labels_replay.to(model_device) 334 if logits_replay.device != model_device: 335 logits_replay = logits_replay.to(model_device) 336 if task_labels_replay.device != model_device: 337 task_labels_replay = task_labels_replay.to(model_device) 338 339 # zero the gradients before forward pass in manual optimization mode 340 opt.zero_grad() 341 342 student_feature_replay = torch.cat( 343 [ 344 self.model.backbone( 345 x_replay[i].unsqueeze(0), 346 stage="test", 347 test_task_id=tid.item(), 348 )[0] 349 for i, tid in enumerate(task_labels_replay) 350 ] 351 ) 352 353 student_logits_replay = torch.cat( 354 [ 355 self.model.heads( 356 student_feature_replay[i].unsqueeze(0), task_id=tid 357 ) 358 for i, tid in enumerate(task_labels_replay) 359 ] 360 ) 361 362 with torch.no_grad(): # stop updating the previous heads 363 teacher_logits_replay = logits_replay 364 365 loss = self.model.distillation_reg( 366 student_logits=student_logits_replay, 367 teacher_logits=teacher_logits_replay, 368 ) 369 loss += self.model.replay_ce_factor * self.model.criterion( 370 student_logits_replay, labels_replay.long() 371 ) 372 373 self.model.manual_backward(loss) # calculate the gradients 374 375 batch_task_ids = [int(task_id) for task_id in torch.unique(task_labels_replay)] 376 batch_affected_task_masks = [] 377 for task_id in batch_task_ids: 378 affected_task_mask = { 379 layer_name: layer_mask.to(model_device) 380 if layer_mask.device != model_device 381 else layer_mask 382 for layer_name, layer_mask in self.model.backbone.masks[ 383 task_id 384 ].items() 385 } 386 batch_affected_task_masks.append(affected_task_mask) 387 union_batch_affected_mask = self.model.backbone.combine_masks( 388 batch_affected_task_masks, 389 mode="union", 390 ) 391 unlearning_and_affected_mask = self.model.backbone.combine_masks( 392 [union_unlearning_mask, union_batch_affected_mask], 393 mode="intersection", 394 ) 395 396 self.model.clip_grad_by_mask( 397 mask=unlearning_and_affected_mask, aggregation_mode="min" 398 ) 399 400 # update parameters with the modified gradients 401 opt.step()
Repairing the model with replay after unlearning.
def
unlearn(self) -> None:
403 def unlearn(self) -> None: 404 r"""Unlearn the requested unlearning tasks (`self.unlearning_task_ids`) in the current task `self.task_id`.""" 405 406 # delete the corresponding parameter update records 407 self.delete_update() 408 409 for unlearning_task_id in self.unlearning_task_ids: 410 411 # delete the data of the unlearning task from the memory buffer 412 self.model.memory_buffer.delete_task(unlearning_task_id) 413 414 if self.if_backup_compensation: 415 self.compensate_by_backup() 416 417 if self.if_replay_repairing: 418 self.replay_repairing() 419 420 # do not delete the masks and other related info of the unlearning tasks, as they may be needed in testing
Unlearn the requested unlearning tasks (self.unlearning_task_ids) in the current task self.task_id.