clarena.metrics.cl_acc

The submodule in metrics for CLAccuracy.

  1r"""
  2The submodule in `metrics` for `CLAccuracy`.
  3"""
  4
  5__all__ = ["CLAccuracy"]
  6
  7import csv
  8import logging
  9import os
 10from typing import Any
 11
 12import pandas as pd
 13from lightning import Trainer
 14from lightning.pytorch.utilities import rank_zero_only
 15from matplotlib import pyplot as plt
 16from torchmetrics import MeanMetric
 17
 18from clarena.cl_algorithms import CLAlgorithm
 19from clarena.metrics import MetricCallback
 20from clarena.utils.metrics import MeanMetricBatch
 21
 22# always get logger for built-in logging in each module
 23pylogger = logging.getLogger(__name__)
 24
 25
 26class CLAccuracy(MetricCallback):
 27    r"""Provides all actions that are related to CL accuracy metric, which include:
 28
 29    - Defining, initializing and recording accuracy metric.
 30    - Logging training and validation accuracy metric to Lightning loggers in real time.
 31    - Saving test accuracy metric to files.
 32    - Visualizing test accuracy metric as plots.
 33
 34    The callback is able to produce the following outputs:
 35
 36    - CSV files for test accuracy (lower triangular) matrix and average accuracy. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 37    - Coloured plot for test accuracy (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 38    - Curve plots for test average accuracy over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details.
 39
 40    Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric.
 41    """
 42
 43    def __init__(
 44        self,
 45        save_dir: str,
 46        test_acc_csv_name: str = "acc.csv",
 47        test_acc_matrix_plot_name: str | None = None,
 48        test_ave_acc_plot_name: str | None = None,
 49    ) -> None:
 50        r"""
 51        **Args:**
 52        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 53        - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file.
 54        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
 55        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
 56        """
 57        super().__init__(save_dir=save_dir)
 58
 59        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 60        r"""The path to save test accuracy matrix and average accuracy CSV file."""
 61        if test_acc_matrix_plot_name:
 62            self.test_acc_matrix_plot_path: str = os.path.join(
 63                save_dir, test_acc_matrix_plot_name
 64            )
 65            r"""The path to save test accuracy matrix plot."""
 66        if test_ave_acc_plot_name:
 67            self.test_ave_acc_plot_path: str = os.path.join(
 68                save_dir, test_ave_acc_plot_name
 69            )
 70            r"""The path to save test average accuracy curve plot."""
 71
 72        # training accumulated metrics
 73        self.acc_training_epoch: MeanMetricBatch
 74        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 75
 76        # validation accumulated metrics
 77        self.acc_val: MeanMetricBatch
 78        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
 79
 80        # test accumulated metrics
 81        self.acc_test: dict[int, MeanMetricBatch]
 82        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
 83
 84        # task ID control
 85        self.task_id: int
 86        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 87
 88    @rank_zero_only
 89    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 90        r"""Initialize training and validation metrics."""
 91
 92        # set the current task_id from the `CLAlgorithm` object
 93        self.task_id = pl_module.task_id
 94
 95        # get the device to put the metrics on the same device
 96        device = pl_module.device
 97
 98        # initialize training metrics
 99        self.acc_training_epoch = MeanMetricBatch().to(device)
