clarena.metrics.cl_acc

The submodule in metrics for CLAccuracy.

  1r"""
  2The submodule in `metrics` for `CLAccuracy`.
  3"""
  4
  5__all__ = ["CLAccuracy"]
  6
  7import csv
  8import logging
  9import os
 10from typing import Any
 11
 12import pandas as pd
 13from lightning import Trainer
 14from lightning.pytorch.utilities import rank_zero_only
 15from matplotlib import pyplot as plt
 16from torchmetrics import MeanMetric
 17
 18from clarena.cl_algorithms import CLAlgorithm
 19from clarena.metrics import MetricCallback
 20from clarena.utils.metrics import MeanMetricBatch
 21
 22# always get logger for built-in logging in each module
 23pylogger = logging.getLogger(__name__)
 24
 25
 26class CLAccuracy(MetricCallback):
 27    r"""Provides all actions that are related to CL accuracy metric, which include:
 28
 29    - Defining, initializing and recording accuracy metric.
 30    - Logging training and validation accuracy metric to Lightning loggers in real time.
 31    - Saving test accuracy metric to files.
 32    - Visualizing test accuracy metric as plots.
 33
 34    The callback is able to produce the following outputs:
 35
 36    - CSV files for test accuracy (lower triangular) matrix and average accuracy. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 37    - Coloured plot for test accuracy (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 38    - Curve plots for test average accuracy over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details.
 39
 40    Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric.
 41    """
 42
 43    def __init__(
 44        self,
 45        save_dir: str,
 46        test_acc_csv_name: str = "acc.csv",
 47        test_acc_matrix_plot_name: str | None = None,
 48        test_ave_acc_plot_name: str | None = None,
 49    ) -> None:
 50        r"""
 51        **Args:**
 52        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 53        - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file.
 54        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
 55        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
 56        """
 57        super().__init__(save_dir=save_dir)
 58
 59        # paths
 60        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 61        r"""The path to save test accuracy matrix and average accuracy CSV file."""
 62        if test_acc_matrix_plot_name:
 63            self.test_acc_matrix_plot_path: str = os.path.join(
 64                save_dir, test_acc_matrix_plot_name
 65            )
 66            r"""The path to save test accuracy matrix plot."""
 67        if test_ave_acc_plot_name:
 68            self.test_ave_acc_plot_path: str = os.path.join(
 69                save_dir, test_ave_acc_plot_name
 70            )
 71            r"""The path to save test average accuracy curve plot."""
 72
 73        # training accumulated metrics
 74        self.acc_training_epoch: MeanMetricBatch
 75        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 76
 77        # validation accumulated metrics
 78        self.acc_val: MeanMetricBatch
 79        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
 80
 81        # test accumulated metrics
 82        self.acc_test: dict[int, MeanMetricBatch]
 83        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
 84
 85        # task ID control
 86        self.task_id: int
 87        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 88
 89    @rank_zero_only
 90    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 91        r"""Initialize training and validation metrics."""
 92
 93        # set the current task_id from the `CLAlgorithm` object
 94        self.task_id = pl_module.task_id
 95
 96        # get the device to put the metrics on the same device
 97        device = pl_module.device
 98
 99        # initialize training metrics
