clarena.metrics.cul_dd

The submodule in metrics for CULDistributionDistance.

  1r"""
  2The submodule in `metrics` for `CULDistributionDistance`.
  3"""
  4
  5__all__ = ["CULDistributionDistance"]
  6
  7import csv
  8import logging
  9import os
 10from typing import Any
 11
 12import pandas as pd
 13import torch
 14from lightning import Trainer
 15from lightning.pytorch.utilities import rank_zero_only
 16from matplotlib import pyplot as plt
 17from torchmetrics import MeanMetric
 18
 19from clarena.metrics import MetricCallback
 20from clarena.utils.eval import CULEvaluation
 21from clarena.utils.metrics import MeanMetricBatch, linear_cka
 22
 23# always get logger for built-in logging in each module
 24pylogger = logging.getLogger(__name__)
 25
 26
 27class CULDistributionDistance(MetricCallback):
 28    r"""Provides all actions that are related to CUL distribution distance (DD) metric, which include:
 29
 30    - Defining, initializing and recording DD metric.
 31    - Saving DD metric to files.
 32    - Visualizing DD metric as plots.
 33
 34    The callback is able to produce the following outputs:
 35
 36    - CSV files for DD in each task.
 37    - Coloured plot for DD in each task.
 38
 39    Note that this callback is designed to be used with the `CULEvaluation` module, which is a special evaluation module for continual unlearning. It is not a typical test step in the algorithm, but rather a test protocol that evaluates the performance of the model on unlearned tasks.
 40    """
 41
 42    def __init__(
 43        self,
 44        save_dir: str,
 45        distribution_distance_type: str,
 46        distribution_distance_csv_name: str = "dd.csv",
 47        distribution_distance_plot_name: str | None = None,
 48        average_scope: str = "unlearned",
 49    ) -> None:
 50        r"""
 51        **Args:**
 52        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 53        - **distribution_distance_type** (`str`): the type of distribution distance to use; one of:
 54            - 'euclidean': Euclidean distance.
 55            - 'cosine': Cosine distance. This is calculated as 1 - cosine similarity, where the cosine similarity is the mean of the cosine similarities between the main and reference model outputs across the batch.
 56            - 'linear_cka': Linear CKA distance. This is calculated as 1 - linear CKA similarity. Please note batch size should be larger than 1 when using CKA distance, otherwise the distance will be 0.0.
 57            - 'manhattan': Manhattan distance.
 58        - **distribution_distance_csv_name** (`str`): file name to save test distribution distance metrics as CSV file.
 59        - **distribution_distance_plot_name** (`str` | `None`): file name to save test distribution distance metrics as plot. If `None`, no plot will be saved.
 60        - **average_scope** (`str`): scope to compute average DD over tasks, must be one of:
 61            - 'all': compute average DD over all eval tasks.
 62            - 'remaining': compute average DD over remaining tasks (exclude unlearned tasks).
 63            - 'unlearned': compute average DD over unlearned tasks.
 64
 65        """
 66        super().__init__(save_dir=save_dir)
 67
 68        self.distribution_distance_type: str = distribution_distance_type
 69        r"""The type of distribution distance to use."""
 70
 71        # paths
 72        self.distribution_distance_csv_path: str = os.path.join(
 73            save_dir, distribution_distance_csv_name
 74        )
 75        r"""The path to save the test distribution distance metrics CSV file."""
 76        if distribution_distance_plot_name:
 77            self.distribution_distance_plot_path: str = os.path.join(
 78                save_dir, distribution_distance_plot_name
 79            )
 80            r"""The path to save the test distribution distance metrics plot file."""
 81
 82        # average scope control
 83        self.average_scope: str = average_scope
 84        r"""The scope to compute average DD over tasks."""
 85
 86        self.average_task_ids: list[int] = []
 87        r"""Task IDs used to compute the average DD. Defaults to unlearned tasks if available; otherwise all eval tasks."""
 88
 89        # test accumulated metrics
 90        self.distribution_distance: dict[int, MeanMetricBatch]
 91        r"""Distribution distance unlearning metrics for each seen task. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics."""
 92
 93        self.sanity_check()
 94
 95    def sanity_check(self) -> None:
 96        r"""Sanity check."""
 97
 98        valid_distribution_distance_types = {
 99            "euclidean",
100            "cosine",
101            "linear_cka",
102            "manhattan",
103        }
104        if self.distribution_distance_type not in valid_distribution_distance_types:
105            raise ValueError(
106                f"Invalid distribution_distance_type: {self.distribution_distance_type} in `CULDistributionDistance`. Must be one of {sorted(valid_distribution_distance_types)}."
107            )
108
109        if self.average_scope not in ["all", "remaining", "unlearned"]:
110            raise ValueError(
111                f"Invalid average_scope: {self.average_scope} in `CULDistributionDistance`. Must be one of 'all', 'remaining', or 'unlearned'."
112            )
113
114    @rank_zero_only
115    def on_test_start(
116        self,
117        trainer: Trainer,
118        pl_module: CULEvaluation,
119    ) -> None:
120        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
121
122        # get the device to put the metrics on the same device
123        device = pl_module.device
124
125        # initialize test metrics for evaluation tasks
126        self.distribution_distance = {
127            task_id: MeanMetricBatch().to(device)
128            for task_id in pl_module.dd_eval_task_ids
129        }
130
131        eval_task_ids = pl_module.dd_eval_task_ids
132        unlearned_task_ids = set(
133            getattr(pl_module.main_model, "unlearned_task_ids", [])
134        )
135        if self.average_scope == "all":
136            self.average_task_ids = eval_task_ids
137        elif self.average_scope == "unlearned":
138            self.average_task_ids = [
139                task_id for task_id in eval_task_ids if task_id in unlearned_task_ids
140            ]
141        else:
142            self.average_task_ids = [
143                task_id
144                for task_id in eval_task_ids
145                if task_id not in unlearned_task_ids
146            ]
147        if not self.average_task_ids:
148            self.average_task_ids = eval_task_ids
149
150    @rank_zero_only
151    def on_test_batch_end(
152        self,
153        trainer: Trainer,
154        pl_module: CULEvaluation,
155        outputs: dict[str, Any],
156        batch: Any,
157        batch_idx: int,
158        dataloader_idx: int = 0,
159    ) -> None:
160        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
161
162        **Args:**
163        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CULEvaluation`.
164        - **batch** (`Any`): the test data batch.
165        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
166        """
167
168        # get the batch size
169        batch_size = len(batch)
170
171        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
172        if test_task_id not in self.distribution_distance:
173            return
174
175        # get the raw outputs from the outputs dictionary
176        agg_out_main = outputs["agg_out_main"]  # aggregated outputs from the main model
177        agg_out_ref = outputs[
178            "agg_out_ref"
179        ]  # aggregated outputs from the reference model
180
181        if agg_out_main.dim() != 2:
182            raise ValueError(
183                f"Expected aggregated outputs to be (batch_size, flattened feature), i.e. 2 dimension, but got {agg_out_main.dim()}."
184            )
185
186        if agg_out_ref.dim() != 2:
187            raise ValueError(
188                f"Expected aggregated outputs to be (batch_size, flattened feature), i.e. 2 dimension, but got {agg_out_ref.dim()}."
189            )
190
191        # calculate the distribution distance between the main and reference model outputs
192        if self.distribution_distance_type == "euclidean":
193            distance_batch = torch.norm(
194                agg_out_main - agg_out_ref, p=2, dim=-1
195            ).mean()  # Euclidean distance
196
197        elif self.distribution_distance_type == "cosine":
198            distance_batch = (
199                1
200                - (
201                    torch.nn.functional.cosine_similarity(
202                        agg_out_main, agg_out_ref, dim=-1
203                    )
204                ).mean()
205            )  # cosine distance
206        elif self.distribution_distance_type == "manhattan":
207            distance_batch = torch.norm(
208                agg_out_main - agg_out_ref, p=1, dim=-1
209            ).mean()  # Manhattan distance
210        elif self.distribution_distance_type == "linear_cka":
211            distance_batch = 1 - linear_cka(agg_out_main, agg_out_ref)  # CKA distance
212        else:
213            raise ValueError(
214                f"Invalid distribution_distance_type: {self.distribution_distance_type}"
215            )
216
217        # update the accumulated metrics in order to calculate the metrics of the epoch
218        self.distribution_distance[test_task_id].update(distance_batch, batch_size)
219
220    @rank_zero_only
221    def on_test_epoch_end(
222        self,
223        trainer: Trainer,
224        pl_module: CULEvaluation,
225    ) -> None:
226        r"""Save and plot test metrics at the end of test."""
227
228        self.update_distribution_distance_to_csv(
229            csv_path=self.distribution_distance_csv_path,
230        )
231
232        if hasattr(self, "distribution_distance_plot_path"):
233            self.plot_distribution_distance_from_csv(
234                csv_path=self.distribution_distance_csv_path,
235                plot_path=self.distribution_distance_plot_path,
236            )
237
238    def update_distribution_distance_to_csv(
239        self,
240        csv_path: str,
241    ) -> None:
242        r"""Update the unlearning distribution distance metrics of unlearning tasks to CSV file.
243
244        **Args:**
245        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/unlearning_test_after_task_X/distance.csv'.
246
247        """
248
249        eval_task_ids = list(self.distribution_distance.keys())
250        average_col_name = f"average_distribution_distance_on_{self.average_scope}"
251        fieldnames = [average_col_name] + [
252            f"unlearning_test_on_task_{task_id}" for task_id in eval_task_ids
253        ]
254
255        new_line = {}
256
257        # write to the columns and calculate the average distribution distance over selected tasks
258        average_distribution_distance_over_tasks = MeanMetric().to(
259            device=next(iter(self.distribution_distance.values())).device
260        )
261        for task_id in eval_task_ids:
262            distance = self.distribution_distance[task_id].compute().item()
263            new_line[f"unlearning_test_on_task_{task_id}"] = distance
264            if task_id in self.average_task_ids:
265                average_distribution_distance_over_tasks(distance)
266        average_distribution_distance = (
267            average_distribution_distance_over_tasks.compute().item()
268        )
269        new_line[average_col_name] = average_distribution_distance
270
271        # write to the csv file
272        is_first = not os.path.exists(csv_path)
273        if not is_first:
274            with open(csv_path, "r", encoding="utf-8") as file:
275                lines = file.readlines()
276                del lines[0]
277        # write header
278        with open(csv_path, "w", encoding="utf-8") as file:
279            writer = csv.DictWriter(file, fieldnames=fieldnames)
280            writer.writeheader()
281        # write metrics
282        with open(csv_path, "a", encoding="utf-8") as file:
283            if not is_first:
284                file.writelines(lines)  # write the previous lines
285            writer = csv.DictWriter(file, fieldnames=fieldnames)
286            writer.writerow(new_line)
287
288    def plot_distribution_distance_from_csv(
289        self, csv_path: str, plot_path: str
290    ) -> None:
291        """Plot the unlearning test distance matrix over different unlearned tasks from saved CSV file and save the plot to the designated directory.
292
293        **Args:**
294        - **csv_path** (`str`): the path to the CSV file where the `update_distribution_distance_to_csv()` saved the unlearning test distance metric.
295        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/unlearning_test_after_task_X/distance.png'.
296        """
297        data = pd.read_csv(csv_path)
298
299        eval_task_ids = [
300            int(col.replace("unlearning_test_on_task_", ""))
301            for col in data.columns
302            if col.startswith("unlearning_test_on_task_")
303        ]
304        test_task_cols = [
305            col for col in data.columns if col.startswith("unlearning_test_on_task_")
306        ]
307        num_tasks = len(eval_task_ids)
308        num_rows = len(data)
309
310        # Build the distribution distance matrix
311        distance_matrix = data[test_task_cols].values
312
313        # plot the distribution distance matrix
314        fig, ax = plt.subplots(
315            figsize=(2 * num_tasks, 2 * num_rows)
316        )  # adaptive figure size
317        vmin = float(distance_matrix.min())
318        vmax = float(distance_matrix.max())
319        cax = ax.imshow(
320            distance_matrix,
321            interpolation="nearest",
322            cmap="Greens",
323            vmin=vmin,
324            vmax=vmax,
325            aspect="auto",
326        )
327
328        colorbar = fig.colorbar(cax)
329        yticks = colorbar.ax.get_yticks()
330        colorbar.ax.set_yticks(yticks)
331        colorbar.ax.set_yticklabels(
332            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
333        )  # adaptive font size
334
335        for r in range(num_rows):
336            for c in range(num_tasks):
337                ax.text(
338                    c,
339                    r,
340                    f"{distance_matrix[r, c]:.3f}",
341                    ha="center",
342                    va="center",
343                    color="black",
344                    fontsize=10 + num_tasks,  # adaptive font size
345                )
346
347        ax.set_xticks(range(num_tasks))
348        ax.set_yticks(range(num_rows))
349        ax.set_xticklabels(eval_task_ids, fontsize=10 + num_tasks)  # adaptive font size
350        ax.set_yticklabels(
351            range(1, num_rows + 1), fontsize=10 + num_rows
352        )  # adaptive font size
353
354        # Labeling the axes
355        ax.set_xlabel(
356            "Testing unlearning on task τ", fontsize=10 + num_tasks
357        )  # adaptive font size
358        ax.set_ylabel(
359            "Unlearning test after training task t", fontsize=10 + num_tasks
360        )  # adaptive font size
361        fig.tight_layout()
362        fig.savefig(plot_path)
363        plt.close(fig)
class CULDistributionDistance(clarena.metrics.base.MetricCallback):
 28class CULDistributionDistance(MetricCallback):
 29    r"""Provides all actions that are related to CUL distribution distance (DD) metric, which include:
 30
 31    - Defining, initializing and recording DD metric.
 32    - Saving DD metric to files.
 33    - Visualizing DD metric as plots.
 34
 35    The callback is able to produce the following outputs:
 36
 37    - CSV files for DD in each task.
 38    - Coloured plot for DD in each task.
 39
 40    Note that this callback is designed to be used with the `CULEvaluation` module, which is a special evaluation module for continual unlearning. It is not a typical test step in the algorithm, but rather a test protocol that evaluates the performance of the model on unlearned tasks.
 41    """
 42
 43    def __init__(
 44        self,
 45        save_dir: str,
 46        distribution_distance_type: str,
 47        distribution_distance_csv_name: str = "dd.csv",
 48        distribution_distance_plot_name: str | None = None,
 49        average_scope: str = "unlearned",
 50    ) -> None:
 51        r"""
 52        **Args:**
 53        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 54        - **distribution_distance_type** (`str`): the type of distribution distance to use; one of:
 55            - 'euclidean': Euclidean distance.
 56            - 'cosine': Cosine distance. This is calculated as 1 - cosine similarity, where the cosine similarity is the mean of the cosine similarities between the main and reference model outputs across the batch.
 57            - 'linear_cka': Linear CKA distance. This is calculated as 1 - linear CKA similarity. Please note batch size should be larger than 1 when using CKA distance, otherwise the distance will be 0.0.
 58            - 'manhattan': Manhattan distance.
 59        - **distribution_distance_csv_name** (`str`): file name to save test distribution distance metrics as CSV file.
 60        - **distribution_distance_plot_name** (`str` | `None`): file name to save test distribution distance metrics as plot. If `None`, no plot will be saved.
 61        - **average_scope** (`str`): scope to compute average DD over tasks, must be one of:
 62            - 'all': compute average DD over all eval tasks.
 63            - 'remaining': compute average DD over remaining tasks (exclude unlearned tasks).
 64            - 'unlearned': compute average DD over unlearned tasks.
 65
 66        """
 67        super().__init__(save_dir=save_dir)
 68
 69        self.distribution_distance_type: str = distribution_distance_type
 70        r"""The type of distribution distance to use."""
 71
 72        # paths
 73        self.distribution_distance_csv_path: str = os.path.join(
 74            save_dir, distribution_distance_csv_name
 75        )
 76        r"""The path to save the test distribution distance metrics CSV file."""
 77        if distribution_distance_plot_name:
 78            self.distribution_distance_plot_path: str = os.path.join(
 79                save_dir, distribution_distance_plot_name
 80            )
 81            r"""The path to save the test distribution distance metrics plot file."""
 82
 83        # average scope control
 84        self.average_scope: str = average_scope
 85        r"""The scope to compute average DD over tasks."""
 86
 87        self.average_task_ids: list[int] = []
 88        r"""Task IDs used to compute the average DD. Defaults to unlearned tasks if available; otherwise all eval tasks."""
 89
 90        # test accumulated metrics
 91        self.distribution_distance: dict[int, MeanMetricBatch]
 92        r"""Distribution distance unlearning metrics for each seen task. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics."""
 93
 94        self.sanity_check()
 95
 96    def sanity_check(self) -> None:
 97        r"""Sanity check."""
 98
 99        valid_distribution_distance_types = {
100            "euclidean",
101            "cosine",
102            "linear_cka",
103            "manhattan",
104        }
105        if self.distribution_distance_type not in valid_distribution_distance_types:
106            raise ValueError(
107                f"Invalid distribution_distance_type: {self.distribution_distance_type} in `CULDistributionDistance`. Must be one of {sorted(valid_distribution_distance_types)}."
108            )
109
110        if self.average_scope not in ["all", "remaining", "unlearned"]:
111            raise ValueError(
112                f"Invalid average_scope: {self.average_scope} in `CULDistributionDistance`. Must be one of 'all', 'remaining', or 'unlearned'."
113            )
114
115    @rank_zero_only
116    def on_test_start(
117        self,
118        trainer: Trainer,
119        pl_module: CULEvaluation,
120    ) -> None:
121        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
122
123        # get the device to put the metrics on the same device
124        device = pl_module.device
125
126        # initialize test metrics for evaluation tasks
127        self.distribution_distance = {
128            task_id: MeanMetricBatch().to(device)
129            for task_id in pl_module.dd_eval_task_ids
130        }
131
132        eval_task_ids = pl_module.dd_eval_task_ids
133        unlearned_task_ids = set(
134            getattr(pl_module.main_model, "unlearned_task_ids", [])
135        )
136        if self.average_scope == "all":
137            self.average_task_ids = eval_task_ids
138        elif self.average_scope == "unlearned":
139            self.average_task_ids = [
140                task_id for task_id in eval_task_ids if task_id in unlearned_task_ids
141            ]
142        else:
143            self.average_task_ids = [
144                task_id
145                for task_id in eval_task_ids
146                if task_id not in unlearned_task_ids
147            ]
148        if not self.average_task_ids:
149            self.average_task_ids = eval_task_ids
150
151    @rank_zero_only
152    def on_test_batch_end(
153        self,
154        trainer: Trainer,
155        pl_module: CULEvaluation,
156        outputs: dict[str, Any],
157        batch: Any,
158        batch_idx: int,
159        dataloader_idx: int = 0,
160    ) -> None:
161        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
162
163        **Args:**
164        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CULEvaluation`.
165        - **batch** (`Any`): the test data batch.
166        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
167        """
168
169        # get the batch size
170        batch_size = len(batch)
171
172        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
173        if test_task_id not in self.distribution_distance:
174            return
175
176        # get the raw outputs from the outputs dictionary
177        agg_out_main = outputs["agg_out_main"]  # aggregated outputs from the main model
178        agg_out_ref = outputs[
179            "agg_out_ref"
180        ]  # aggregated outputs from the reference model
181
182        if agg_out_main.dim() != 2:
183            raise ValueError(
184                f"Expected aggregated outputs to be (batch_size, flattened feature), i.e. 2 dimension, but got {agg_out_main.dim()}."
185            )
186
187        if agg_out_ref.dim() != 2:
188            raise ValueError(
189                f"Expected aggregated outputs to be (batch_size, flattened feature), i.e. 2 dimension, but got {agg_out_ref.dim()}."
190            )
191
192        # calculate the distribution distance between the main and reference model outputs
193        if self.distribution_distance_type == "euclidean":
194            distance_batch = torch.norm(
195                agg_out_main - agg_out_ref, p=2, dim=-1
196            ).mean()  # Euclidean distance
197
198        elif self.distribution_distance_type == "cosine":
199            distance_batch = (
200                1
201                - (
202                    torch.nn.functional.cosine_similarity(
203                        agg_out_main, agg_out_ref, dim=-1
204                    )
205                ).mean()
206            )  # cosine distance
207        elif self.distribution_distance_type == "manhattan":
208            distance_batch = torch.norm(
209                agg_out_main - agg_out_ref, p=1, dim=-1
210            ).mean()  # Manhattan distance
211        elif self.distribution_distance_type == "linear_cka":
212            distance_batch = 1 - linear_cka(agg_out_main, agg_out_ref)  # CKA distance
213        else:
214            raise ValueError(
215                f"Invalid distribution_distance_type: {self.distribution_distance_type}"
216            )
217
218        # update the accumulated metrics in order to calculate the metrics of the epoch
219        self.distribution_distance[test_task_id].update(distance_batch, batch_size)
220
221    @rank_zero_only
222    def on_test_epoch_end(
223        self,
224        trainer: Trainer,
225        pl_module: CULEvaluation,
226    ) -> None:
227        r"""Save and plot test metrics at the end of test."""
228
229        self.update_distribution_distance_to_csv(
230            csv_path=self.distribution_distance_csv_path,
231        )
232
233        if hasattr(self, "distribution_distance_plot_path"):
234            self.plot_distribution_distance_from_csv(
235                csv_path=self.distribution_distance_csv_path,
236                plot_path=self.distribution_distance_plot_path,
237            )
238
239    def update_distribution_distance_to_csv(
240        self,
241        csv_path: str,
242    ) -> None:
243        r"""Update the unlearning distribution distance metrics of unlearning tasks to CSV file.
244
245        **Args:**
246        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/unlearning_test_after_task_X/distance.csv'.
247
248        """
249
250        eval_task_ids = list(self.distribution_distance.keys())
251        average_col_name = f"average_distribution_distance_on_{self.average_scope}"
252        fieldnames = [average_col_name] + [
253            f"unlearning_test_on_task_{task_id}" for task_id in eval_task_ids
254        ]
255
256        new_line = {}
257
258        # write to the columns and calculate the average distribution distance over selected tasks
259        average_distribution_distance_over_tasks = MeanMetric().to(
260            device=next(iter(self.distribution_distance.values())).device
261        )
262        for task_id in eval_task_ids:
263            distance = self.distribution_distance[task_id].compute().item()
264            new_line[f"unlearning_test_on_task_{task_id}"] = distance
265            if task_id in self.average_task_ids:
266                average_distribution_distance_over_tasks(distance)
267        average_distribution_distance = (
268            average_distribution_distance_over_tasks.compute().item()
269        )
270        new_line[average_col_name] = average_distribution_distance
271
272        # write to the csv file
273        is_first = not os.path.exists(csv_path)
274        if not is_first:
275            with open(csv_path, "r", encoding="utf-8") as file:
276                lines = file.readlines()
277                del lines[0]
278        # write header
279        with open(csv_path, "w", encoding="utf-8") as file:
280            writer = csv.DictWriter(file, fieldnames=fieldnames)
281            writer.writeheader()
282        # write metrics
283        with open(csv_path, "a", encoding="utf-8") as file:
284            if not is_first:
285                file.writelines(lines)  # write the previous lines
286            writer = csv.DictWriter(file, fieldnames=fieldnames)
287            writer.writerow(new_line)
288
289    def plot_distribution_distance_from_csv(
290        self, csv_path: str, plot_path: str
291    ) -> None:
292        """Plot the unlearning test distance matrix over different unlearned tasks from saved CSV file and save the plot to the designated directory.
293
294        **Args:**
295        - **csv_path** (`str`): the path to the CSV file where the `update_distribution_distance_to_csv()` saved the unlearning test distance metric.
296        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/unlearning_test_after_task_X/distance.png'.
297        """
298        data = pd.read_csv(csv_path)
299
300        eval_task_ids = [
301            int(col.replace("unlearning_test_on_task_", ""))
302            for col in data.columns
303            if col.startswith("unlearning_test_on_task_")
304        ]
305        test_task_cols = [
306            col for col in data.columns if col.startswith("unlearning_test_on_task_")
307        ]
308        num_tasks = len(eval_task_ids)
309        num_rows = len(data)
310
311        # Build the distribution distance matrix
312        distance_matrix = data[test_task_cols].values
313
314        # plot the distribution distance matrix
315        fig, ax = plt.subplots(
316            figsize=(2 * num_tasks, 2 * num_rows)
317        )  # adaptive figure size
318        vmin = float(distance_matrix.min())
319        vmax = float(distance_matrix.max())
320        cax = ax.imshow(
321            distance_matrix,
322            interpolation="nearest",
323            cmap="Greens",
324            vmin=vmin,
325            vmax=vmax,
326            aspect="auto",
327        )
328
329        colorbar = fig.colorbar(cax)
330        yticks = colorbar.ax.get_yticks()
331        colorbar.ax.set_yticks(yticks)
332        colorbar.ax.set_yticklabels(
333            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
334        )  # adaptive font size
335
336        for r in range(num_rows):
337            for c in range(num_tasks):
338                ax.text(
339                    c,
340                    r,
341                    f"{distance_matrix[r, c]:.3f}",
342                    ha="center",
343                    va="center",
344                    color="black",
345                    fontsize=10 + num_tasks,  # adaptive font size
346                )
347
348        ax.set_xticks(range(num_tasks))
349        ax.set_yticks(range(num_rows))
350        ax.set_xticklabels(eval_task_ids, fontsize=10 + num_tasks)  # adaptive font size
351        ax.set_yticklabels(
352            range(1, num_rows + 1), fontsize=10 + num_rows
353        )  # adaptive font size
354
355        # Labeling the axes
356        ax.set_xlabel(
357            "Testing unlearning on task τ", fontsize=10 + num_tasks
358        )  # adaptive font size
359        ax.set_ylabel(
360            "Unlearning test after training task t", fontsize=10 + num_tasks
361        )  # adaptive font size
362        fig.tight_layout()
363        fig.savefig(plot_path)
364        plt.close(fig)

Provides all actions that are related to CUL distribution distance (DD) metric, which include:

  • Defining, initializing and recording DD metric.
  • Saving DD metric to files.
  • Visualizing DD metric as plots.

The callback is able to produce the following outputs:

  • CSV files for DD in each task.
  • Coloured plot for DD in each task.

Note that this callback is designed to be used with the CULEvaluation module, which is a special evaluation module for continual unlearning. It is not a typical test step in the algorithm, but rather a test protocol that evaluates the performance of the model on unlearned tasks.

CULDistributionDistance( save_dir: str, distribution_distance_type: str, distribution_distance_csv_name: str = 'dd.csv', distribution_distance_plot_name: str | None = None, average_scope: str = 'unlearned')
43    def __init__(
44        self,
45        save_dir: str,
46        distribution_distance_type: str,
47        distribution_distance_csv_name: str = "dd.csv",
48        distribution_distance_plot_name: str | None = None,
49        average_scope: str = "unlearned",
50    ) -> None:
51        r"""
52        **Args:**
53        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
54        - **distribution_distance_type** (`str`): the type of distribution distance to use; one of:
55            - 'euclidean': Euclidean distance.
56            - 'cosine': Cosine distance. This is calculated as 1 - cosine similarity, where the cosine similarity is the mean of the cosine similarities between the main and reference model outputs across the batch.
57            - 'linear_cka': Linear CKA distance. This is calculated as 1 - linear CKA similarity. Please note batch size should be larger than 1 when using CKA distance, otherwise the distance will be 0.0.
58            - 'manhattan': Manhattan distance.
59        - **distribution_distance_csv_name** (`str`): file name to save test distribution distance metrics as CSV file.
60        - **distribution_distance_plot_name** (`str` | `None`): file name to save test distribution distance metrics as plot. If `None`, no plot will be saved.
61        - **average_scope** (`str`): scope to compute average DD over tasks, must be one of:
62            - 'all': compute average DD over all eval tasks.
63            - 'remaining': compute average DD over remaining tasks (exclude unlearned tasks).
64            - 'unlearned': compute average DD over unlearned tasks.
65
66        """
67        super().__init__(save_dir=save_dir)
68
69        self.distribution_distance_type: str = distribution_distance_type
70        r"""The type of distribution distance to use."""
71
72        # paths
73        self.distribution_distance_csv_path: str = os.path.join(
74            save_dir, distribution_distance_csv_name
75        )
76        r"""The path to save the test distribution distance metrics CSV file."""
77        if distribution_distance_plot_name:
78            self.distribution_distance_plot_path: str = os.path.join(
79                save_dir, distribution_distance_plot_name
80            )
81            r"""The path to save the test distribution distance metrics plot file."""
82
83        # average scope control
84        self.average_scope: str = average_scope
85        r"""The scope to compute average DD over tasks."""
86
87        self.average_task_ids: list[int] = []
88        r"""Task IDs used to compute the average DD. Defaults to unlearned tasks if available; otherwise all eval tasks."""
89
90        # test accumulated metrics
91        self.distribution_distance: dict[int, MeanMetricBatch]
92        r"""Distribution distance unlearning metrics for each seen task. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics."""
93
94        self.sanity_check()

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
  • distribution_distance_type (str): the type of distribution distance to use; one of:
    • 'euclidean': Euclidean distance.
    • 'cosine': Cosine distance. This is calculated as 1 - cosine similarity, where the cosine similarity is the mean of the cosine similarities between the main and reference model outputs across the batch.
    • 'linear_cka': Linear CKA distance. This is calculated as 1 - linear CKA similarity. Please note batch size should be larger than 1 when using CKA distance, otherwise the distance will be 0.0.
    • 'manhattan': Manhattan distance.
  • distribution_distance_csv_name (str): file name to save test distribution distance metrics as CSV file.
  • distribution_distance_plot_name (str | None): file name to save test distribution distance metrics as plot. If None, no plot will be saved.
  • average_scope (str): scope to compute average DD over tasks, must be one of:
    • 'all': compute average DD over all eval tasks.
    • 'remaining': compute average DD over remaining tasks (exclude unlearned tasks).
    • 'unlearned': compute average DD over unlearned tasks.
distribution_distance_type: str

The type of distribution distance to use.

distribution_distance_csv_path: str

The path to save the test distribution distance metrics CSV file.

average_scope: str

The scope to compute average DD over tasks.

average_task_ids: list[int]

Task IDs used to compute the average DD. Defaults to unlearned tasks if available; otherwise all eval tasks.

distribution_distance: dict[int, clarena.utils.metrics.MeanMetricBatch]

Distribution distance unlearning metrics for each seen task. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics.

def sanity_check(self) -> None:
 96    def sanity_check(self) -> None:
 97        r"""Sanity check."""
 98
 99        valid_distribution_distance_types = {
100            "euclidean",
101            "cosine",
102            "linear_cka",
103            "manhattan",
104        }
105        if self.distribution_distance_type not in valid_distribution_distance_types:
106            raise ValueError(
107                f"Invalid distribution_distance_type: {self.distribution_distance_type} in `CULDistributionDistance`. Must be one of {sorted(valid_distribution_distance_types)}."
108            )
109
110        if self.average_scope not in ["all", "remaining", "unlearned"]:
111            raise ValueError(
112                f"Invalid average_scope: {self.average_scope} in `CULDistributionDistance`. Must be one of 'all', 'remaining', or 'unlearned'."
113            )

Sanity check.

@rank_zero_only
def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.utils.eval.CULEvaluation) -> None:
115    @rank_zero_only
116    def on_test_start(
117        self,
118        trainer: Trainer,
119        pl_module: CULEvaluation,
120    ) -> None:
121        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
122
123        # get the device to put the metrics on the same device
124        device = pl_module.device
125
126        # initialize test metrics for evaluation tasks
127        self.distribution_distance = {
128            task_id: MeanMetricBatch().to(device)
129            for task_id in pl_module.dd_eval_task_ids
130        }
131
132        eval_task_ids = pl_module.dd_eval_task_ids
133        unlearned_task_ids = set(
134            getattr(pl_module.main_model, "unlearned_task_ids", [])
135        )
136        if self.average_scope == "all":
137            self.average_task_ids = eval_task_ids
138        elif self.average_scope == "unlearned":
139            self.average_task_ids = [
140                task_id for task_id in eval_task_ids if task_id in unlearned_task_ids
141            ]
142        else:
143            self.average_task_ids = [
144                task_id
145                for task_id in eval_task_ids
146                if task_id not in unlearned_task_ids
147            ]
148        if not self.average_task_ids:
149            self.average_task_ids = eval_task_ids

Initialize the metrics for testing each seen task in the beginning of a task's testing.

@rank_zero_only
def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.utils.eval.CULEvaluation, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
151    @rank_zero_only
152    def on_test_batch_end(
153        self,
154        trainer: Trainer,
155        pl_module: CULEvaluation,
156        outputs: dict[str, Any],
157        batch: Any,
158        batch_idx: int,
159        dataloader_idx: int = 0,
160    ) -> None:
161        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
162
163        **Args:**
164        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CULEvaluation`.
165        - **batch** (`Any`): the test data batch.
166        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
167        """
168
169        # get the batch size
170        batch_size = len(batch)
171
172        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
173        if test_task_id not in self.distribution_distance:
174            return
175
176        # get the raw outputs from the outputs dictionary
177        agg_out_main = outputs["agg_out_main"]  # aggregated outputs from the main model
178        agg_out_ref = outputs[
179            "agg_out_ref"
180        ]  # aggregated outputs from the reference model
181
182        if agg_out_main.dim() != 2:
183            raise ValueError(
184                f"Expected aggregated outputs to be (batch_size, flattened feature), i.e. 2 dimension, but got {agg_out_main.dim()}."
185            )
186
187        if agg_out_ref.dim() != 2:
188            raise ValueError(
189                f"Expected aggregated outputs to be (batch_size, flattened feature), i.e. 2 dimension, but got {agg_out_ref.dim()}."
190            )
191
192        # calculate the distribution distance between the main and reference model outputs
193        if self.distribution_distance_type == "euclidean":
194            distance_batch = torch.norm(
195                agg_out_main - agg_out_ref, p=2, dim=-1
196            ).mean()  # Euclidean distance
197
198        elif self.distribution_distance_type == "cosine":
199            distance_batch = (
200                1
201                - (
202                    torch.nn.functional.cosine_similarity(
203                        agg_out_main, agg_out_ref, dim=-1
204                    )
205                ).mean()
206            )  # cosine distance
207        elif self.distribution_distance_type == "manhattan":
208            distance_batch = torch.norm(
209                agg_out_main - agg_out_ref, p=1, dim=-1
210            ).mean()  # Manhattan distance
211        elif self.distribution_distance_type == "linear_cka":
212            distance_batch = 1 - linear_cka(agg_out_main, agg_out_ref)  # CKA distance
213        else:
214            raise ValueError(
215                f"Invalid distribution_distance_type: {self.distribution_distance_type}"
216            )
217
218        # update the accumulated metrics in order to calculate the metrics of the epoch
219        self.distribution_distance[test_task_id].update(distance_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the CULEvaluation.
  • batch (Any): the test data batch.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
@rank_zero_only
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.utils.eval.CULEvaluation) -> None:
221    @rank_zero_only
222    def on_test_epoch_end(
223        self,
224        trainer: Trainer,
225        pl_module: CULEvaluation,
226    ) -> None:
227        r"""Save and plot test metrics at the end of test."""
228
229        self.update_distribution_distance_to_csv(
230            csv_path=self.distribution_distance_csv_path,
231        )
232
233        if hasattr(self, "distribution_distance_plot_path"):
234            self.plot_distribution_distance_from_csv(
235                csv_path=self.distribution_distance_csv_path,
236                plot_path=self.distribution_distance_plot_path,
237            )