100
101        # initialize validation metrics
102        self.acc_val = MeanMetricBatch().to(device)
103
104    @rank_zero_only
105    def on_train_batch_end(
106        self,
107        trainer: Trainer,
108        pl_module: CLAlgorithm,
109        outputs: dict[str, Any],
110        batch: Any,
111        batch_idx: int,
112    ) -> None:
113        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
114
115        **Args:**
116        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
117        - **batch** (`Any`): the training data batch.
118        """
119        # get the batch size
120        batch_size = len(batch)
121
122        # get training metrics values of current training batch from the outputs of the `training_step()`
123        acc_batch = outputs["acc"]
124
125        # update accumulated training metrics to calculate training metrics of the epoch
126        self.acc_training_epoch.update(acc_batch, batch_size)
127
128        # log training metrics of current training batch to Lightning loggers
129        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
130
131        # log accumulated training metrics till this training batch to Lightning loggers
132        pl_module.log(
133            f"task_{self.task_id}/train/acc",
134            self.acc_training_epoch.compute(),
135            prog_bar=True,
136        )
137
138    @rank_zero_only
139    def on_train_epoch_end(
140        self,
141        trainer: Trainer,
142        pl_module: CLAlgorithm,
143    ) -> None:
144        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
145
146        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
147        pl_module.log(
148            f"task_{self.task_id}/learning_curve/train/acc",
149            self.acc_training_epoch.compute(),
150            on_epoch=True,
151            prog_bar=True,
152        )
153
154        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
155        self.acc_training_epoch.reset()
156
157    @rank_zero_only
158    def on_validation_batch_end(
159        self,
160        trainer: Trainer,
161        pl_module: CLAlgorithm,
162        outputs: dict[str, Any],
163        batch: Any,
164        batch_idx: int,
165    ) -> None:
166        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
167
168        **Args:**
169        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
170        - **batch** (`Any`): the validation data batch.
171        """
172
173        # get the batch size
174        batch_size = len(batch)
175
176        # get the metrics values of the batch from the outputs
177        acc_batch = outputs["acc"]
178
179        # update the accumulated metrics in order to calculate the validation metrics
180        self.acc_val.update(acc_batch, batch_size)
181
182    @rank_zero_only
183    def on_validation_epoch_end(
184        self,
185        trainer: Trainer,
186        pl_module: CLAlgorithm,
187    ) -> None:
188        r"""Log validation metrics to plot learning curves."""
189
190        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
191        pl_module.log(
192            f"task_{self.task_id}/learning_curve/val/acc",
193            self.acc_val.compute(),
194            on_epoch=True,
195            prog_bar=True,
196        )
197
198    @rank_zero_only
199    def on_test_start(
200        self,
201        trainer: Trainer,
202        pl_module: CLAlgorithm,
203    ) -> None:
204        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
205
206        # set the current task_id again (double checking) from the `CLAlgorithm` object
207        self.task_id = pl_module.task_id
208
209        # get the device to put the metrics on the same device
210        device = pl_module.device
211
212        # initialize test metrics for current and previous tasks
213        self.acc_test = {
214            task_id: MeanMetricBatch().to(device)
215            for task_id in pl_module.processed_task_ids
216        }
217
218    @rank_zero_only
219    def on_test_batch_end(
220        self,
221        trainer: Trainer,
222        pl_module: CLAlgorithm,
223        outputs: dict[str, Any],
224        batch: Any,
225        batch_idx: int,
226        dataloader_idx: int = 0,
227    ) -> None:
228        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
229
230        **Args:**
231        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
232        - **batch** (`Any`): the test data batch.
233        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
234        """
235
236        # get the batch size
237        batch_size = len(batch)
238
239        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
240
241        # get the metrics values of the batch from the outputs
242        acc_batch = outputs["acc"]
243
244        # update the accumulated metrics in order to calculate the metrics of the epoch
245        self.acc_test[test_task_id].update(acc_batch, batch_size)
246
247    @rank_zero_only
248    def on_test_epoch_end(
249        self,
250        trainer: Trainer,
251        pl_module: CLAlgorithm,
252    ) -> None:
253        r"""Save and plot test metrics at the end of test."""
254
255        # save (update) the test metrics to CSV files
256        self.update_test_acc_to_csv(
257            after_training_task_id=self.task_id,
258            csv_path=self.test_acc_csv_path,
259        )
260
261        # plot the test metrics
262        if hasattr(self, "test_acc_matrix_plot_path"):
263            self.plot_test_acc_matrix_from_csv(
264                csv_path=self.test_acc_csv_path,
265                plot_path=self.test_acc_matrix_plot_path,
266            )
267        if hasattr(self, "test_ave_acc_plot_path"):
268            self.plot_test_ave_acc_curve_from_csv(
269                csv_path=self.test_acc_csv_path,
270                plot_path=self.test_ave_acc_plot_path,
271            )
272
273    def update_test_acc_to_csv(
274        self,
275        after_training_task_id: int,
276        csv_path: str,
277    ) -> None:
278        r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
279
280        **Args:**
281        - **after_training_task_id** (`int`): the task ID after training.
282        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
283        """
284        processed_task_ids = list(self.acc_test.keys())
285        fieldnames = ["after_training_task", "average_accuracy"] + [
286            f"test_on_task_{task_id}" for task_id in processed_task_ids
287        ]
288
289        new_line = {
290            "after_training_task": after_training_task_id
291        }  # construct the first column
292
293        # construct the columns and calculate the average accuracy over tasks at the same time
294        average_accuracy_over_tasks = MeanMetric().to(
295            device=next(iter(self.acc_test.values())).device
296        )
297        for task_id in processed_task_ids:
298            acc = self.acc_test[task_id].compute().item()
299            new_line[f"test_on_task_{task_id}"] = acc
300            average_accuracy_over_tasks(acc)
301        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
302
303        # write to the csv file
304        is_first = not os.path.exists(csv_path)
305        if not is_first:
306            with open(csv_path, "r", encoding="utf-8") as file:
307                lines = file.readlines()
308                del lines[0]
309        # write header
310        with open(csv_path, "w", encoding="utf-8") as file:
311            writer = csv.DictWriter(file, fieldnames=fieldnames)
312            writer.writeheader()
313        # write metrics
314        with open(csv_path, "a", encoding="utf-8") as file:
315            if not is_first:
316                file.writelines(lines)  # write the previous lines
317            writer = csv.DictWriter(file, fieldnames=fieldnames)
318            writer.writerow(new_line)
319
320    def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
321        """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.
322
323        **Args:**
324        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
325        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
326        """
327        data = pd.read_csv(csv_path)
328        processed_task_ids = [
329            int(col.replace("test_on_task_", ""))
330            for col in data.columns
331            if col.startswith("test_on_task_")
332        ]
333
334        # Get all columns that start with "test_on_task_"
335        test_task_cols = [
336            col for col in data.columns if col.startswith("test_on_task_")
337        ]
338        num_tasks = len(processed_task_ids)
339        num_rows = len(data)
340
341        # Build the accuracy matrix
342        acc_matrix = data[test_task_cols].values
343
344        fig, ax = plt.subplots(
345            figsize=(2 * num_tasks, 2 * num_rows)
346        )  # adaptive figure size
347
348        cax = ax.imshow(
349            acc_matrix,
350            interpolation="nearest",
351            cmap="Greens",
352            vmin=0,
353            vmax=1,
354            aspect="auto",
355        )
356
357        colorbar = fig.colorbar(cax)
358        yticks = colorbar.ax.get_yticks()
359        colorbar.ax.set_yticks(yticks)
360        colorbar.ax.set_yticklabels(
361            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
362        )
363
364        # Annotate each cell
365        for r in range(num_rows):
366            for c in range(r + 1):
367                ax.text(
368                    c,
369                    r,
370                    f"{acc_matrix[r, c]:.3f}",
371                    ha="center",
372                    va="center",
373                    color="black",
374                    fontsize=10 + num_tasks,
375                )
376
377        ax.set_xticks(range(num_tasks))
378        ax.set_yticks(range(num_rows))
379        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
380        ax.set_yticklabels(
381            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
382        )
383
384        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
385        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
386        fig.tight_layout()
387        fig.savefig(plot_path)
388        plt.close(fig)
389
390    def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
391        """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.
392
393        **Args:**
394        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
395        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.
396        """
397        data = pd.read_csv(csv_path)
398        after_training_tasks = data["after_training_task"].astype(int).tolist()
399
400        # plot the average accuracy curve over different training tasks
401        fig, ax = plt.subplots(figsize=(16, 9))
402        ax.plot(
403            after_training_tasks,
404            data["average_accuracy"],
405            marker="o",
406            linewidth=2,
407        )
408        ax.set_xlabel("After training task $t$", fontsize=16)
409        ax.set_ylabel("Average Accuracy (AA)", fontsize=16)
410        ax.grid(True)
411        xticks = after_training_tasks
412        yticks = [i * 0.05 for i in range(21)]
413        ax.set_xticks(xticks)
414        ax.set_yticks(yticks)
415        ax.set_xticklabels(xticks, fontsize=16)
416        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
417        fig.savefig(plot_path)
418        plt.close(fig)
class CLAccuracy(clarena.metrics.base.MetricCallback):
 27class CLAccuracy(MetricCallback):
 28    r"""Provides all actions that are related to CL accuracy metric, which include:
 29
 30    - Defining, initializing and recording accuracy metric.
 31    - Logging training and validation accuracy metric to Lightning loggers in real time.
 32    - Saving test accuracy metric to files.
 33    - Visualizing test accuracy metric as plots.
 34
 35    The callback is able to produce the following outputs:
 36
 37    - CSV files for test accuracy (lower triangular) matrix and average accuracy. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 38    - Coloured plot for test accuracy (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 39    - Curve plots for test average accuracy over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details.
 40
 41    Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric.
 42    """
 43
 44    def __init__(
 45        self,
 46        save_dir: str,
 47        test_acc_csv_name: str = "acc.csv",
 48        test_acc_matrix_plot_name: str | None = None,
 49        test_ave_acc_plot_name: str | None = None,
 50    ) -> None:
 51        r"""
 52        **Args:**
 53        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 54        - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file.
 55        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
 56        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
 57        """
 58        super().__init__(save_dir=save_dir)
 59
 60        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 61        r"""The path to save test accuracy matrix and average accuracy CSV file."""
 62        if test_acc_matrix_plot_name:
 63            self.test_acc_matrix_plot_path: str = os.path.join(
 64                save_dir, test_acc_matrix_plot_name
 65            )
 66            r"""The path to save test accuracy matrix plot."""
 67        if test_ave_acc_plot_name:
 68            self.test_ave_acc_plot_path: str = os.path.join(
 69                save_dir, test_ave_acc_plot_name
 70            )
 71            r"""The path to save test average accuracy curve plot."""
 72
 73        # training accumulated metrics
 74        self.acc_training_epoch: MeanMetricBatch
 75        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 76
 77        # validation accumulated metrics
 78        self.acc_val: MeanMetricBatch
 79        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
 80
 81        # test accumulated metrics
 82        self.acc_test: dict[int, MeanMetricBatch]
 83        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
 84
 85        # task ID control
 86        self.task_id: int
 87        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 88
 89    @rank_zero_only
 90    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 91        r"""Initialize training and validation metrics."""
 92
 93        # set the current task_id from the `CLAlgorithm` object
 94        self.task_id = pl_module.task_id
 95
 96        # get the device to put the metrics on the same device
 97        device = pl_module.device
 98
 99        # initialize training metrics