100        self.acc_training_epoch = MeanMetricBatch().to(device)
101
102        # initialize validation metrics
103        self.acc_val = MeanMetricBatch().to(device)
104
105    @rank_zero_only
106    def on_train_batch_end(
107        self,
108        trainer: Trainer,
109        pl_module: CLAlgorithm,
110        outputs: dict[str, Any],
111        batch: Any,
112        batch_idx: int,
113    ) -> None:
114        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
115
116        **Args:**
117        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
118        - **batch** (`Any`): the training data batch.
119        """
120        # get the batch size
121        batch_size = len(batch)
122
123        # get training metrics values of current training batch from the outputs of the `training_step()`
124        acc_batch = outputs["acc"]
125
126        # update accumulated training metrics to calculate training metrics of the epoch
127        self.acc_training_epoch.update(acc_batch, batch_size)
128
129        # log training metrics of current training batch to Lightning loggers
130        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
131
132        # log accumulated training metrics till this training batch to Lightning loggers
133        pl_module.log(
134            f"task_{self.task_id}/train/acc",
135            self.acc_training_epoch.compute(),
136            prog_bar=True,
137        )
138
139    @rank_zero_only
140    def on_train_epoch_end(
141        self,
142        trainer: Trainer,
143        pl_module: CLAlgorithm,
144    ) -> None:
145        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
146
147        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
148        pl_module.log(
149            f"task_{self.task_id}/learning_curve/train/acc",
150            self.acc_training_epoch.compute(),
151            on_epoch=True,
152            prog_bar=True,
153        )
154
155        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
156        self.acc_training_epoch.reset()
157
158    @rank_zero_only
159    def on_validation_batch_end(
160        self,
161        trainer: Trainer,
162        pl_module: CLAlgorithm,
163        outputs: dict[str, Any],
164        batch: Any,
165        batch_idx: int,
166    ) -> None:
167        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
168
169        **Args:**
170        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
171        - **batch** (`Any`): the validation data batch.
172        """
173
174        # get the batch size
175        batch_size = len(batch)
176
177        # get the metrics values of the batch from the outputs
178        acc_batch = outputs["acc"]
179
180        # update the accumulated metrics in order to calculate the validation metrics
181        self.acc_val.update(acc_batch, batch_size)
182
183    @rank_zero_only
184    def on_validation_epoch_end(
185        self,
186        trainer: Trainer,
187        pl_module: CLAlgorithm,
188    ) -> None:
189        r"""Log validation metrics to plot learning curves."""
190
191        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
192        pl_module.log(
193            f"task_{self.task_id}/learning_curve/val/acc",
194            self.acc_val.compute(),
195            on_epoch=True,
196            prog_bar=True,
197        )
198
199    @rank_zero_only
200    def on_test_start(
201        self,
202        trainer: Trainer,
203        pl_module: CLAlgorithm,
204    ) -> None:
205        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
206
207        # set the current task_id again (double checking) from the `CLAlgorithm` object
208        self.task_id = pl_module.task_id
209
210        # get the device to put the metrics on the same device
211        device = pl_module.device
212
213        # initialize test metrics for current and previous tasks
214        self.acc_test = {
215            task_id: MeanMetricBatch().to(device)
216            for task_id in pl_module.processed_task_ids
217        }
218
219    @rank_zero_only
220    def on_test_batch_end(
221        self,
222        trainer: Trainer,
223        pl_module: CLAlgorithm,
224        outputs: dict[str, Any],
225        batch: Any,
226        batch_idx: int,
227        dataloader_idx: int = 0,
228    ) -> None:
229        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
230
231        **Args:**
232        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
233        - **batch** (`Any`): the test data batch.
234        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
235        """
236
237        # get the batch size
238        batch_size = len(batch)
239
240        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
241
242        # get the metrics values of the batch from the outputs
243        acc_batch = outputs["acc"]
244
245        # update the accumulated metrics in order to calculate the metrics of the epoch
246        self.acc_test[test_task_id].update(acc_batch, batch_size)
247
248    @rank_zero_only
249    def on_test_epoch_end(
250        self,
251        trainer: Trainer,
252        pl_module: CLAlgorithm,
253    ) -> None:
254        r"""Save and plot test metrics at the end of test."""
255
256        # save (update) the test metrics to CSV files
257        self.update_test_acc_to_csv(
258            after_training_task_id=self.task_id,
259            csv_path=self.test_acc_csv_path,
260        )
261
262        # plot the test metrics
263        if hasattr(self, "test_acc_matrix_plot_path"):
264            self.plot_test_acc_matrix_from_csv(
265                csv_path=self.test_acc_csv_path,
266                plot_path=self.test_acc_matrix_plot_path,
267            )
268        if hasattr(self, "test_ave_acc_plot_path"):
269            self.plot_test_ave_acc_curve_from_csv(
270                csv_path=self.test_acc_csv_path,
271                plot_path=self.test_ave_acc_plot_path,
272            )
273
274    def update_test_acc_to_csv(
275        self,
276        after_training_task_id: int,
277        csv_path: str,
278    ) -> None:
279        r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
280
281        **Args:**
282        - **after_training_task_id** (`int`): the task ID after training.
283        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
284        """
285        processed_task_ids = list(self.acc_test.keys())
286        fieldnames = ["after_training_task", "average_accuracy"] + [
287            f"test_on_task_{task_id}" for task_id in processed_task_ids
288        ]
289
290        new_line = {
291            "after_training_task": after_training_task_id
292        }  # construct the first column
293
294        # construct the columns and calculate the average accuracy over tasks at the same time
295        average_accuracy_over_tasks = MeanMetric().to(
296            device=next(iter(self.acc_test.values())).device
297        )
298        for task_id in processed_task_ids:
299            acc = self.acc_test[task_id].compute().item()
300            new_line[f"test_on_task_{task_id}"] = acc
301            average_accuracy_over_tasks(acc)
302        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
303
304        # write to the csv file
305        is_first = not os.path.exists(csv_path)
306        if not is_first:
307            with open(csv_path, "r", encoding="utf-8") as file:
308                lines = file.readlines()
309                del lines[0]
310        # write header
311        with open(csv_path, "w", encoding="utf-8") as file:
312            writer = csv.DictWriter(file, fieldnames=fieldnames)
313            writer.writeheader()
314        # write metrics
315        with open(csv_path, "a", encoding="utf-8") as file:
316            if not is_first:
317                file.writelines(lines)  # write the previous lines
318            writer = csv.DictWriter(file, fieldnames=fieldnames)
319            writer.writerow(new_line)
320
321    def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
322        """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.
323
324        **Args:**
325        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
326        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
327        """
328        data = pd.read_csv(csv_path)
329        processed_task_ids = [
330            int(col.replace("test_on_task_", ""))
331            for col in data.columns
332            if col.startswith("test_on_task_")
333        ]
334
335        # Get all columns that start with "test_on_task_"
336        test_task_cols = [
337            col for col in data.columns if col.startswith("test_on_task_")
338        ]
339        num_tasks = len(processed_task_ids)
340        num_rows = len(data)
341
342        # Build the accuracy matrix
343        acc_matrix = data[test_task_cols].values
344
345        fig, ax = plt.subplots(
346            figsize=(2 * num_tasks, 2 * num_rows)
347        )  # adaptive figure size
348
349        cax = ax.imshow(
350            acc_matrix,
351            interpolation="nearest",
352            cmap="Greens",
353            vmin=0,
354            vmax=1,
355            aspect="auto",
356        )
357
358        colorbar = fig.colorbar(cax)
359        yticks = colorbar.ax.get_yticks()
360        colorbar.ax.set_yticks(yticks)
361        colorbar.ax.set_yticklabels(
362            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
363        )
364
365        # Annotate each cell
366        for r in range(num_rows):
367            for c in range(r + 1):
368                ax.text(
369                    c,
370                    r,
371                    f"{acc_matrix[r, c]:.3f}",
372                    ha="center",
373                    va="center",
374                    color="black",
375                    fontsize=10 + num_tasks,
376                )
377
378        ax.set_xticks(range(num_tasks))
379        ax.set_yticks(range(num_rows))
380        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
381        ax.set_yticklabels(
382            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
383        )
384
385        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
386        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
387        fig.tight_layout()
388        fig.savefig(plot_path)
389        plt.close(fig)
390
391    def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
392        """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.
393
394        **Args:**
395        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
396        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.
397        """
398        data = pd.read_csv(csv_path)
399        after_training_tasks = data["after_training_task"].astype(int).tolist()
400
401        # plot the average accuracy curve over different training tasks
402        fig, ax = plt.subplots(figsize=(16, 9))
403        ax.plot(
404            after_training_tasks,
405            data["average_accuracy"],
406            marker="o",
407            linewidth=2,
408        )
409        ax.set_xlabel("After training task $t$", fontsize=16)
410        ax.set_ylabel("Average Accuracy (AA)", fontsize=16)
411        ax.grid(True)
412        xticks = after_training_tasks
413        yticks = [i * 0.05 for i in range(21)]
414        ax.set_xticks(xticks)
415        ax.set_yticks(yticks)
416        ax.set_xticklabels(xticks, fontsize=16)
417        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
418        fig.savefig(plot_path)
419        plt.close(fig)
class CLAccuracy(clarena.metrics.base.MetricCallback):
 27class CLAccuracy(MetricCallback):
 28    r"""Provides all actions that are related to CL accuracy metric, which include:
 29
 30    - Defining, initializing and recording accuracy metric.
 31    - Logging training and validation accuracy metric to Lightning loggers in real time.
 32    - Saving test accuracy metric to files.
 33    - Visualizing test accuracy metric as plots.
 34
 35    The callback is able to produce the following outputs:
 36
 37    - CSV files for test accuracy (lower triangular) matrix and average accuracy. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 38    - Coloured plot for test accuracy (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 39    - Curve plots for test average accuracy over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details.
 40
 41    Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric.
 42    """
 43
 44    def __init__(
 45        self,
 46        save_dir: str,
 47        test_acc_csv_name: str = "acc.csv",
 48        test_acc_matrix_plot_name: str | None = None,
 49        test_ave_acc_plot_name: str | None = None,
 50    ) -> None:
 51        r"""
 52        **Args:**
 53        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 54        - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file.
 55        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
 56        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
 57        """
 58        super().__init__(save_dir=save_dir)
 59
 60        # paths
 61        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 62        r"""The path to save test accuracy matrix and average accuracy CSV file."""
 63        if test_acc_matrix_plot_name:
 64            self.test_acc_matrix_plot_path: str = os.path.join(
 65                save_dir, test_acc_matrix_plot_name
 66            )
 67            r"""The path to save test accuracy matrix plot."""
 68        if test_ave_acc_plot_name:
 69            self.test_ave_acc_plot_path: str = os.path.join(
 70                save_dir, test_ave_acc_plot_name
 71            )
 72            r"""The path to save test average accuracy curve plot."""
 73
 74        # training accumulated metrics
 75        self.acc_training_epoch: MeanMetricBatch
 76        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 77
 78        # validation accumulated metrics
 79        self.acc_val: MeanMetricBatch
 80        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
 81
 82        # test accumulated metrics
 83        self.acc_test: dict[int, MeanMetricBatch]
 84        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
 85
 86        # task ID control
 87        self.task_id: int
 88        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 89
 90    @rank_zero_only
 91    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 92        r"""Initialize training and validation metrics."""
 93
 94        # set the current task_id from the `CLAlgorithm` object
 95        self.task_id = pl_module.task_id
 96
 97        # get the device to put the metrics on the same device
 98        device = pl_module.device
 99