Save and plot test metrics at the end of test.

def update_distribution_distance_to_csv(self, csv_path: str) -> None:
239    def update_distribution_distance_to_csv(
240        self,
241        csv_path: str,
242    ) -> None:
243        r"""Update the unlearning distribution distance metrics of unlearning tasks to CSV file.
244
245        **Args:**
246        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/unlearning_test_after_task_X/distance.csv'.
247
248        """
249
250        eval_task_ids = list(self.distribution_distance.keys())
251        average_col_name = f"average_distribution_distance_on_{self.average_scope}"
252        fieldnames = [average_col_name] + [
253            f"unlearning_test_on_task_{task_id}" for task_id in eval_task_ids
254        ]
255
256        new_line = {}
257
258        # write to the columns and calculate the average distribution distance over selected tasks
259        average_distribution_distance_over_tasks = MeanMetric().to(
260            device=next(iter(self.distribution_distance.values())).device
261        )
262        for task_id in eval_task_ids:
263            distance = self.distribution_distance[task_id].compute().item()
264            new_line[f"unlearning_test_on_task_{task_id}"] = distance
265            if task_id in self.average_task_ids:
266                average_distribution_distance_over_tasks(distance)
267        average_distribution_distance = (
268            average_distribution_distance_over_tasks.compute().item()
269        )
270        new_line[average_col_name] = average_distribution_distance
271
272        # write to the csv file
273        is_first = not os.path.exists(csv_path)
274        if not is_first:
275            with open(csv_path, "r", encoding="utf-8") as file:
276                lines = file.readlines()
277                del lines[0]
278        # write header
279        with open(csv_path, "w", encoding="utf-8") as file:
280            writer = csv.DictWriter(file, fieldnames=fieldnames)
281            writer.writeheader()
282        # write metrics
283        with open(csv_path, "a", encoding="utf-8") as file:
284            if not is_first:
285                file.writelines(lines)  # write the previous lines
286            writer = csv.DictWriter(file, fieldnames=fieldnames)
287            writer.writerow(new_line)