100        self.acc_training_epoch = MeanMetricBatch().to(device)
101
102        # initialize validation metrics
103        self.acc_val = MeanMetricBatch().to(device)
104
105    @rank_zero_only
106    def on_train_batch_end(
107        self,
108        trainer: Trainer,
109        pl_module: CLAlgorithm,
110        outputs: dict[str, Any],
111        batch: Any,
112        batch_idx: int,
113    ) -> None:
114        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
115
116        **Args:**
117        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
118        - **batch** (`Any`): the training data batch.
119        """
120        # get the batch size
121        batch_size = len(batch)
122
123        # get training metrics values of current training batch from the outputs of the `training_step()`
124        acc_batch = outputs["acc"]
125
126        # update accumulated training metrics to calculate training metrics of the epoch
127        self.acc_training_epoch.update(acc_batch, batch_size)
128
129        # log training metrics of current training batch to Lightning loggers
130        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
131
132        # log accumulated training metrics till this training batch to Lightning loggers
133        pl_module.log(
134            f"task_{self.task_id}/train/acc",
135            self.acc_training_epoch.compute(),
136            prog_bar=True,
137        )
138
139    @rank_zero_only
140    def on_train_epoch_end(
141        self,
142        trainer: Trainer,
143        pl_module: CLAlgorithm,
144    ) -> None:
145        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
146
147        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
148        pl_module.log(
149            f"task_{self.task_id}/learning_curve/train/acc",
150            self.acc_training_epoch.compute(),
151            on_epoch=True,
152            prog_bar=True,
153        )
154
155        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
156        self.acc_training_epoch.reset()
157
158    @rank_zero_only
159    def on_validation_batch_end(
160        self,
161        trainer: Trainer,
162        pl_module: CLAlgorithm,
163        outputs: dict[str, Any],
164        batch: Any,
165        batch_idx: int,
166    ) -> None:
167        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
168
169        **Args:**
170        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
171        - **batch** (`Any`): the validation data batch.
172        """
173
174        # get the batch size
175        batch_size = len(batch)
176
177        # get the metrics values of the batch from the outputs
178        acc_batch = outputs["acc"]
179
180        # update the accumulated metrics in order to calculate the validation metrics
181        self.acc_val.update(acc_batch, batch_size)
182
183    @rank_zero_only
184    def on_validation_epoch_end(
185        self,
186        trainer: Trainer,
187        pl_module: CLAlgorithm,
188    ) -> None:
189        r"""Log validation metrics to plot learning curves."""
190
191        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
192        pl_module.log(
193            f"task_{self.task_id}/learning_curve/val/acc",
194            self.acc_val.compute(),
195            on_epoch=True,
196            prog_bar=True,
197        )
198
199    @rank_zero_only
200    def on_test_start(
201        self,
202        trainer: Trainer,
203        pl_module: CLAlgorithm,
204    ) -> None:
205        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
206
207        # set the current task_id again (double checking) from the `CLAlgorithm` object
208        self.task_id = pl_module.task_id
209
210        # get the device to put the metrics on the same device
211        device = pl_module.device
212
213        # initialize test metrics for current and previous tasks
214        self.acc_test = {
215            task_id: MeanMetricBatch().to(device)
216            for task_id in pl_module.processed_task_ids
217        }
218
219    @rank_zero_only
220    def on_test_batch_end(
221        self,
222        trainer: Trainer,
223        pl_module: CLAlgorithm,
224        outputs: dict[str, Any],
225        batch: Any,
226        batch_idx: int,
227        dataloader_idx: int = 0,
228    ) -> None:
229        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
230
231        **Args:**
232        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
233        - **batch** (`Any`): the test data batch.
234        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
235        """
236
237        # get the batch size
238        batch_size = len(batch)
239
240        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
241
242        # get the metrics values of the batch from the outputs
243        acc_batch = outputs["acc"]
244
245        # update the accumulated metrics in order to calculate the metrics of the epoch
246        self.acc_test[test_task_id].update(acc_batch, batch_size)
247
248    @rank_zero_only
249    def on_test_epoch_end(
250        self,
251        trainer: Trainer,
252        pl_module: CLAlgorithm,
253    ) -> None:
254        r"""Save and plot test metrics at the end of test."""
255
256        # save (update) the test metrics to CSV files
257        self.update_test_acc_to_csv(
258            after_training_task_id=self.task_id,
259            csv_path=self.test_acc_csv_path,
260        )
261
262        # plot the test metrics
263        if hasattr(self, "test_acc_matrix_plot_path"):
264            self.plot_test_acc_matrix_from_csv(
265                csv_path=self.test_acc_csv_path,
266                plot_path=self.test_acc_matrix_plot_path,
267            )
268        if hasattr(self, "test_ave_acc_plot_path"):
269            self.plot_test_ave_acc_curve_from_csv(
270                csv_path=self.test_acc_csv_path,
271                plot_path=self.test_ave_acc_plot_path,
272            )
273
274    def update_test_acc_to_csv(
275        self,
276        after_training_task_id: int,
277        csv_path: str,
278    ) -> None:
279        r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
280
281        **Args:**
282        - **after_training_task_id** (`int`): the task ID after training.
283        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
284        """
285        processed_task_ids = list(self.acc_test.keys())
286        fieldnames = ["after_training_task", "average_accuracy"] + [
287            f"test_on_task_{task_id}" for task_id in processed_task_ids
288        ]
289
290        new_line = {
291            "after_training_task": after_training_task_id
292        }  # construct the first column
293
294        # construct the columns and calculate the average accuracy over tasks at the same time
295        average_accuracy_over_tasks = MeanMetric().to(
296            device=next(iter(self.acc_test.values())).device
297        )
298        for task_id in processed_task_ids:
299            acc = self.acc_test[task_id].compute().item()
300            new_line[f"test_on_task_{task_id}"] = acc
301            average_accuracy_over_tasks(acc)
302        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
303
304        # write to the csv file
305        is_first = not os.path.exists(csv_path)
306        if not is_first:
307            with open(csv_path, "r", encoding="utf-8") as file:
308                lines = file.readlines()
309                del lines[0]
310        # write header
311        with open(csv_path, "w", encoding="utf-8") as file:
312            writer = csv.DictWriter(file, fieldnames=fieldnames)
313            writer.writeheader()
314        # write metrics
315        with open(csv_path, "a", encoding="utf-8") as file:
316            if not is_first:
317                file.writelines(lines)  # write the previous lines
318            writer = csv.DictWriter(file, fieldnames=fieldnames)
319            writer.writerow(new_line)
320
321    def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
322        """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.
323
324        **Args:**
325        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
326        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
327        """
328        data = pd.read_csv(csv_path)
329        processed_task_ids = [
330            int(col.replace("test_on_task_", ""))
331            for col in data.columns
332            if col.startswith("test_on_task_")
333        ]
334
335        # Get all columns that start with "test_on_task_"
336        test_task_cols = [
337            col for col in data.columns if col.startswith("test_on_task_")
338        ]
339        num_tasks = len(processed_task_ids)
340        num_rows = len(data)
341
342        # Build the accuracy matrix
343        acc_matrix = data[test_task_cols].values
344
345        fig, ax = plt.subplots(
346            figsize=(2 * num_tasks, 2 * num_rows)
347        )  # adaptive figure size
348
349        cax = ax.imshow(
350            acc_matrix,
351            interpolation="nearest",
352            cmap="Greens",
353            vmin=0,
354            vmax=1,
355            aspect="auto",
356        )
357
358        colorbar = fig.colorbar(cax)
359        yticks = colorbar.ax.get_yticks()
360        colorbar.ax.set_yticks(yticks)
361        colorbar.ax.set_yticklabels(
362            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
363        )
364
365        # Annotate each cell
366        for r in range(num_rows):
367            for c in range(r + 1):
368                ax.text(
369                    c,
370                    r,
371                    f"{acc_matrix[r, c]:.3f}",
372                    ha="center",
373                    va="center",
374                    color="black",
375                    fontsize=10 + num_tasks,
376                )
377
378        ax.set_xticks(range(num_tasks))
379        ax.set_yticks(range(num_rows))
380        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
381        ax.set_yticklabels(
382            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
383        )
384
385        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
386        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
387        fig.tight_layout()
388        fig.savefig(plot_path)
389        plt.close(fig)
390
391    def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
392        """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.
393
394        **Args:**
395        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
396        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.
397        """
398        data = pd.read_csv(csv_path)
399        after_training_tasks = data["after_training_task"].astype(int).tolist()
400
401        # plot the average accuracy curve over different training tasks
402        fig, ax = plt.subplots(figsize=(16, 9))
403        ax.plot(
404            after_training_tasks,
405            data["average_accuracy"],
406            marker="o",
407            linewidth=2,
408        )
409        ax.set_xlabel("After training task $t$", fontsize=16)
410        ax.set_ylabel("Average Accuracy (AA)", fontsize=16)
411        ax.grid(True)
412        xticks = after_training_tasks
413        yticks = [i * 0.05 for i in range(21)]
414        ax.set_xticks(xticks)
415        ax.set_yticks(yticks)
416        ax.set_xticklabels(xticks, fontsize=16)
417        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
418        fig.savefig(plot_path)
419        plt.close(fig)