100        # initialize training metrics
101        self.acc_training_epoch = MeanMetricBatch().to(device)
102
103        # initialize validation metrics
104        self.acc_val = MeanMetricBatch().to(device)
105
106    @rank_zero_only
107    def on_train_batch_end(
108        self,
109        trainer: Trainer,
110        pl_module: CLAlgorithm,
111        outputs: dict[str, Any],
112        batch: Any,
113        batch_idx: int,
114    ) -> None:
115        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
116
117        **Args:**
118        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
119        - **batch** (`Any`): the training data batch.
120        """
121        # get the batch size
122        batch_size = len(batch)
123
124        # get training metrics values of current training batch from the outputs of the `training_step()`
125        acc_batch = outputs["acc"]
126
127        # update accumulated training metrics to calculate training metrics of the epoch
128        self.acc_training_epoch.update(acc_batch, batch_size)
129
130        # log training metrics of current training batch to Lightning loggers
131        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
132
133        # log accumulated training metrics till this training batch to Lightning loggers
134        pl_module.log(
135            f"task_{self.task_id}/train/acc",
136            self.acc_training_epoch.compute(),
137            prog_bar=True,
138        )
139
140    @rank_zero_only
141    def on_train_epoch_end(
142        self,
143        trainer: Trainer,
144        pl_module: CLAlgorithm,
145    ) -> None:
146        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
147
148        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
149        pl_module.log(
150            f"task_{self.task_id}/learning_curve/train/acc",
151            self.acc_training_epoch.compute(),
152            on_epoch=True,
153            prog_bar=True,
154        )
155
156        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
157        self.acc_training_epoch.reset()
158
159    @rank_zero_only
160    def on_validation_batch_end(
161        self,
162        trainer: Trainer,
163        pl_module: CLAlgorithm,
164        outputs: dict[str, Any],
165        batch: Any,
166        batch_idx: int,
167    ) -> None:
168        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
169
170        **Args:**
171        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
172        - **batch** (`Any`): the validation data batch.
173        """
174
175        # get the batch size
176        batch_size = len(batch)
177
178        # get the metrics values of the batch from the outputs
179        acc_batch = outputs["acc"]
180
181        # update the accumulated metrics in order to calculate the validation metrics
182        self.acc_val.update(acc_batch, batch_size)
183
184    @rank_zero_only
185    def on_validation_epoch_end(
186        self,
187        trainer: Trainer,
188        pl_module: CLAlgorithm,
189    ) -> None:
190        r"""Log validation metrics to plot learning curves."""
191
192        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
193        pl_module.log(
194            f"task_{self.task_id}/learning_curve/val/acc",
195            self.acc_val.compute(),
196            on_epoch=True,
197            prog_bar=True,
198        )
199
200    @rank_zero_only
201    def on_test_start(
202        self,
203        trainer: Trainer,
204        pl_module: CLAlgorithm,
205    ) -> None:
206        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
207
208        # set the current task_id again (double checking) from the `CLAlgorithm` object
209        self.task_id = pl_module.task_id
210
211        # get the device to put the metrics on the same device
212        device = pl_module.device
213
214        # initialize test metrics for current and previous tasks
215        self.acc_test = {
216            task_id: MeanMetricBatch().to(device)
217            for task_id in pl_module.processed_task_ids
218        }
219
220    @rank_zero_only
221    def on_test_batch_end(
222        self,
223        trainer: Trainer,
224        pl_module: CLAlgorithm,
225        outputs: dict[str, Any],
226        batch: Any,
227        batch_idx: int,
228        dataloader_idx: int = 0,
229    ) -> None:
230        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
231
232        **Args:**
233        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
234        - **batch** (`Any`): the test data batch.
235        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
236        """
237
238        # get the batch size
239        batch_size = len(batch)
240
241        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
242
243        # get the metrics values of the batch from the outputs
244        acc_batch = outputs["acc"]
245
246        # update the accumulated metrics in order to calculate the metrics of the epoch
247        self.acc_test[test_task_id].update(acc_batch, batch_size)
248
249    @rank_zero_only
250    def on_test_epoch_end(
251        self,
252        trainer: Trainer,
253        pl_module: CLAlgorithm,
254    ) -> None:
255        r"""Save and plot test metrics at the end of test."""
256
257        # save (update) the test metrics to CSV files
258        self.update_test_acc_to_csv(
259            after_training_task_id=self.task_id,
260            csv_path=self.test_acc_csv_path,
261        )
262
263        # plot the test metrics
264        if hasattr(self, "test_acc_matrix_plot_path"):
265            self.plot_test_acc_matrix_from_csv(
266                csv_path=self.test_acc_csv_path,
267                plot_path=self.test_acc_matrix_plot_path,
268            )
269        if hasattr(self, "test_ave_acc_plot_path"):
270            self.plot_test_ave_acc_curve_from_csv(
271                csv_path=self.test_acc_csv_path,
272                plot_path=self.test_ave_acc_plot_path,
273            )
274
275    def update_test_acc_to_csv(
276        self,
277        after_training_task_id: int,
278        csv_path: str,
279    ) -> None:
280        r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
281
282        **Args:**
283        - **after_training_task_id** (`int`): the task ID after training.
284        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
285        """
286        processed_task_ids = list(self.acc_test.keys())
287        fieldnames = ["after_training_task", "average_accuracy"] + [
288            f"test_on_task_{task_id}" for task_id in processed_task_ids
289        ]
290
291        new_line = {
292            "after_training_task": after_training_task_id
293        }  # construct the first column
294
295        # construct the columns and calculate the average accuracy over tasks at the same time
296        average_accuracy_over_tasks = MeanMetric().to(
297            device=next(iter(self.acc_test.values())).device
298        )
299        for task_id in processed_task_ids:
300            acc = self.acc_test[task_id].compute().item()
301            new_line[f"test_on_task_{task_id}"] = acc
302            average_accuracy_over_tasks(acc)
303        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
304
305        # write to the csv file
306        is_first = not os.path.exists(csv_path)
307        if not is_first:
308            with open(csv_path, "r", encoding="utf-8") as file:
309                lines = file.readlines()
310                del lines[0]
311        # write header
312        with open(csv_path, "w", encoding="utf-8") as file:
313            writer = csv.DictWriter(file, fieldnames=fieldnames)
314            writer.writeheader()
315        # write metrics
316        with open(csv_path, "a", encoding="utf-8") as file:
317            if not is_first:
318                file.writelines(lines)  # write the previous lines
319            writer = csv.DictWriter(file, fieldnames=fieldnames)
320            writer.writerow(new_line)
321
322    def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
323        """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.
324
325        **Args:**
326        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
327        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
328        """
329        data = pd.read_csv(csv_path)
330        processed_task_ids = [
331            int(col.replace("test_on_task_", ""))
332            for col in data.columns
333            if col.startswith("test_on_task_")
334        ]
335
336        # Get all columns that start with "test_on_task_"
337        test_task_cols = [
338            col for col in data.columns if col.startswith("test_on_task_")
339        ]
340        num_tasks = len(processed_task_ids)
341        num_rows = len(data)
342
343        # Build the accuracy matrix
344        acc_matrix = data[test_task_cols].values
345
346        fig, ax = plt.subplots(
347            figsize=(2 * num_tasks, 2 * num_rows)
348        )  # adaptive figure size
349
350        cax = ax.imshow(
351            acc_matrix,
352            interpolation="nearest",
353            cmap="Greens",
354            vmin=0,
355            vmax=1,
356            aspect="auto",
357        )
358
359        colorbar = fig.colorbar(cax)
360        yticks = colorbar.ax.get_yticks()
361        colorbar.ax.set_yticks(yticks)
362        colorbar.ax.set_yticklabels(
363            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
364        )
365
366        # Annotate each cell
367        for r in range(num_rows):
368            for c in range(r + 1):
369                ax.text(
370                    c,
371                    r,
372                    f"{acc_matrix[r, c]:.3f}",
373                    ha="center",
374                    va="center",
375                    color="black",
376                    fontsize=10 + num_tasks,
377                )
378
379        ax.set_xticks(range(num_tasks))
380        ax.set_yticks(range(num_rows))
381        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
382        ax.set_yticklabels(
383            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
384        )
385
386        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
387        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
388        fig.tight_layout()
389        fig.savefig(plot_path)
390        plt.close(fig)
391
392    def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
393        """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.
394
395        **Args:**
396        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
397        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.
398        """
399        data = pd.read_csv(csv_path)
400        after_training_tasks = data["after_training_task"].astype(int).tolist()
401
402        # plot the average accuracy curve over different training tasks
403        fig, ax = plt.subplots(figsize=(16, 9))
404        ax.plot(
405            after_training_tasks,
406            data["average_accuracy"],
407            marker="o",
408            linewidth=2,
409        )
410        ax.set_xlabel("After training task $t$", fontsize=16)
411        ax.set_ylabel("Average Accuracy (AA)", fontsize=16)
412        ax.grid(True)
413        xticks = after_training_tasks
414        yticks = [i * 0.05 for i in range(21)]
415        ax.set_xticks(xticks)
416        ax.set_yticks(yticks)
417        ax.set_xticklabels(xticks, fontsize=16)
418        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
419        fig.savefig(plot_path)
420        plt.close(fig)