Update the unlearning distribution distance metrics of unlearning tasks to CSV file.

Args:

  • csv_path (str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/unlearning_test_after_task_X/distance.csv'.
def plot_distribution_distance_from_csv(self, csv_path: str, plot_path: str) -> None:
289    def plot_distribution_distance_from_csv(
290        self, csv_path: str, plot_path: str
291    ) -> None:
292        """Plot the unlearning test distance matrix over different unlearned tasks from saved CSV file and save the plot to the designated directory.
293
294        **Args:**
295        - **csv_path** (`str`): the path to the CSV file where the `update_distribution_distance_to_csv()` saved the unlearning test distance metric.
296        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/unlearning_test_after_task_X/distance.png'.
297        """
298        data = pd.read_csv(csv_path)
299
300        eval_task_ids = [
301            int(col.replace("unlearning_test_on_task_", ""))
302            for col in data.columns
303            if col.startswith("unlearning_test_on_task_")
304        ]
305        test_task_cols = [
306            col for col in data.columns if col.startswith("unlearning_test_on_task_")
307        ]
308        num_tasks = len(eval_task_ids)
309        num_rows = len(data)
310
311        # Build the distribution distance matrix
312        distance_matrix = data[test_task_cols].values
313
314        # plot the distribution distance matrix
315        fig, ax = plt.subplots(
316            figsize=(2 * num_tasks, 2 * num_rows)
317        )  # adaptive figure size
318        vmin = float(distance_matrix.min())
319        vmax = float(distance_matrix.max())
320        cax = ax.imshow(
321            distance_matrix,
322            interpolation="nearest",
323            cmap="Greens",
324            vmin=vmin,
325            vmax=vmax,
326            aspect="auto",
327        )
328
329        colorbar = fig.colorbar(cax)
330        yticks = colorbar.ax.get_yticks()
331        colorbar.ax.set_yticks(yticks)
332        colorbar.ax.set_yticklabels(
333            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
334        )  # adaptive font size
335
336        for r in range(num_rows):
337            for c in range(num_tasks):
338                ax.text(
339                    c,
340                    r,
341                    f"{distance_matrix[r, c]:.3f}",
342                    ha="center",
343                    va="center",
344                    color="black",
345                    fontsize=10 + num_tasks,  # adaptive font size
346                )
347
348        ax.set_xticks(range(num_tasks))
349        ax.set_yticks(range(num_rows))
350        ax.set_xticklabels(eval_task_ids, fontsize=10 + num_tasks)  # adaptive font size
351        ax.set_yticklabels(
352            range(1, num_rows + 1), fontsize=10 + num_rows
353        )  # adaptive font size
354
355        # Labeling the axes
356        ax.set_xlabel(
357            "Testing unlearning on task τ", fontsize=10 + num_tasks
358        )  # adaptive font size
359        ax.set_ylabel(
360            "Unlearning test after training task t", fontsize=10 + num_tasks
361        )  # adaptive font size
362        fig.tight_layout()
363        fig.savefig(plot_path)
364        plt.close(fig)

Plot the unlearning test distance matrix over different unlearned tasks from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the CSV file where the update_distribution_distance_to_csv() saved the unlearning test distance metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/unlearning_test_after_task_X/distance.png'.