Provides all actions that are related to CL accuracy metric, which include:

  • Defining, initializing and recording accuracy metric.
  • Logging training and validation accuracy metric to Lightning loggers in real time.
  • Saving test accuracy metric to files.
  • Visualizing test accuracy metric as plots.

The callback is able to produce the following outputs:

  • CSV files for test accuracy (lower triangular) matrix and average accuracy. See here for details.
  • Coloured plot for test accuracy (lower triangular) matrix. See here for details.
  • Curve plots for test average accuracy over different training tasks. See here for details.

Please refer to the A Summary of Continual Learning Metrics to learn about this metric.

CLAccuracy( save_dir: str, test_acc_csv_name: str = 'acc.csv', test_acc_matrix_plot_name: str | None = None, test_ave_acc_plot_name: str | None = None)
44    def __init__(
45        self,
46        save_dir: str,
47        test_acc_csv_name: str = "acc.csv",
48        test_acc_matrix_plot_name: str | None = None,
49        test_ave_acc_plot_name: str | None = None,
50    ) -> None:
51        r"""
52        **Args:**
53        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
54        - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file.
55        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
56        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
57        """
58        super().__init__(save_dir=save_dir)
59
60        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
61        r"""The path to save test accuracy matrix and average accuracy CSV file."""
62        if test_acc_matrix_plot_name:
63            self.test_acc_matrix_plot_path: str = os.path.join(
64                save_dir, test_acc_matrix_plot_name
65            )
66            r"""The path to save test accuracy matrix plot."""
67        if test_ave_acc_plot_name:
68            self.test_ave_acc_plot_path: str = os.path.join(
69                save_dir, test_ave_acc_plot_name
70            )
71            r"""The path to save test average accuracy curve plot."""
72
73        # training accumulated metrics
74        self.acc_training_epoch: MeanMetricBatch
75        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
76
77        # validation accumulated metrics
78        self.acc_val: MeanMetricBatch
79        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
80
81        # test accumulated metrics
82        self.acc_test: dict[int, MeanMetricBatch]
83        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
84
85        # task ID control
86        self.task_id: int
87        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
  • test_acc_csv_name (str): file name to save test accuracy matrix and average accuracy as CSV file.
  • test_acc_matrix_plot_name (str | None): file name to save accuracy matrix plot. If None, no file will be saved.
  • test_ave_acc_plot_name (str | None): file name to save average accuracy as curve plot over different training tasks. If None, no file will be saved.