Provides all actions that are related to CL accuracy metric, which include:

  • Defining, initializing and recording accuracy metric.
  • Logging training and validation accuracy metric to Lightning loggers in real time.
  • Saving test accuracy metric to files.
  • Visualizing test accuracy metric as plots.

The callback is able to produce the following outputs:

  • CSV files for test accuracy (lower triangular) matrix and average accuracy. See here for details.
  • Coloured plot for test accuracy (lower triangular) matrix. See here for details.
  • Curve plots for test average accuracy over different training tasks. See here for details.

Please refer to the A Summary of Continual Learning Metrics to learn about this metric.

CLAccuracy( save_dir: str, test_acc_csv_name: str = 'acc.csv', test_acc_matrix_plot_name: str | None = None, test_ave_acc_plot_name: str | None = None)
44    def __init__(
45        self,
46        save_dir: str,
47        test_acc_csv_name: str = "acc.csv",
48        test_acc_matrix_plot_name: str | None = None,
49        test_ave_acc_plot_name: str | None = None,
50    ) -> None:
51        r"""
52        **Args:**
53        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
54        - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file.
55        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
56        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
57        """
58        super().__init__(save_dir=save_dir)
59
60        # paths
61        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
62        r"""The path to save test accuracy matrix and average accuracy CSV file."""
63        if test_acc_matrix_plot_name:
64            self.test_acc_matrix_plot_path: str = os.path.join(
65                save_dir, test_acc_matrix_plot_name
66            )
67            r"""The path to save test accuracy matrix plot."""
68        if test_ave_acc_plot_name:
69            self.test_ave_acc_plot_path: str = os.path.join(
70                save_dir, test_ave_acc_plot_name
71            )
72            r"""The path to save test average accuracy curve plot."""
73
74        # training accumulated metrics
75        self.acc_training_epoch: MeanMetricBatch
76        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
77
78        # validation accumulated metrics
79        self.acc_val: MeanMetricBatch
80        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
81
82        # test accumulated metrics
83        self.acc_test: dict[int, MeanMetricBatch]
84        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
85
86        # task ID control
87        self.task_id: int
88        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
  • test_acc_csv_name (str): file name to save test accuracy matrix and average accuracy as CSV file.
  • test_acc_matrix_plot_name (str | None): file name to save accuracy matrix plot. If None, no file will be saved.
  • test_ave_acc_plot_name (str | None): file name to save average accuracy as curve plot over different training tasks. If None, no file will be saved.
