clarena.utils.eval
The submodule in utils for evaluation utilities.
1r"""The submodule in `utils` for evaluation utilities.""" 2 3import logging 4 5from lightning import LightningModule 6from torch import Tensor, nn 7from torch.utils.data import DataLoader 8 9from clarena.cl_algorithms import CLAlgorithm 10from clarena.cul_algorithms import CULAlgorithm 11 12# always get logger for built-in logging in each module 13pylogger = logging.getLogger(__name__) 14 15 16class CULEvaluation(LightningModule): 17 r"""Full evaluation module for continual unlearning.""" 18 19 def __init__( 20 self, 21 main_model: CULAlgorithm, 22 refretrain_model: CLAlgorithm, 23 reforiginal_model: CLAlgorithm | None, 24 dd_eval_task_ids: list[int], 25 ag_eval_task_ids: list[int], 26 ): 27 r""" 28 **Args:** 29 - **main_model** (`CULAlgorithm`): the main model to evaluate. 30 - **refretrain_model** (`CLAlgorithm`): the reference retrain model to evaluate against. 31 - **reforiginal_model** (`CLAlgorithm` | `None`): the reference original model that has been trained on all tasks. If `None`, AG-related evaluation is skipped. 32 - **dd_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the DD on. 33 - **ag_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the accuracy gain on. 34 """ 35 super().__init__() 36 37 self.criterion = nn.CrossEntropyLoss() 38 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 39 40 self.main_model = main_model 41 r"""The main model for evaluation.""" 42 self.refretrain_model = refretrain_model 43 r"""The reference retrain model for evaluation.""" 44 self.reforiginal_model = reforiginal_model 45 r"""The reference original model for evaluation.""" 46 47 self.dd_eval_task_ids: list[int] = dd_eval_task_ids 48 r"""The task IDs to evaluate the DD on. """ 49 self.ag_eval_task_ids: list[int] = ag_eval_task_ids 50 r"""The task IDs to evaluate the AG on. """ 51 52 def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 53 r"""Get the test task ID from the dataloader index. 54 55 **Args:** 56 - **dataloader_idx** (`int`): the dataloader index. 57 58 **Returns:** 59 - **test_task_id** (`int`): the test task ID. 60 """ 61 dataset_test = self.trainer.datamodule.dataset_test 62 test_task_id = list(dataset_test.keys())[dataloader_idx] 63 return test_task_id 64 65 def test_step( 66 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 67 ) -> dict[str, Tensor]: 68 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 69 70 **Args:** 71 - **batch** (`Any`): a batch of test data. 72 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 73 74 **Returns:** 75 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 76 """ 77 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 78 79 x, y = batch 80 81 # get the aggregated backbone output (instead of logits) 82 agg_out_main = self.main_model.aggregated_backbone_output(x) 83 agg_out_ref = self.refretrain_model.aggregated_backbone_output(x) 84 85 outputs = { 86 "agg_out_main": agg_out_main, 87 "agg_out_ref": agg_out_ref, 88 } 89 90 if self.reforiginal_model is not None and test_task_id in self.ag_eval_task_ids: 91 logits_main = self.main_model.forward( 92 x, stage="test", task_id=test_task_id 93 )[ 94 0 95 ] # use the corresponding head to test (instead of the current task `self.task_id`) 96 acc_main = (logits_main.argmax(dim=1) == y).float().mean() 97 98 logits_full = self.reforiginal_model.forward( 99 x, stage="test", task_id=test_task_id 100 )[ 101 0 102 ] # use the corresponding head to test (instead of the current task `self.task_id`) 103 acc_full = (logits_full.argmax(dim=1) == y).float().mean() 104 105 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 106 outputs["acc_gain"] = acc_main - acc_full 107 108 return outputs 109 110 111# print("Unlearning JS divergence results:", eval_module.results) 112 113 114# def compute_unlearning_metrics(output_dir: str) -> None: 115# r"""Compute the unlearning metrics for the continual unlearning experiment and save the results to the `results/` in the output directory. 116 117# **Args:** 118# - **output_dir** (`str`): the output directory path of the continual unlearning experiment. This directory must contain a `unlearning_ref` directory, which contains the output directory of the unlearning reference experiment. 119# """ 120 121# # initialize unlearning test metrics for unlearned tasks 122 123# for unlearned_task_id in unlearned_task_ids: 124# # test on the unlearned task 125 126# test_dataloader = datamodule.test_dataloader()[ 127# f"{unlearned_task_id}" 128# ] # get the test data 129 130# # set the model to evaluation mode 131# model.to("cpu") 132# model.eval() 133# model_unlearning_test_reference.eval() 134 135# for batch in test_dataloader: 136# # unlearning test step 137# x, _ = batch 138# batch_size = len(batch) 139 140# with torch.no_grad(): 141 142# # get the aggregated backbone output (instead of logits) 143# aggregated_backbone_output = model.aggregated_backbone_output(x) 144# aggregated_backbone_output_unlearning_test_reference = ( 145# model_unlearning_test_reference.aggregated_backbone_output(x) 146# ) 147 148# # calculate the Jensen-Shannon divergence as distribution distance 149# js = js_div( 150# aggregated_backbone_output, 151# aggregated_backbone_output_unlearning_test_reference, 152# ) 153 154# print("js", js) 155# print("js", js) 156# print("js", js) 157# print("js", js) 158# print("js", js) 159# print("js", js)
pylogger =
<Logger clarena.utils.eval (WARNING)>
class
CULEvaluation(lightning.pytorch.core.module.LightningModule):
17class CULEvaluation(LightningModule): 18 r"""Full evaluation module for continual unlearning.""" 19 20 def __init__( 21 self, 22 main_model: CULAlgorithm, 23 refretrain_model: CLAlgorithm, 24 reforiginal_model: CLAlgorithm | None, 25 dd_eval_task_ids: list[int], 26 ag_eval_task_ids: list[int], 27 ): 28 r""" 29 **Args:** 30 - **main_model** (`CULAlgorithm`): the main model to evaluate. 31 - **refretrain_model** (`CLAlgorithm`): the reference retrain model to evaluate against. 32 - **reforiginal_model** (`CLAlgorithm` | `None`): the reference original model that has been trained on all tasks. If `None`, AG-related evaluation is skipped. 33 - **dd_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the DD on. 34 - **ag_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the accuracy gain on. 35 """ 36 super().__init__() 37 38 self.criterion = nn.CrossEntropyLoss() 39 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 40 41 self.main_model = main_model 42 r"""The main model for evaluation.""" 43 self.refretrain_model = refretrain_model 44 r"""The reference retrain model for evaluation.""" 45 self.reforiginal_model = reforiginal_model 46 r"""The reference original model for evaluation.""" 47 48 self.dd_eval_task_ids: list[int] = dd_eval_task_ids 49 r"""The task IDs to evaluate the DD on. """ 50 self.ag_eval_task_ids: list[int] = ag_eval_task_ids 51 r"""The task IDs to evaluate the AG on. """ 52 53 def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 54 r"""Get the test task ID from the dataloader index. 55 56 **Args:** 57 - **dataloader_idx** (`int`): the dataloader index. 58 59 **Returns:** 60 - **test_task_id** (`int`): the test task ID. 61 """ 62 dataset_test = self.trainer.datamodule.dataset_test 63 test_task_id = list(dataset_test.keys())[dataloader_idx] 64 return test_task_id 65 66 def test_step( 67 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 68 ) -> dict[str, Tensor]: 69 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 70 71 **Args:** 72 - **batch** (`Any`): a batch of test data. 73 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 74 75 **Returns:** 76 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 77 """ 78 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 79 80 x, y = batch 81 82 # get the aggregated backbone output (instead of logits) 83 agg_out_main = self.main_model.aggregated_backbone_output(x) 84 agg_out_ref = self.refretrain_model.aggregated_backbone_output(x) 85 86 outputs = { 87 "agg_out_main": agg_out_main, 88 "agg_out_ref": agg_out_ref, 89 } 90 91 if self.reforiginal_model is not None and test_task_id in self.ag_eval_task_ids: 92 logits_main = self.main_model.forward( 93 x, stage="test", task_id=test_task_id 94 )[ 95 0 96 ] # use the corresponding head to test (instead of the current task `self.task_id`) 97 acc_main = (logits_main.argmax(dim=1) == y).float().mean() 98 99 logits_full = self.reforiginal_model.forward( 100 x, stage="test", task_id=test_task_id 101 )[ 102 0 103 ] # use the corresponding head to test (instead of the current task `self.task_id`) 104 acc_full = (logits_full.argmax(dim=1) == y).float().mean() 105 106 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 107 outputs["acc_gain"] = acc_main - acc_full 108 109 return outputs
Full evaluation module for continual unlearning.
CULEvaluation( main_model: clarena.cul_algorithms.CULAlgorithm, refretrain_model: clarena.cl_algorithms.CLAlgorithm, reforiginal_model: clarena.cl_algorithms.CLAlgorithm | None, dd_eval_task_ids: list[int], ag_eval_task_ids: list[int])
20 def __init__( 21 self, 22 main_model: CULAlgorithm, 23 refretrain_model: CLAlgorithm, 24 reforiginal_model: CLAlgorithm | None, 25 dd_eval_task_ids: list[int], 26 ag_eval_task_ids: list[int], 27 ): 28 r""" 29 **Args:** 30 - **main_model** (`CULAlgorithm`): the main model to evaluate. 31 - **refretrain_model** (`CLAlgorithm`): the reference retrain model to evaluate against. 32 - **reforiginal_model** (`CLAlgorithm` | `None`): the reference original model that has been trained on all tasks. If `None`, AG-related evaluation is skipped. 33 - **dd_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the DD on. 34 - **ag_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the accuracy gain on. 35 """ 36 super().__init__() 37 38 self.criterion = nn.CrossEntropyLoss() 39 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 40 41 self.main_model = main_model 42 r"""The main model for evaluation.""" 43 self.refretrain_model = refretrain_model 44 r"""The reference retrain model for evaluation.""" 45 self.reforiginal_model = reforiginal_model 46 r"""The reference original model for evaluation.""" 47 48 self.dd_eval_task_ids: list[int] = dd_eval_task_ids 49 r"""The task IDs to evaluate the DD on. """ 50 self.ag_eval_task_ids: list[int] = ag_eval_task_ids 51 r"""The task IDs to evaluate the AG on. """
Args:
- main_model (
CULAlgorithm): the main model to evaluate. - refretrain_model (
CLAlgorithm): the reference retrain model to evaluate against. - reforiginal_model (
CLAlgorithm|None): the reference original model that has been trained on all tasks. IfNone, AG-related evaluation is skipped. - dd_eval_task_ids (
list[int]): the list of task IDs to evaluate the DD on. - ag_eval_task_ids (
list[int]): the list of task IDs to evaluate the accuracy gain on.
criterion
The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.
def
get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
53 def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 54 r"""Get the test task ID from the dataloader index. 55 56 **Args:** 57 - **dataloader_idx** (`int`): the dataloader index. 58 59 **Returns:** 60 - **test_task_id** (`int`): the test task ID. 61 """ 62 dataset_test = self.trainer.datamodule.dataset_test 63 test_task_id = list(dataset_test.keys())[dataloader_idx] 64 return test_task_id
Get the test task ID from the dataloader index.
Args:
- dataloader_idx (
int): the dataloader index.
Returns:
- test_task_id (
int): the test task ID.
def
test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
66 def test_step( 67 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 68 ) -> dict[str, Tensor]: 69 r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`. 70 71 **Args:** 72 - **batch** (`Any`): a batch of test data. 73 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 74 75 **Returns:** 76 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 77 """ 78 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 79 80 x, y = batch 81 82 # get the aggregated backbone output (instead of logits) 83 agg_out_main = self.main_model.aggregated_backbone_output(x) 84 agg_out_ref = self.refretrain_model.aggregated_backbone_output(x) 85 86 outputs = { 87 "agg_out_main": agg_out_main, 88 "agg_out_ref": agg_out_ref, 89 } 90 91 if self.reforiginal_model is not None and test_task_id in self.ag_eval_task_ids: 92 logits_main = self.main_model.forward( 93 x, stage="test", task_id=test_task_id 94 )[ 95 0 96 ] # use the corresponding head to test (instead of the current task `self.task_id`) 97 acc_main = (logits_main.argmax(dim=1) == y).float().mean() 98 99 logits_full = self.reforiginal_model.forward( 100 x, stage="test", task_id=test_task_id 101 )[ 102 0 103 ] # use the corresponding head to test (instead of the current task `self.task_id`) 104 acc_full = (logits_full.argmax(dim=1) == y).float().mean() 105 106 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 107 outputs["acc_gain"] = acc_main - acc_full 108 109 return outputs
Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.
Args:
- batch (
Any): a batch of test data. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
Returns:
- outputs (
dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.