test_acc_csv_path: str

The path to save test accuracy matrix and average accuracy CSV file.

Classification accuracy of training epoch. Accumulated and calculated from the training batches. See here for details.

Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See here for details.

Test classification accuracy of the current model (self.task_id) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.

@rank_zero_only
def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
 89    @rank_zero_only
 90    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 91        r"""Initialize training and validation metrics."""
 92
 93        # set the current task_id from the `CLAlgorithm` object
 94        self.task_id = pl_module.task_id
 95
 96        # get the device to put the metrics on the same device
 97        device = pl_module.device
 98
 99        # initialize training metrics
100        self.acc_training_epoch = MeanMetricBatch().to(device)
101
102        # initialize validation metrics
103        self.acc_val = MeanMetricBatch().to(device)

Initialize training and validation metrics.

@rank_zero_only
def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
105    @rank_zero_only
106    def on_train_batch_end(
107        self,
108        trainer: Trainer,
109        pl_module: CLAlgorithm,
110        outputs: dict[str, Any],
111        batch: Any,
112        batch_idx: int,
113    ) -> None:
114        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
115
116        **Args:**
117        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
118        - **batch** (`Any`): the training data batch.
119        """
120        # get the batch size
121        batch_size = len(batch)
122
123        # get training metrics values of current training batch from the outputs of the `training_step()`
124        acc_batch = outputs["acc"]
125
126        # update accumulated training metrics to calculate training metrics of the epoch
127        self.acc_training_epoch.update(acc_batch, batch_size)
128
129        # log training metrics of current training batch to Lightning loggers
130        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
131
132        # log accumulated training metrics till this training batch to Lightning loggers
133        pl_module.log(
134            f"task_{self.task_id}/train/acc",
135            self.acc_training_epoch.compute(),
136            prog_bar=True,
137        )

Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, the returns of the training_step() method in the CLAlgorithm.
  • batch (Any): the training data batch.
@rank_zero_only
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
139    @rank_zero_only
140    def on_train_epoch_end(
141        self,
142        trainer: Trainer,
143        pl_module: CLAlgorithm,
144    ) -> None:
145        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
146
147        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
148        pl_module.log(
149            f"task_{self.task_id}/learning_curve/train/acc",
150            self.acc_training_epoch.compute(),
151            on_epoch=True,
152            prog_bar=True,
153        )
154
155        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
156        self.acc_training_epoch.reset()

Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.

@rank_zero_only
def on_validation_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
158    @rank_zero_only
159    def on_validation_batch_end(
160        self,
161        trainer: Trainer,
162        pl_module: CLAlgorithm,
163        outputs: dict[str, Any],
164        batch: Any,
165        batch_idx: int,
166    ) -> None:
167        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
168
169        **Args:**
170        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
171        - **batch** (`Any`): the validation data batch.
172        """
173
174        # get the batch size
175        batch_size = len(batch)
176
177        # get the metrics values of the batch from the outputs
178        acc_batch = outputs["acc"]
179
180        # update the accumulated metrics in order to calculate the validation metrics
181        self.acc_val.update(acc_batch, batch_size)

Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.