test_acc_csv_path: str

The path to save test accuracy matrix and average accuracy CSV file.

Classification accuracy of training epoch. Accumulated and calculated from the training batches. See here for details.

Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See here for details.

Test classification accuracy of the current model (self.task_id) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.

@rank_zero_only
def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
 90    @rank_zero_only
 91    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 92        r"""Initialize training and validation metrics."""
 93
 94        # set the current task_id from the `CLAlgorithm` object
 95        self.task_id = pl_module.task_id
 96
 97        # get the device to put the metrics on the same device
 98        device = pl_module.device
 99
100        # initialize training metrics
101        self.acc_training_epoch = MeanMetricBatch().to(device)
102
103        # initialize validation metrics
104        self.acc_val = MeanMetricBatch().to(device)

Initialize training and validation metrics.

@rank_zero_only
def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
106    @rank_zero_only
107    def on_train_batch_end(
108        self,
109        trainer: Trainer,
110        pl_module: CLAlgorithm,
111        outputs: dict[str, Any],
112        batch: Any,
113        batch_idx: int,
114    ) -> None:
115        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
116
117        **Args:**
118        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
119        - **batch** (`Any`): the training data batch.
120        """
121        # get the batch size
122        batch_size = len(batch)
123
124        # get training metrics values of current training batch from the outputs of the `training_step()`
125        acc_batch = outputs["acc"]
126
127        # update accumulated training metrics to calculate training metrics of the epoch
128        self.acc_training_epoch.update(acc_batch, batch_size)
129
130        # log training metrics of current training batch to Lightning loggers
131        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
132
133        # log accumulated training metrics till this training batch to Lightning loggers
134        pl_module.log(
135            f"task_{self.task_id}/train/acc",
136            self.acc_training_epoch.compute(),
137            prog_bar=True,
138        )

Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, the returns of the training_step() method in the CLAlgorithm.
  • batch (Any): the training data batch.
@rank_zero_only
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
140    @rank_zero_only
141    def on_train_epoch_end(
142        self,
143        trainer: Trainer,
144        pl_module: CLAlgorithm,
145    ) -> None:
146        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
147
148        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
149        pl_module.log(
150            f"task_{self.task_id}/learning_curve/train/acc",
151            self.acc_training_epoch.compute(),
152            on_epoch=True,
153            prog_bar=True,
154        )
155
156        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
157        self.acc_training_epoch.reset()

Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.

@rank_zero_only
def on_validation_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
159    @rank_zero_only
160    def on_validation_batch_end(
161        self,
162        trainer: Trainer,
163        pl_module: CLAlgorithm,
164        outputs: dict[str, Any],
165        batch: Any,
166        batch_idx: int,
167    ) -> None:
168        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
169
170        **Args:**
171        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
172        - **batch** (`Any`): the validation data batch.
173        """
174
175        # get the batch size
176        batch_size = len(batch)
177
178        # get the metrics values of the batch from the outputs
179        acc_batch = outputs["acc"]
180
181        # update the accumulated metrics in order to calculate the validation metrics
182        self.acc_val.update(acc_batch, batch_size)

Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.

