clarena.utils.eval

The submodule in utils for evaluation utilities.

  1r"""The submodule in `utils` for evaluation utilities."""
  2
  3import logging
  4
  5from lightning import LightningModule
  6from torch import Tensor, nn
  7from torch.utils.data import DataLoader
  8
  9from clarena.cl_algorithms import CLAlgorithm
 10from clarena.cul_algorithms import CULAlgorithm
 11
 12# always get logger for built-in logging in each module
 13pylogger = logging.getLogger(__name__)
 14
 15
 16class CULEvaluation(LightningModule):
 17    r"""Full evaluation module for continual unlearning."""
 18
 19    def __init__(
 20        self,
 21        main_model: CULAlgorithm,
 22        refretrain_model: CLAlgorithm,
 23        reforiginal_model: CLAlgorithm | None,
 24        dd_eval_task_ids: list[int],
 25        ag_eval_task_ids: list[int],
 26    ):
 27        r"""
 28        **Args:**
 29        - **main_model** (`CULAlgorithm`): the main model to evaluate.
 30        - **refretrain_model** (`CLAlgorithm`): the reference retrain model to evaluate against.
 31        - **reforiginal_model** (`CLAlgorithm` | `None`): the reference original model that has been trained on all tasks. If `None`, AG-related evaluation is skipped.
 32        - **dd_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the DD on.
 33        - **ag_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the accuracy gain on.
 34        """
 35        super().__init__()
 36
 37        self.criterion = nn.CrossEntropyLoss()
 38        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
 39
 40        self.main_model = main_model
 41        r"""The main model for evaluation."""
 42        self.refretrain_model = refretrain_model
 43        r"""The reference retrain model for evaluation."""
 44        self.reforiginal_model = reforiginal_model
 45        r"""The reference original model for evaluation."""
 46
 47        self.dd_eval_task_ids: list[int] = dd_eval_task_ids
 48        r"""The task IDs to evaluate the DD on. """
 49        self.ag_eval_task_ids: list[int] = ag_eval_task_ids
 50        r"""The task IDs to evaluate the AG on. """
 51
 52    def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
 53        r"""Get the test task ID from the dataloader index.
 54
 55        **Args:**
 56        - **dataloader_idx** (`int`): the dataloader index.
 57
 58        **Returns:**
 59        - **test_task_id** (`int`): the test task ID.
 60        """
 61        dataset_test = self.trainer.datamodule.dataset_test
 62        test_task_id = list(dataset_test.keys())[dataloader_idx]
 63        return test_task_id
 64
 65    def test_step(
 66        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
 67    ) -> dict[str, Tensor]:
 68        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
 69
 70        **Args:**
 71        - **batch** (`Any`): a batch of test data.
 72        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
 73
 74        **Returns:**
 75        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 76        """
 77        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
 78
 79        x, y = batch
 80
 81        # get the aggregated backbone output (instead of logits)
 82        agg_out_main = self.main_model.aggregated_backbone_output(x)
 83        agg_out_ref = self.refretrain_model.aggregated_backbone_output(x)
 84
 85        outputs = {
 86            "agg_out_main": agg_out_main,
 87            "agg_out_ref": agg_out_ref,
 88        }
 89
 90        if self.reforiginal_model is not None and test_task_id in self.ag_eval_task_ids:
 91            logits_main = self.main_model.forward(
 92                x, stage="test", task_id=test_task_id
 93            )[
 94                0
 95            ]  # use the corresponding head to test (instead of the current task `self.task_id`)
 96            acc_main = (logits_main.argmax(dim=1) == y).float().mean()
 97
 98            logits_full = self.reforiginal_model.forward(
 99                x, stage="test", task_id=test_task_id
100            )[
101                0
102            ]  # use the corresponding head to test (instead of the current task `self.task_id`)
103            acc_full = (logits_full.argmax(dim=1) == y).float().mean()
104
105            # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
106            outputs["acc_gain"] = acc_main - acc_full
107
108        return outputs
109
110
111# print("Unlearning JS divergence results:", eval_module.results)
112
113
114# def compute_unlearning_metrics(output_dir: str) -> None:
115#     r"""Compute the unlearning metrics for the continual unlearning experiment and save the results to the `results/` in the output directory.
116
117#     **Args:**
118#     - **output_dir** (`str`): the output directory path of the continual unlearning experiment. This directory must contain a `unlearning_ref` directory, which contains the output directory of the unlearning reference experiment.
119#     """
120
121#     # initialize unlearning test metrics for unlearned tasks
122
123#     for unlearned_task_id in unlearned_task_ids:
124#         # test on the unlearned task
125
126#         test_dataloader = datamodule.test_dataloader()[
127#             f"{unlearned_task_id}"
128#         ]  # get the test data
129
130#         # set the model to evaluation mode
131#         model.to("cpu")
132#         model.eval()
133#         model_unlearning_test_reference.eval()
134
135#         for batch in test_dataloader:
136#             # unlearning test step
137#             x, _ = batch
138#             batch_size = len(batch)
139
140#             with torch.no_grad():
141
142#                 # get the aggregated backbone output (instead of logits)
143#                 aggregated_backbone_output = model.aggregated_backbone_output(x)
144#                 aggregated_backbone_output_unlearning_test_reference = (
145#                     model_unlearning_test_reference.aggregated_backbone_output(x)
146#                 )
147
148#                 # calculate the Jensen-Shannon divergence as distribution distance
149#                 js = js_div(
150#                     aggregated_backbone_output,
151#                     aggregated_backbone_output_unlearning_test_reference,
152#                 )
153
154#             print("js", js)
155#             print("js", js)
156#             print("js", js)
157#             print("js", js)
158#             print("js", js)
159#             print("js", js)
pylogger = <Logger clarena.utils.eval (WARNING)>
class CULEvaluation(lightning.pytorch.core.module.LightningModule):
 17class CULEvaluation(LightningModule):
 18    r"""Full evaluation module for continual unlearning."""
 19
 20    def __init__(
 21        self,
 22        main_model: CULAlgorithm,
 23        refretrain_model: CLAlgorithm,
 24        reforiginal_model: CLAlgorithm | None,
 25        dd_eval_task_ids: list[int],
 26        ag_eval_task_ids: list[int],
 27    ):
 28        r"""
 29        **Args:**
 30        - **main_model** (`CULAlgorithm`): the main model to evaluate.
 31        - **refretrain_model** (`CLAlgorithm`): the reference retrain model to evaluate against.
 32        - **reforiginal_model** (`CLAlgorithm` | `None`): the reference original model that has been trained on all tasks. If `None`, AG-related evaluation is skipped.
 33        - **dd_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the DD on.
 34        - **ag_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the accuracy gain on.
 35        """
 36        super().__init__()
 37
 38        self.criterion = nn.CrossEntropyLoss()
 39        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
 40
 41        self.main_model = main_model
 42        r"""The main model for evaluation."""
 43        self.refretrain_model = refretrain_model
 44        r"""The reference retrain model for evaluation."""
 45        self.reforiginal_model = reforiginal_model
 46        r"""The reference original model for evaluation."""
 47
 48        self.dd_eval_task_ids: list[int] = dd_eval_task_ids
 49        r"""The task IDs to evaluate the DD on. """
 50        self.ag_eval_task_ids: list[int] = ag_eval_task_ids
 51        r"""The task IDs to evaluate the AG on. """
 52
 53    def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
 54        r"""Get the test task ID from the dataloader index.
 55
 56        **Args:**
 57        - **dataloader_idx** (`int`): the dataloader index.
 58
 59        **Returns:**
 60        - **test_task_id** (`int`): the test task ID.
 61        """
 62        dataset_test = self.trainer.datamodule.dataset_test
 63        test_task_id = list(dataset_test.keys())[dataloader_idx]
 64        return test_task_id
 65
 66    def test_step(
 67        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
 68    ) -> dict[str, Tensor]:
 69        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
 70
 71        **Args:**
 72        - **batch** (`Any`): a batch of test data.
 73        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
 74
 75        **Returns:**
 76        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 77        """
 78        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
 79
 80        x, y = batch
 81
 82        # get the aggregated backbone output (instead of logits)
 83        agg_out_main = self.main_model.aggregated_backbone_output(x)
 84        agg_out_ref = self.refretrain_model.aggregated_backbone_output(x)
 85
 86        outputs = {
 87            "agg_out_main": agg_out_main,
 88            "agg_out_ref": agg_out_ref,
 89        }
 90
 91        if self.reforiginal_model is not None and test_task_id in self.ag_eval_task_ids:
 92            logits_main = self.main_model.forward(
 93                x, stage="test", task_id=test_task_id
 94            )[
 95                0
 96            ]  # use the corresponding head to test (instead of the current task `self.task_id`)
 97            acc_main = (logits_main.argmax(dim=1) == y).float().mean()
 98
 99            logits_full = self.reforiginal_model.forward(
100                x, stage="test", task_id=test_task_id
101            )[
102                0
103            ]  # use the corresponding head to test (instead of the current task `self.task_id`)
104            acc_full = (logits_full.argmax(dim=1) == y).float().mean()
105
106            # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
107            outputs["acc_gain"] = acc_main - acc_full
108
109        return outputs

Full evaluation module for continual unlearning.

CULEvaluation( main_model: clarena.cul_algorithms.CULAlgorithm, refretrain_model: clarena.cl_algorithms.CLAlgorithm, reforiginal_model: clarena.cl_algorithms.CLAlgorithm | None, dd_eval_task_ids: list[int], ag_eval_task_ids: list[int])
20    def __init__(
21        self,
22        main_model: CULAlgorithm,
23        refretrain_model: CLAlgorithm,
24        reforiginal_model: CLAlgorithm | None,
25        dd_eval_task_ids: list[int],
26        ag_eval_task_ids: list[int],
27    ):
28        r"""
29        **Args:**
30        - **main_model** (`CULAlgorithm`): the main model to evaluate.
31        - **refretrain_model** (`CLAlgorithm`): the reference retrain model to evaluate against.
32        - **reforiginal_model** (`CLAlgorithm` | `None`): the reference original model that has been trained on all tasks. If `None`, AG-related evaluation is skipped.
33        - **dd_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the DD on.
34        - **ag_eval_task_ids** (`list[int]`): the list of task IDs to evaluate the accuracy gain on.
35        """
36        super().__init__()
37
38        self.criterion = nn.CrossEntropyLoss()
39        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
40
41        self.main_model = main_model
42        r"""The main model for evaluation."""
43        self.refretrain_model = refretrain_model
44        r"""The reference retrain model for evaluation."""
45        self.reforiginal_model = reforiginal_model
46        r"""The reference original model for evaluation."""
47
48        self.dd_eval_task_ids: list[int] = dd_eval_task_ids
49        r"""The task IDs to evaluate the DD on. """
50        self.ag_eval_task_ids: list[int] = ag_eval_task_ids
51        r"""The task IDs to evaluate the AG on. """

Args:

  • main_model (CULAlgorithm): the main model to evaluate.
  • refretrain_model (CLAlgorithm): the reference retrain model to evaluate against.
  • reforiginal_model (CLAlgorithm | None): the reference original model that has been trained on all tasks. If None, AG-related evaluation is skipped.
  • dd_eval_task_ids (list[int]): the list of task IDs to evaluate the DD on.
  • ag_eval_task_ids (list[int]): the list of task IDs to evaluate the accuracy gain on.
criterion

The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.

main_model

The main model for evaluation.

refretrain_model

The reference retrain model for evaluation.

reforiginal_model

The reference original model for evaluation.

dd_eval_task_ids: list[int]

The task IDs to evaluate the DD on.

ag_eval_task_ids: list[int]

The task IDs to evaluate the AG on.

def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
53    def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
54        r"""Get the test task ID from the dataloader index.
55
56        **Args:**
57        - **dataloader_idx** (`int`): the dataloader index.
58
59        **Returns:**
60        - **test_task_id** (`int`): the test task ID.
61        """
62        dataset_test = self.trainer.datamodule.dataset_test
63        test_task_id = list(dataset_test.keys())[dataloader_idx]
64        return test_task_id

Get the test task ID from the dataloader index.

Args:

  • dataloader_idx (int): the dataloader index.

Returns:

  • test_task_id (int): the test task ID.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
 66    def test_step(
 67        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
 68    ) -> dict[str, Tensor]:
 69        r"""Test step for current task `self.task_id`, which tests for all seen tasks indexed by `dataloader_idx`.
 70
 71        **Args:**
 72        - **batch** (`Any`): a batch of test data.
 73        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
 74
 75        **Returns:**
 76        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and other metrics from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 77        """
 78        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
 79
 80        x, y = batch
 81
 82        # get the aggregated backbone output (instead of logits)
 83        agg_out_main = self.main_model.aggregated_backbone_output(x)
 84        agg_out_ref = self.refretrain_model.aggregated_backbone_output(x)
 85
 86        outputs = {
 87            "agg_out_main": agg_out_main,
 88            "agg_out_ref": agg_out_ref,
 89        }
 90
 91        if self.reforiginal_model is not None and test_task_id in self.ag_eval_task_ids:
 92            logits_main = self.main_model.forward(
 93                x, stage="test", task_id=test_task_id
 94            )[
 95                0
 96            ]  # use the corresponding head to test (instead of the current task `self.task_id`)
 97            acc_main = (logits_main.argmax(dim=1) == y).float().mean()
 98
 99            logits_full = self.reforiginal_model.forward(
100                x, stage="test", task_id=test_task_id
101            )[
102                0
103            ]  # use the corresponding head to test (instead of the current task `self.task_id`)
104            acc_full = (logits_full.argmax(dim=1) == y).float().mean()
105
106            # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
107            outputs["acc_gain"] = acc_main - acc_full
108
109        return outputs

Test step for current task self.task_id, which tests for all seen tasks indexed by dataloader_idx.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and other metrics from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.