Args:

  • outputs (dict[str, Any]): the outputs of the validation step, which is the returns of the validation_step() method in the CLAlgorithm.
  • batch (Any): the validation data batch.
@rank_zero_only
def on_validation_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
183    @rank_zero_only
184    def on_validation_epoch_end(
185        self,
186        trainer: Trainer,
187        pl_module: CLAlgorithm,
188    ) -> None:
189        r"""Log validation metrics to plot learning curves."""
190
191        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
192        pl_module.log(
193            f"task_{self.task_id}/learning_curve/val/acc",
194            self.acc_val.compute(),
195            on_epoch=True,
196            prog_bar=True,
197        )

Log validation metrics to plot learning curves.

@rank_zero_only
def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
199    @rank_zero_only
200    def on_test_start(
201        self,
202        trainer: Trainer,
203        pl_module: CLAlgorithm,
204    ) -> None:
205        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
206
207        # set the current task_id again (double checking) from the `CLAlgorithm` object
208        self.task_id = pl_module.task_id
209
210        # get the device to put the metrics on the same device
211        device = pl_module.device
212
213        # initialize test metrics for current and previous tasks
214        self.acc_test = {
215            task_id: MeanMetricBatch().to(device)
216            for task_id in pl_module.processed_task_ids
217        }

Initialize the metrics for testing each seen task in the beginning of a task's testing.