Args:

  • outputs (dict[str, Any]): the outputs of the validation step, which is the returns of the validation_step() method in the CLAlgorithm.
  • batch (Any): the validation data batch.
@rank_zero_only
def on_validation_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
184    @rank_zero_only
185    def on_validation_epoch_end(
186        self,
187        trainer: Trainer,
188        pl_module: CLAlgorithm,
189    ) -> None:
190        r"""Log validation metrics to plot learning curves."""
191
192        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
193        pl_module.log(
194            f"task_{self.task_id}/learning_curve/val/acc",
195            self.acc_val.compute(),
196            on_epoch=True,
197            prog_bar=True,
198        )

Log validation metrics to plot learning curves.

@rank_zero_only
def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
200    @rank_zero_only
201    def on_test_start(
202        self,
203        trainer: Trainer,
204        pl_module: CLAlgorithm,
205    ) -> None:
206        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
207
208        # set the current task_id again (double checking) from the `CLAlgorithm` object
209        self.task_id = pl_module.task_id
210
211        # get the device to put the metrics on the same device
212        device = pl_module.device
213
214        # initialize test metrics for current and previous tasks
215        self.acc_test = {
216            task_id: MeanMetricBatch().to(device)
217            for task_id in pl_module.processed_task_ids
218        }

Initialize the metrics for testing each seen task in the beginning of a task's testing.