@rank_zero_only
def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
219    @rank_zero_only
220    def on_test_batch_end(
221        self,
222        trainer: Trainer,
223        pl_module: CLAlgorithm,
224        outputs: dict[str, Any],
225        batch: Any,
226        batch_idx: int,
227        dataloader_idx: int = 0,
228    ) -> None:
229        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
230
231        **Args:**
232        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
233        - **batch** (`Any`): the test data batch.
234        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
235        """
236
237        # get the batch size
238        batch_size = len(batch)
239
240        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
241
242        # get the metrics values of the batch from the outputs
243        acc_batch = outputs["acc"]
244
245        # update the accumulated metrics in order to calculate the metrics of the epoch
246        self.acc_test[test_task_id].update(acc_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the CLAlgorithm.
  • batch (Any): the test data batch.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
@rank_zero_only
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
248    @rank_zero_only
249    def on_test_epoch_end(
250        self,
251        trainer: Trainer,
252        pl_module: CLAlgorithm,
253    ) -> None:
254        r"""Save and plot test metrics at the end of test."""
255
256        # save (update) the test metrics to CSV files
257        self.update_test_acc_to_csv(
258            after_training_task_id=self.task_id,
259            csv_path=self.test_acc_csv_path,
260        )
261
262        # plot the test metrics
263        if hasattr(self, "test_acc_matrix_plot_path"):
264            self.plot_test_acc_matrix_from_csv(
265                csv_path=self.test_acc_csv_path,
266                plot_path=self.test_acc_matrix_plot_path,
267            )
268        if hasattr(self, "test_ave_acc_plot_path"):
269            self.plot_test_ave_acc_curve_from_csv(
270                csv_path=self.test_acc_csv_path,
271                plot_path=self.test_ave_acc_plot_path,
272            )

Save and plot test metrics at the end of test.

def update_test_acc_to_csv(self, after_training_task_id: int, csv_path: str) -> None:
274    def update_test_acc_to_csv(
275        self,
276        after_training_task_id: int,
277        csv_path: str,
278    ) -> None:
279        r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
280
281        **Args:**
282        - **after_training_task_id** (`int`): the task ID after training.
283        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
284        """
285        processed_task_ids = list(self.acc_test.keys())
286        fieldnames = ["after_training_task", "average_accuracy"] + [
287            f"test_on_task_{task_id}" for task_id in processed_task_ids
288        ]
289
290        new_line = {
291            "after_training_task": after_training_task_id
292        }  # construct the first column
293
294        # construct the columns and calculate the average accuracy over tasks at the same time
295        average_accuracy_over_tasks = MeanMetric().to(
296            device=next(iter(self.acc_test.values())).device
297        )
298        for task_id in processed_task_ids:
299            acc = self.acc_test[task_id].compute().item()
300            new_line[f"test_on_task_{task_id}"] = acc
301            average_accuracy_over_tasks(acc)
302        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
303
304        # write to the csv file
305        is_first = not os.path.exists(csv_path)
306        if not is_first:
307            with open(csv_path, "r", encoding="utf-8") as file:
308                lines = file.readlines()
309                del lines[0]
310        # write header
311        with open(csv_path, "w", encoding="utf-8") as file:
312            writer = csv.DictWriter(file, fieldnames=fieldnames)
313            writer.writeheader()
314        # write metrics
315        with open(csv_path, "a", encoding="utf-8") as file:
316            if not is_first:
317                file.writelines(lines)  # write the previous lines
318            writer = csv.DictWriter(file, fieldnames=fieldnames)
319            writer.writerow(new_line)

Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.