@rank_zero_only
def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
220    @rank_zero_only
221    def on_test_batch_end(
222        self,
223        trainer: Trainer,
224        pl_module: CLAlgorithm,
225        outputs: dict[str, Any],
226        batch: Any,
227        batch_idx: int,
228        dataloader_idx: int = 0,
229    ) -> None:
230        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
231
232        **Args:**
233        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
234        - **batch** (`Any`): the test data batch.
235        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
236        """
237
238        # get the batch size
239        batch_size = len(batch)
240
241        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
242
243        # get the metrics values of the batch from the outputs
244        acc_batch = outputs["acc"]
245
246        # update the accumulated metrics in order to calculate the metrics of the epoch
247        self.acc_test[test_task_id].update(acc_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the CLAlgorithm.
  • batch (Any): the test data batch.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
@rank_zero_only
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
249    @rank_zero_only
250    def on_test_epoch_end(
251        self,
252        trainer: Trainer,
253        pl_module: CLAlgorithm,
254    ) -> None:
255        r"""Save and plot test metrics at the end of test."""
256
257        # save (update) the test metrics to CSV files
258        self.update_test_acc_to_csv(
259            after_training_task_id=self.task_id,
260            csv_path=self.test_acc_csv_path,
261        )
262
263        # plot the test metrics
264        if hasattr(self, "test_acc_matrix_plot_path"):
265            self.plot_test_acc_matrix_from_csv(
266                csv_path=self.test_acc_csv_path,
267                plot_path=self.test_acc_matrix_plot_path,
268            )
269        if hasattr(self, "test_ave_acc_plot_path"):
270            self.plot_test_ave_acc_curve_from_csv(
271                csv_path=self.test_acc_csv_path,
272                plot_path=self.test_ave_acc_plot_path,
273            )

Save and plot test metrics at the end of test.

def update_test_acc_to_csv(self, after_training_task_id: int, csv_path: str) -> None:
275    def update_test_acc_to_csv(
276        self,
277        after_training_task_id: int,
278        csv_path: str,
279    ) -> None:
280        r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
281
282        **Args:**
283        - **after_training_task_id** (`int`): the task ID after training.
284        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
285        """
286        processed_task_ids = list(self.acc_test.keys())
287        fieldnames = ["after_training_task", "average_accuracy"] + [
288            f"test_on_task_{task_id}" for task_id in processed_task_ids
289        ]
290
291        new_line = {
292            "after_training_task": after_training_task_id
293        }  # construct the first column
294
295        # construct the columns and calculate the average accuracy over tasks at the same time
296        average_accuracy_over_tasks = MeanMetric().to(
297            device=next(iter(self.acc_test.values())).device
298        )
299        for task_id in processed_task_ids:
300            acc = self.acc_test[task_id].compute().item()
301            new_line[f"test_on_task_{task_id}"] = acc
302            average_accuracy_over_tasks(acc)
303        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
304
305        # write to the csv file
306        is_first = not os.path.exists(csv_path)
307        if not is_first:
308            with open(csv_path, "r", encoding="utf-8") as file:
309                lines = file.readlines()
310                del lines[0]
311        # write header
312        with open(csv_path, "w", encoding="utf-8") as file:
313            writer = csv.DictWriter(file, fieldnames=fieldnames)
314            writer.writeheader()
315        # write metrics
316        with open(csv_path, "a", encoding="utf-8") as file:
317            if not is_first:
318                file.writelines(lines)  # write the previous lines
319            writer = csv.DictWriter(file, fieldnames=fieldnames)
320            writer.writerow(new_line)

Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.