Args:

  • after_training_task_id (int): the task ID after training.
  • csv_path (str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
321    def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
322        """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.
323
324        **Args:**
325        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
326        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
327        """
328        data = pd.read_csv(csv_path)
329        processed_task_ids = [
330            int(col.replace("test_on_task_", ""))
331            for col in data.columns
332            if col.startswith("test_on_task_")
333        ]
334
335        # Get all columns that start with "test_on_task_"
336        test_task_cols = [
337            col for col in data.columns if col.startswith("test_on_task_")
338        ]
339        num_tasks = len(processed_task_ids)
340        num_rows = len(data)
341
342        # Build the accuracy matrix
343        acc_matrix = data[test_task_cols].values
344
345        fig, ax = plt.subplots(
346            figsize=(2 * num_tasks, 2 * num_rows)
347        )  # adaptive figure size
348
349        cax = ax.imshow(
350            acc_matrix,
351            interpolation="nearest",
352            cmap="Greens",
353            vmin=0,
354            vmax=1,
355            aspect="auto",
356        )
357
358        colorbar = fig.colorbar(cax)
359        yticks = colorbar.ax.get_yticks()
360        colorbar.ax.set_yticks(yticks)
361        colorbar.ax.set_yticklabels(
362            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
363        )
364
365        # Annotate each cell
366        for r in range(num_rows):
367            for c in range(r + 1):
368                ax.text(
369                    c,
370                    r,
371                    f"{acc_matrix[r, c]:.3f}",
372                    ha="center",
373                    va="center",
374                    color="black",
375                    fontsize=10 + num_tasks,
376                )
377
378        ax.set_xticks(range(num_tasks))
379        ax.set_yticks(range(num_rows))
380        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
381        ax.set_yticklabels(
382            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
383        )
384
385        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
386        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
387        fig.tight_layout()
388        fig.savefig(plot_path)
389        plt.close(fig)

Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the CSV file where the utils.update_test_acc_to_csv() saved the test accuracy metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
391    def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
392        """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.
393
394        **Args:**
395        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
396        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.
397        """
398        data = pd.read_csv(csv_path)
399        after_training_tasks = data["after_training_task"].astype(int).tolist()
400
401        # plot the average accuracy curve over different training tasks
402        fig, ax = plt.subplots(figsize=(16, 9))
403        ax.plot(
404            after_training_tasks,
405            data["average_accuracy"],
406            marker="o",
407            linewidth=2,
408        )
409        ax.set_xlabel("After training task $t$", fontsize=16)
410        ax.set_ylabel("Average Accuracy (AA)", fontsize=16)
411        ax.grid(True)
412        xticks = after_training_tasks
413        yticks = [i * 0.05 for i in range(21)]
414        ax.set_xticks(xticks)
415        ax.set_yticks(yticks)
416        ax.set_xticklabels(xticks, fontsize=16)
417        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
418        fig.savefig(plot_path)
419        plt.close(fig)

Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the CSV file where the utils.update_test_acc_to_csv() saved the test accuracy metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.