Args:

  • after_training_task_id (int): the task ID after training.
  • csv_path (str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
322    def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
323        """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.
324
325        **Args:**
326        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
327        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
328        """
329        data = pd.read_csv(csv_path)
330        processed_task_ids = [
331            int(col.replace("test_on_task_", ""))
332            for col in data.columns
333            if col.startswith("test_on_task_")
334        ]
335
336        # Get all columns that start with "test_on_task_"
337        test_task_cols = [
338            col for col in data.columns if col.startswith("test_on_task_")
339        ]
340        num_tasks = len(processed_task_ids)
341        num_rows = len(data)
342
343        # Build the accuracy matrix
344        acc_matrix = data[test_task_cols].values
345
346        fig, ax = plt.subplots(
347            figsize=(2 * num_tasks, 2 * num_rows)
348        )  # adaptive figure size
349
350        cax = ax.imshow(
351            acc_matrix,
352            interpolation="nearest",
353            cmap="Greens",
354            vmin=0,
355            vmax=1,
356            aspect="auto",
357        )
358
359        colorbar = fig.colorbar(cax)
360        yticks = colorbar.ax.get_yticks()
361        colorbar.ax.set_yticks(yticks)
362        colorbar.ax.set_yticklabels(
363            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
364        )
365
366        # Annotate each cell
367        for r in range(num_rows):
368            for c in range(r + 1):
369                ax.text(
370                    c,
371                    r,
372                    f"{acc_matrix[r, c]:.3f}",
373                    ha="center",
374                    va="center",
375                    color="black",
376                    fontsize=10 + num_tasks,
377                )
378
379        ax.set_xticks(range(num_tasks))
380        ax.set_yticks(range(num_rows))
381        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
382        ax.set_yticklabels(
383            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
384        )
385
386        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
387        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
388        fig.tight_layout()
389        fig.savefig(plot_path)
390        plt.close(fig)

Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the CSV file where the utils.update_test_acc_to_csv() saved the test accuracy metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
392    def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
393        """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.
394
395        **Args:**
396        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric.
397        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.
398        """
399        data = pd.read_csv(csv_path)
400        after_training_tasks = data["after_training_task"].astype(int).tolist()
401
402        # plot the average accuracy curve over different training tasks
403        fig, ax = plt.subplots(figsize=(16, 9))
404        ax.plot(
405            after_training_tasks,
406            data["average_accuracy"],
407            marker="o",
408            linewidth=2,
409        )
410        ax.set_xlabel("After training task $t$", fontsize=16)
411        ax.set_ylabel("Average Accuracy (AA)", fontsize=16)
412        ax.grid(True)
413        xticks = after_training_tasks
414        yticks = [i * 0.05 for i in range(21)]
415        ax.set_xticks(xticks)
416        ax.set_yticks(yticks)
417        ax.set_xticklabels(xticks, fontsize=16)
418        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
419        fig.savefig(plot_path)
420        plt.close(fig)

Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the CSV file where the utils.update_test_acc_to_csv() saved the test accuracy metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.