clarena.metrics.cl_loss

The submodule in metrics for CLLoss.

  1r"""
  2The submodule in `metrics` for `CLLoss`.
  3"""
  4
  5__all__ = ["CLLoss"]
  6
  7import csv
  8import logging
  9import os
 10from typing import Any
 11
 12import pandas as pd
 13from lightning import Trainer
 14from lightning.pytorch.utilities import rank_zero_only
 15from matplotlib import pyplot as plt
 16from torchmetrics import MeanMetric
 17
 18from clarena.cl_algorithms import CLAlgorithm
 19from clarena.metrics import MetricCallback
 20from clarena.utils.metrics import MeanMetricBatch
 21
 22# always get logger for built-in logging in each module
 23pylogger = logging.getLogger(__name__)
 24
 25
 26class CLLoss(MetricCallback):
 27    r"""Provides all actions that are related to CL loss metrics, which include:
 28
 29    - Defining, initializing and recording loss metrics.
 30    - Logging training and validation loss metrics to Lightning loggers in real time.
 31    - Saving test loss metrics to files.
 32    - Visualizing test loss metrics as plots.
 33
 34
 35    The callback is able to produce the following outputs:
 36
 37    - CSV files for classification loss (lower triangular) matrix and average classification loss. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 38    - Coloured plot for test classification loss (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 39    - Curve plots for test average classification loss over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details.
 40
 41    Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric.
 42    """
 43
 44    def __init__(
 45        self,
 46        save_dir: str,
 47        test_loss_cls_csv_name: str = "loss_cls.csv",
 48        test_loss_cls_matrix_plot_name: str | None = None,
 49        test_ave_loss_cls_plot_name: str | None = None,
 50    ) -> None:
 51        r"""
 52        **Args:**
 53        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 54        - **test_loss_cls_csv_name**(`str`): file name to save classification loss matrix and average classification loss as CSV file.
 55        - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved.
 56        - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved.
 57        """
 58        super().__init__(save_dir=save_dir)
 59
 60        self.test_loss_cls_csv_path: str = os.path.join(
 61            save_dir, test_loss_cls_csv_name
 62        )
 63        r"""The path to save test classification loss matrix and average classification loss CSV file."""
 64        if test_loss_cls_matrix_plot_name:
 65            self.test_loss_cls_matrix_plot_path: str = os.path.join(
 66                save_dir, test_loss_cls_matrix_plot_name
 67            )
 68            r"""The path to save test classification loss matrix plot."""
 69        if test_ave_loss_cls_plot_name:
 70            self.test_ave_loss_cls_plot_path: str = os.path.join(
 71                save_dir, test_ave_loss_cls_plot_name
 72            )
 73            r"""The path to save test average classification loss curve plot."""
 74
 75        # training accumulated metrics
 76        self.loss_cls_training_epoch: MeanMetricBatch
 77        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 78        self.loss_training_epoch: MeanMetricBatch
 79        r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 80
 81        # validation accumulated metrics
 82        self.loss_cls_val: MeanMetricBatch
 83        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
 84
 85        # test accumulated metrics
 86        self.loss_cls_test: dict[int, MeanMetricBatch]
 87        r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
 88
 89        # task ID control
 90        self.task_id: int
 91        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 92
 93    @rank_zero_only
 94    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 95        r"""Initialize training and validation metrics."""
 96
 97        # set the current task_id from the `CLAlgorithm` object
 98        self.task_id = pl_module.task_id
 99
100        # get the device to put the metrics on the same device
101        device = pl_module.device
102
103        # initialize training metrics
104        self.loss_cls_training_epoch = MeanMetricBatch().to(device)
105        self.loss_training_epoch = MeanMetricBatch().to(device)
106
107        # initialize validation metrics
108        self.loss_cls_val = MeanMetricBatch().to(device)
109
110    @rank_zero_only
111    def on_train_batch_end(
112        self,
113        trainer: Trainer,
114        pl_module: CLAlgorithm,
115        outputs: dict[str, Any],
116        batch: Any,
117        batch_idx: int,
118    ) -> None:
119        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
120
121        **Args:**
122        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
123        - **batch** (`Any`): the training data batch.
124        """
125        # get the batch size
126        batch_size = len(batch)
127
128        # get training metrics values of current training batch from the outputs of the `training_step()`
129        loss_cls_batch = outputs["loss_cls"]
130        loss_batch = outputs["loss"]
131
132        # update accumulated training metrics to calculate training metrics of the epoch
133        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
134        self.loss_training_epoch.update(loss_batch, batch_size)
135
136        # log training metrics of current training batch to Lightning loggers
137        pl_module.log(
138            f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True
139        )
140        pl_module.log(
141            f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True
142        )
143
144        # log accumulated training metrics till this training batch to Lightning loggers
145        pl_module.log(
146            f"task_{self.task_id}/train/loss_cls",
147            self.loss_cls_training_epoch.compute(),
148            prog_bar=True,
149        )
150        pl_module.log(
151            f"task_{self.task_id}/train/loss",
152            self.loss_training_epoch.compute(),
153            prog_bar=True,
154        )
155
156    @rank_zero_only
157    def on_train_epoch_end(
158        self,
159        trainer: Trainer,
160        pl_module: CLAlgorithm,
161    ) -> None:
162        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
163
164        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
165        pl_module.log(
166            f"task_{self.task_id}/learning_curve/train/loss_cls",
167            self.loss_cls_training_epoch.compute(),
168            on_epoch=True,
169            prog_bar=True,
170        )
171
172        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
173        self.loss_cls_training_epoch.reset()
174        self.loss_training_epoch.reset()
175
176    @rank_zero_only
177    def on_validation_batch_end(
178        self,
179        trainer: Trainer,
180        pl_module: CLAlgorithm,
181        outputs: dict[str, Any],
182        batch: Any,
183        batch_idx: int,
184    ) -> None:
185        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
186
187        **Args:**
188        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
189        - **batch** (`Any`): the validation data batch.
190        """
191
192        # get the batch size
193        batch_size = len(batch)
194
195        # get the metrics values of the batch from the outputs
196        loss_cls_batch = outputs["loss_cls"]
197
198        # update the accumulated metrics in order to calculate the validation metrics
199        self.loss_cls_val.update(loss_cls_batch, batch_size)
200
201    @rank_zero_only
202    def on_validation_epoch_end(
203        self,
204        trainer: Trainer,
205        pl_module: CLAlgorithm,
206    ) -> None:
207        r"""Log validation metrics to plot learning curves."""
208
209        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
210        pl_module.log(
211            f"task_{self.task_id}/learning_curve/val/loss_cls",
212            self.loss_cls_val.compute(),
213            on_epoch=True,
214            prog_bar=True,
215        )
216
217    @rank_zero_only
218    def on_test_start(
219        self,
220        trainer: Trainer,
221        pl_module: CLAlgorithm,
222    ) -> None:
223        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
224
225        # set the current task_id again (double checking) from the `CLAlgorithm` object
226        self.task_id = pl_module.task_id
227
228        # get the device to put the metrics on the same device
229        device = pl_module.device
230
231        # initialize test metrics for current and previous tasks
232        self.loss_cls_test = {
233            task_id: MeanMetricBatch().to(device)
234            for task_id in pl_module.processed_task_ids
235        }
236
237    @rank_zero_only
238    def on_test_batch_end(
239        self,
240        trainer: Trainer,
241        pl_module: CLAlgorithm,
242        outputs: dict[str, Any],
243        batch: Any,
244        batch_idx: int,
245        dataloader_idx: int = 0,
246    ) -> None:
247        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
248
249        **Args:**
250        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
251        - **batch** (`Any`): the test data batch.
252        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
253        """
254
255        # get the batch size
256        batch_size = len(batch)
257
258        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
259
260        # get the metrics values of the batch from the outputs
261        loss_cls_batch = outputs["loss_cls"]
262
263        # update the accumulated metrics in order to calculate the metrics of the epoch
264        self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size)
265
266    @rank_zero_only
267    def on_test_epoch_end(
268        self,
269        trainer: Trainer,
270        pl_module: CLAlgorithm,
271    ) -> None:
272        r"""Save and plot test metrics at the end of test."""
273
274        # save (update) the test metrics to CSV files
275        self.update_test_loss_cls_to_csv(
276            after_training_task_id=self.task_id,
277            csv_path=self.test_loss_cls_csv_path,
278        )
279
280        # plot the test metrics
281        if hasattr(self, "test_loss_cls_matrix_plot_path"):
282            self.plot_test_loss_cls_matrix_from_csv(
283                csv_path=self.test_loss_cls_csv_path,
284                plot_path=self.test_loss_cls_matrix_plot_path,
285            )
286        if hasattr(self, "test_ave_loss_cls_plot_path"):
287            self.plot_test_ave_loss_cls_curve_from_csv(
288                csv_path=self.test_loss_cls_csv_path,
289                plot_path=self.test_ave_loss_cls_plot_path,
290            )
291
292    def update_test_loss_cls_to_csv(
293        self,
294        after_training_task_id: int,
295        csv_path: str,
296    ) -> None:
297        """Update the test classification loss metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
298
299        **Args:**
300        - **after_training_task_id** (`int`): the task ID after training.
301        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
302        """
303        processed_task_ids = list(self.loss_cls_test.keys())
304        fieldnames = ["after_training_task", "average_classification_loss"] + [
305            f"test_on_task_{task_id}" for task_id in processed_task_ids
306        ]
307
308        new_line = {
309            "after_training_task": after_training_task_id
310        }  # construct the first column
311
312        # write to the columns and calculate the average classification loss over tasks at the same time
313        average_classification_loss_over_tasks = MeanMetric().to(
314            device=next(iter(self.loss_cls_test.values())).device
315        )
316        for task_id in processed_task_ids:
317            loss_cls = self.loss_cls_test[task_id].compute().item()
318            new_line[f"test_on_task_{task_id}"] = loss_cls
319            average_classification_loss_over_tasks(loss_cls)
320        new_line["average_classification_loss"] = (
321            average_classification_loss_over_tasks.compute().item()
322        )
323
324        # write to the csv file
325        is_first = not os.path.exists(csv_path)
326        if not is_first:
327            with open(csv_path, "r", encoding="utf-8") as file:
328                lines = file.readlines()
329                del lines[0]
330        # write header
331        with open(csv_path, "w", encoding="utf-8") as file:
332            writer = csv.DictWriter(file, fieldnames=fieldnames)
333            writer.writeheader()
334        # write metrics
335        with open(csv_path, "a", encoding="utf-8") as file:
336            if not is_first:
337                file.writelines(lines)  # write the previous lines
338            writer = csv.DictWriter(file, fieldnames=fieldnames)
339            writer.writerow(new_line)
340
341    def plot_test_loss_cls_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
342        """Plot the test classification loss matrix from saved CSV file and save the plot to the designated directory.
343
344        **Args:**
345        - **csv_path** (`str`): the path to the CSV file where the `utils.update_loss_cls_to_csv()` saved the test classification loss metric.
346        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls_matrix.png'.
347        """
348        data = pd.read_csv(csv_path)
349        processed_task_ids = [
350            int(col.replace("test_on_task_", ""))
351            for col in data.columns
352            if col.startswith("test_on_task_")
353        ]
354
355        # Get all columns that start with "test_on_task_"
356        test_task_cols = [
357            col for col in data.columns if col.startswith("test_on_task_")
358        ]
359        num_tasks = len(processed_task_ids)
360        num_rows = len(data)
361
362        # Build the loss matrix
363        loss_matrix = data[test_task_cols].values
364
365        fig, ax = plt.subplots(
366            figsize=(2 * num_tasks, 2 * num_rows)
367        )  # adaptive figure size
368
369        cax = ax.imshow(
370            loss_matrix,
371            interpolation="nearest",
372            cmap="Greens",
373            aspect="auto",
374        )
375
376        colorbar = fig.colorbar(cax)
377        yticks = colorbar.ax.get_yticks()
378        colorbar.ax.set_yticks(yticks)
379        colorbar.ax.set_yticklabels(
380            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
381        )
382
383        # Annotate each cell
384        for r in range(num_rows):
385            for c in range(r + 1):
386                ax.text(
387                    c,
388                    r,
389                    f"{loss_matrix[r, c]:.3f}",
390                    ha="center",
391                    va="center",
392                    color="black",
393                    fontsize=10 + num_tasks,
394                )
395
396        ax.set_xticks(range(num_tasks))
397        ax.set_yticks(range(num_rows))
398        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
399        ax.set_yticklabels(
400            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
401        )
402
403        # Labeling the axes
404        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
405        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
406        fig.tight_layout()
407        fig.savefig(plot_path)
408        plt.close(fig)
409
410    def plot_test_ave_loss_cls_curve_from_csv(
411        self, csv_path: str, plot_path: str
412    ) -> None:
413        """Plot the test average classfication loss curve over different training tasks from saved CSV file and save the plot to the designated directory.
414
415        **Args:**
416        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test classfication loss metric.
417        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_loss_cls.png'.
418        """
419        data = pd.read_csv(csv_path)
420        after_training_tasks = data["after_training_task"].astype(int).tolist()
421
422        # plot the average accuracy curve over different training tasks
423        fig, ax = plt.subplots(figsize=(16, 9))
424        ax.plot(
425            after_training_tasks,
426            data["average_classification_loss"],
427            marker="o",
428            linewidth=2,
429        )
430        ax.set_xlabel("After training task $t$", fontsize=16)
431        ax.set_ylabel("Average Classification Loss", fontsize=16)
432        ax.grid(True)
433        xticks = after_training_tasks
434        yticks = [i * 0.05 for i in range(21)]
435        ax.set_xticks(xticks)
436        ax.set_yticks(yticks)
437        ax.set_xticklabels(xticks, fontsize=16)
438        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
439        fig.savefig(plot_path)
440        plt.close(fig)
class CLLoss(clarena.metrics.base.MetricCallback):
 27class CLLoss(MetricCallback):
 28    r"""Provides all actions that are related to CL loss metrics, which include:
 29
 30    - Defining, initializing and recording loss metrics.
 31    - Logging training and validation loss metrics to Lightning loggers in real time.
 32    - Saving test loss metrics to files.
 33    - Visualizing test loss metrics as plots.
 34
 35
 36    The callback is able to produce the following outputs:
 37
 38    - CSV files for classification loss (lower triangular) matrix and average classification loss. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 39    - Coloured plot for test classification loss (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 40    - Curve plots for test average classification loss over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details.
 41
 42    Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric.
 43    """
 44
 45    def __init__(
 46        self,
 47        save_dir: str,
 48        test_loss_cls_csv_name: str = "loss_cls.csv",
 49        test_loss_cls_matrix_plot_name: str | None = None,
 50        test_ave_loss_cls_plot_name: str | None = None,
 51    ) -> None:
 52        r"""
 53        **Args:**
 54        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 55        - **test_loss_cls_csv_name**(`str`): file name to save classification loss matrix and average classification loss as CSV file.
 56        - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved.
 57        - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved.
 58        """
 59        super().__init__(save_dir=save_dir)
 60
 61        self.test_loss_cls_csv_path: str = os.path.join(
 62            save_dir, test_loss_cls_csv_name
 63        )
 64        r"""The path to save test classification loss matrix and average classification loss CSV file."""
 65        if test_loss_cls_matrix_plot_name:
 66            self.test_loss_cls_matrix_plot_path: str = os.path.join(
 67                save_dir, test_loss_cls_matrix_plot_name
 68            )
 69            r"""The path to save test classification loss matrix plot."""
 70        if test_ave_loss_cls_plot_name:
 71            self.test_ave_loss_cls_plot_path: str = os.path.join(
 72                save_dir, test_ave_loss_cls_plot_name
 73            )
 74            r"""The path to save test average classification loss curve plot."""
 75
 76        # training accumulated metrics
 77        self.loss_cls_training_epoch: MeanMetricBatch
 78        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 79        self.loss_training_epoch: MeanMetricBatch
 80        r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 81
 82        # validation accumulated metrics
 83        self.loss_cls_val: MeanMetricBatch
 84        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
 85
 86        # test accumulated metrics
 87        self.loss_cls_test: dict[int, MeanMetricBatch]
 88        r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
 89
 90        # task ID control
 91        self.task_id: int
 92        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
 93
 94    @rank_zero_only
 95    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 96        r"""Initialize training and validation metrics."""
 97
 98        # set the current task_id from the `CLAlgorithm` object
 99        self.task_id = pl_module.task_id
100
101        # get the device to put the metrics on the same device
102        device = pl_module.device
103
104        # initialize training metrics
105        self.loss_cls_training_epoch = MeanMetricBatch().to(device)
106        self.loss_training_epoch = MeanMetricBatch().to(device)
107
108        # initialize validation metrics
109        self.loss_cls_val = MeanMetricBatch().to(device)
110
111    @rank_zero_only
112    def on_train_batch_end(
113        self,
114        trainer: Trainer,
115        pl_module: CLAlgorithm,
116        outputs: dict[str, Any],
117        batch: Any,
118        batch_idx: int,
119    ) -> None:
120        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
121
122        **Args:**
123        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
124        - **batch** (`Any`): the training data batch.
125        """
126        # get the batch size
127        batch_size = len(batch)
128
129        # get training metrics values of current training batch from the outputs of the `training_step()`
130        loss_cls_batch = outputs["loss_cls"]
131        loss_batch = outputs["loss"]
132
133        # update accumulated training metrics to calculate training metrics of the epoch
134        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
135        self.loss_training_epoch.update(loss_batch, batch_size)
136
137        # log training metrics of current training batch to Lightning loggers
138        pl_module.log(
139            f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True
140        )
141        pl_module.log(
142            f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True
143        )
144
145        # log accumulated training metrics till this training batch to Lightning loggers
146        pl_module.log(
147            f"task_{self.task_id}/train/loss_cls",
148            self.loss_cls_training_epoch.compute(),
149            prog_bar=True,
150        )
151        pl_module.log(
152            f"task_{self.task_id}/train/loss",
153            self.loss_training_epoch.compute(),
154            prog_bar=True,
155        )
156
157    @rank_zero_only
158    def on_train_epoch_end(
159        self,
160        trainer: Trainer,
161        pl_module: CLAlgorithm,
162    ) -> None:
163        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
164
165        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
166        pl_module.log(
167            f"task_{self.task_id}/learning_curve/train/loss_cls",
168            self.loss_cls_training_epoch.compute(),
169            on_epoch=True,
170            prog_bar=True,
171        )
172
173        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
174        self.loss_cls_training_epoch.reset()
175        self.loss_training_epoch.reset()
176
177    @rank_zero_only
178    def on_validation_batch_end(
179        self,
180        trainer: Trainer,
181        pl_module: CLAlgorithm,
182        outputs: dict[str, Any],
183        batch: Any,
184        batch_idx: int,
185    ) -> None:
186        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
187
188        **Args:**
189        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
190        - **batch** (`Any`): the validation data batch.
191        """
192
193        # get the batch size
194        batch_size = len(batch)
195
196        # get the metrics values of the batch from the outputs
197        loss_cls_batch = outputs["loss_cls"]
198
199        # update the accumulated metrics in order to calculate the validation metrics
200        self.loss_cls_val.update(loss_cls_batch, batch_size)
201
202    @rank_zero_only
203    def on_validation_epoch_end(
204        self,
205        trainer: Trainer,
206        pl_module: CLAlgorithm,
207    ) -> None:
208        r"""Log validation metrics to plot learning curves."""
209
210        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
211        pl_module.log(
212            f"task_{self.task_id}/learning_curve/val/loss_cls",
213            self.loss_cls_val.compute(),
214            on_epoch=True,
215            prog_bar=True,
216        )
217
218    @rank_zero_only
219    def on_test_start(
220        self,
221        trainer: Trainer,
222        pl_module: CLAlgorithm,
223    ) -> None:
224        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
225
226        # set the current task_id again (double checking) from the `CLAlgorithm` object
227        self.task_id = pl_module.task_id
228
229        # get the device to put the metrics on the same device
230        device = pl_module.device
231
232        # initialize test metrics for current and previous tasks
233        self.loss_cls_test = {
234            task_id: MeanMetricBatch().to(device)
235            for task_id in pl_module.processed_task_ids
236        }
237
238    @rank_zero_only
239    def on_test_batch_end(
240        self,
241        trainer: Trainer,
242        pl_module: CLAlgorithm,
243        outputs: dict[str, Any],
244        batch: Any,
245        batch_idx: int,
246        dataloader_idx: int = 0,
247    ) -> None:
248        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
249
250        **Args:**
251        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
252        - **batch** (`Any`): the test data batch.
253        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
254        """
255
256        # get the batch size
257        batch_size = len(batch)
258
259        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
260
261        # get the metrics values of the batch from the outputs
262        loss_cls_batch = outputs["loss_cls"]
263
264        # update the accumulated metrics in order to calculate the metrics of the epoch
265        self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size)
266
267    @rank_zero_only
268    def on_test_epoch_end(
269        self,
270        trainer: Trainer,
271        pl_module: CLAlgorithm,
272    ) -> None:
273        r"""Save and plot test metrics at the end of test."""
274
275        # save (update) the test metrics to CSV files
276        self.update_test_loss_cls_to_csv(
277            after_training_task_id=self.task_id,
278            csv_path=self.test_loss_cls_csv_path,
279        )
280
281        # plot the test metrics
282        if hasattr(self, "test_loss_cls_matrix_plot_path"):
283            self.plot_test_loss_cls_matrix_from_csv(
284                csv_path=self.test_loss_cls_csv_path,
285                plot_path=self.test_loss_cls_matrix_plot_path,
286            )
287        if hasattr(self, "test_ave_loss_cls_plot_path"):
288            self.plot_test_ave_loss_cls_curve_from_csv(
289                csv_path=self.test_loss_cls_csv_path,
290                plot_path=self.test_ave_loss_cls_plot_path,
291            )
292
293    def update_test_loss_cls_to_csv(
294        self,
295        after_training_task_id: int,
296        csv_path: str,
297    ) -> None:
298        """Update the test classification loss metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
299
300        **Args:**
301        - **after_training_task_id** (`int`): the task ID after training.
302        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
303        """
304        processed_task_ids = list(self.loss_cls_test.keys())
305        fieldnames = ["after_training_task", "average_classification_loss"] + [
306            f"test_on_task_{task_id}" for task_id in processed_task_ids
307        ]
308
309        new_line = {
310            "after_training_task": after_training_task_id
311        }  # construct the first column
312
313        # write to the columns and calculate the average classification loss over tasks at the same time
314        average_classification_loss_over_tasks = MeanMetric().to(
315            device=next(iter(self.loss_cls_test.values())).device
316        )
317        for task_id in processed_task_ids:
318            loss_cls = self.loss_cls_test[task_id].compute().item()
319            new_line[f"test_on_task_{task_id}"] = loss_cls
320            average_classification_loss_over_tasks(loss_cls)
321        new_line["average_classification_loss"] = (
322            average_classification_loss_over_tasks.compute().item()
323        )
324
325        # write to the csv file
326        is_first = not os.path.exists(csv_path)
327        if not is_first:
328            with open(csv_path, "r", encoding="utf-8") as file:
329                lines = file.readlines()
330                del lines[0]
331        # write header
332        with open(csv_path, "w", encoding="utf-8") as file:
333            writer = csv.DictWriter(file, fieldnames=fieldnames)
334            writer.writeheader()
335        # write metrics
336        with open(csv_path, "a", encoding="utf-8") as file:
337            if not is_first:
338                file.writelines(lines)  # write the previous lines
339            writer = csv.DictWriter(file, fieldnames=fieldnames)
340            writer.writerow(new_line)
341
342    def plot_test_loss_cls_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
343        """Plot the test classification loss matrix from saved CSV file and save the plot to the designated directory.
344
345        **Args:**
346        - **csv_path** (`str`): the path to the CSV file where the `utils.update_loss_cls_to_csv()` saved the test classification loss metric.
347        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls_matrix.png'.
348        """
349        data = pd.read_csv(csv_path)
350        processed_task_ids = [
351            int(col.replace("test_on_task_", ""))
352            for col in data.columns
353            if col.startswith("test_on_task_")
354        ]
355
356        # Get all columns that start with "test_on_task_"
357        test_task_cols = [
358            col for col in data.columns if col.startswith("test_on_task_")
359        ]
360        num_tasks = len(processed_task_ids)
361        num_rows = len(data)
362
363        # Build the loss matrix
364        loss_matrix = data[test_task_cols].values
365
366        fig, ax = plt.subplots(
367            figsize=(2 * num_tasks, 2 * num_rows)
368        )  # adaptive figure size
369
370        cax = ax.imshow(
371            loss_matrix,
372            interpolation="nearest",
373            cmap="Greens",
374            aspect="auto",
375        )
376
377        colorbar = fig.colorbar(cax)
378        yticks = colorbar.ax.get_yticks()
379        colorbar.ax.set_yticks(yticks)
380        colorbar.ax.set_yticklabels(
381            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
382        )
383
384        # Annotate each cell
385        for r in range(num_rows):
386            for c in range(r + 1):
387                ax.text(
388                    c,
389                    r,
390                    f"{loss_matrix[r, c]:.3f}",
391                    ha="center",
392                    va="center",
393                    color="black",
394                    fontsize=10 + num_tasks,
395                )
396
397        ax.set_xticks(range(num_tasks))
398        ax.set_yticks(range(num_rows))
399        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
400        ax.set_yticklabels(
401            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
402        )
403
404        # Labeling the axes
405        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
406        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
407        fig.tight_layout()
408        fig.savefig(plot_path)
409        plt.close(fig)
410
411    def plot_test_ave_loss_cls_curve_from_csv(
412        self, csv_path: str, plot_path: str
413    ) -> None:
414        """Plot the test average classfication loss curve over different training tasks from saved CSV file and save the plot to the designated directory.
415
416        **Args:**
417        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test classfication loss metric.
418        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_loss_cls.png'.
419        """
420        data = pd.read_csv(csv_path)
421        after_training_tasks = data["after_training_task"].astype(int).tolist()
422
423        # plot the average accuracy curve over different training tasks
424        fig, ax = plt.subplots(figsize=(16, 9))
425        ax.plot(
426            after_training_tasks,
427            data["average_classification_loss"],
428            marker="o",
429            linewidth=2,
430        )
431        ax.set_xlabel("After training task $t$", fontsize=16)
432        ax.set_ylabel("Average Classification Loss", fontsize=16)
433        ax.grid(True)
434        xticks = after_training_tasks
435        yticks = [i * 0.05 for i in range(21)]
436        ax.set_xticks(xticks)
437        ax.set_yticks(yticks)
438        ax.set_xticklabels(xticks, fontsize=16)
439        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
440        fig.savefig(plot_path)
441        plt.close(fig)

Provides all actions that are related to CL loss metrics, which include:

  • Defining, initializing and recording loss metrics.
  • Logging training and validation loss metrics to Lightning loggers in real time.
  • Saving test loss metrics to files.
  • Visualizing test loss metrics as plots.

The callback is able to produce the following outputs:

  • CSV files for classification loss (lower triangular) matrix and average classification loss. See here for details.
  • Coloured plot for test classification loss (lower triangular) matrix. See here for details.
  • Curve plots for test average classification loss over different training tasks. See here for details.

Please refer to the A Summary of Continual Learning Metrics to learn about this metric.

CLLoss( save_dir: str, test_loss_cls_csv_name: str = 'loss_cls.csv', test_loss_cls_matrix_plot_name: str | None = None, test_ave_loss_cls_plot_name: str | None = None)
45    def __init__(
46        self,
47        save_dir: str,
48        test_loss_cls_csv_name: str = "loss_cls.csv",
49        test_loss_cls_matrix_plot_name: str | None = None,
50        test_ave_loss_cls_plot_name: str | None = None,
51    ) -> None:
52        r"""
53        **Args:**
54        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
55        - **test_loss_cls_csv_name**(`str`): file name to save classification loss matrix and average classification loss as CSV file.
56        - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved.
57        - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved.
58        """
59        super().__init__(save_dir=save_dir)
60
61        self.test_loss_cls_csv_path: str = os.path.join(
62            save_dir, test_loss_cls_csv_name
63        )
64        r"""The path to save test classification loss matrix and average classification loss CSV file."""
65        if test_loss_cls_matrix_plot_name:
66            self.test_loss_cls_matrix_plot_path: str = os.path.join(
67                save_dir, test_loss_cls_matrix_plot_name
68            )
69            r"""The path to save test classification loss matrix plot."""
70        if test_ave_loss_cls_plot_name:
71            self.test_ave_loss_cls_plot_path: str = os.path.join(
72                save_dir, test_ave_loss_cls_plot_name
73            )
74            r"""The path to save test average classification loss curve plot."""
75
76        # training accumulated metrics
77        self.loss_cls_training_epoch: MeanMetricBatch
78        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
79        self.loss_training_epoch: MeanMetricBatch
80        r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
81
82        # validation accumulated metrics
83        self.loss_cls_val: MeanMetricBatch
84        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
85
86        # test accumulated metrics
87        self.loss_cls_test: dict[int, MeanMetricBatch]
88        r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
89
90        # task ID control
91        self.task_id: int
92        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
  • test_loss_cls_csv_name(str): file name to save classification loss matrix and average classification loss as CSV file.
  • test_loss_cls_matrix_plot_name (str | None): file name to save classification loss matrix plot. If None, no file will be saved.
  • test_ave_loss_cls_plot_name (str | None): file name to save average classification loss as curve plot over different training tasks. If None, no file will be saved.
test_loss_cls_csv_path: str

The path to save test classification loss matrix and average classification loss CSV file.

loss_cls_training_epoch: clarena.utils.metrics.MeanMetricBatch

Classification loss of training epoch. Accumulated and calculated from the training batches. See here for details.

Total loss of training epoch. Accumulated and calculated from the training batches. See here for details.

Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See here for details.

loss_cls_test: dict[int, clarena.utils.metrics.MeanMetricBatch]

Test classification loss of the current model (self.task_id) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.

@rank_zero_only
def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
 94    @rank_zero_only
 95    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
 96        r"""Initialize training and validation metrics."""
 97
 98        # set the current task_id from the `CLAlgorithm` object
 99        self.task_id = pl_module.task_id
100
101        # get the device to put the metrics on the same device
102        device = pl_module.device
103
104        # initialize training metrics
105        self.loss_cls_training_epoch = MeanMetricBatch().to(device)
106        self.loss_training_epoch = MeanMetricBatch().to(device)
107
108        # initialize validation metrics
109        self.loss_cls_val = MeanMetricBatch().to(device)

Initialize training and validation metrics.

@rank_zero_only
def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
111    @rank_zero_only
112    def on_train_batch_end(
113        self,
114        trainer: Trainer,
115        pl_module: CLAlgorithm,
116        outputs: dict[str, Any],
117        batch: Any,
118        batch_idx: int,
119    ) -> None:
120        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
121
122        **Args:**
123        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
124        - **batch** (`Any`): the training data batch.
125        """
126        # get the batch size
127        batch_size = len(batch)
128
129        # get training metrics values of current training batch from the outputs of the `training_step()`
130        loss_cls_batch = outputs["loss_cls"]
131        loss_batch = outputs["loss"]
132
133        # update accumulated training metrics to calculate training metrics of the epoch
134        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
135        self.loss_training_epoch.update(loss_batch, batch_size)
136
137        # log training metrics of current training batch to Lightning loggers
138        pl_module.log(
139            f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True
140        )
141        pl_module.log(
142            f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True
143        )
144
145        # log accumulated training metrics till this training batch to Lightning loggers
146        pl_module.log(
147            f"task_{self.task_id}/train/loss_cls",
148            self.loss_cls_training_epoch.compute(),
149            prog_bar=True,
150        )
151        pl_module.log(
152            f"task_{self.task_id}/train/loss",
153            self.loss_training_epoch.compute(),
154            prog_bar=True,
155        )

Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, the returns of the training_step() method in the CLAlgorithm.
  • batch (Any): the training data batch.
@rank_zero_only
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
157    @rank_zero_only
158    def on_train_epoch_end(
159        self,
160        trainer: Trainer,
161        pl_module: CLAlgorithm,
162    ) -> None:
163        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
164
165        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
166        pl_module.log(
167            f"task_{self.task_id}/learning_curve/train/loss_cls",
168            self.loss_cls_training_epoch.compute(),
169            on_epoch=True,
170            prog_bar=True,
171        )
172
173        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
174        self.loss_cls_training_epoch.reset()
175        self.loss_training_epoch.reset()

Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.

@rank_zero_only
def on_validation_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
177    @rank_zero_only
178    def on_validation_batch_end(
179        self,
180        trainer: Trainer,
181        pl_module: CLAlgorithm,
182        outputs: dict[str, Any],
183        batch: Any,
184        batch_idx: int,
185    ) -> None:
186        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
187
188        **Args:**
189        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
190        - **batch** (`Any`): the validation data batch.
191        """
192
193        # get the batch size
194        batch_size = len(batch)
195
196        # get the metrics values of the batch from the outputs
197        loss_cls_batch = outputs["loss_cls"]
198
199        # update the accumulated metrics in order to calculate the validation metrics
200        self.loss_cls_val.update(loss_cls_batch, batch_size)

Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.

Args:

  • outputs (dict[str, Any]): the outputs of the validation step, which is the returns of the validation_step() method in the CLAlgorithm.
  • batch (Any): the validation data batch.
@rank_zero_only
def on_validation_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
202    @rank_zero_only
203    def on_validation_epoch_end(
204        self,
205        trainer: Trainer,
206        pl_module: CLAlgorithm,
207    ) -> None:
208        r"""Log validation metrics to plot learning curves."""
209
210        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
211        pl_module.log(
212            f"task_{self.task_id}/learning_curve/val/loss_cls",
213            self.loss_cls_val.compute(),
214            on_epoch=True,
215            prog_bar=True,
216        )

Log validation metrics to plot learning curves.

@rank_zero_only
def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
218    @rank_zero_only
219    def on_test_start(
220        self,
221        trainer: Trainer,
222        pl_module: CLAlgorithm,
223    ) -> None:
224        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
225
226        # set the current task_id again (double checking) from the `CLAlgorithm` object
227        self.task_id = pl_module.task_id
228
229        # get the device to put the metrics on the same device
230        device = pl_module.device
231
232        # initialize test metrics for current and previous tasks
233        self.loss_cls_test = {
234            task_id: MeanMetricBatch().to(device)
235            for task_id in pl_module.processed_task_ids
236        }

Initialize the metrics for testing each seen task in the beginning of a task's testing.

@rank_zero_only
def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
238    @rank_zero_only
239    def on_test_batch_end(
240        self,
241        trainer: Trainer,
242        pl_module: CLAlgorithm,
243        outputs: dict[str, Any],
244        batch: Any,
245        batch_idx: int,
246        dataloader_idx: int = 0,
247    ) -> None:
248        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
249
250        **Args:**
251        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
252        - **batch** (`Any`): the test data batch.
253        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
254        """
255
256        # get the batch size
257        batch_size = len(batch)
258
259        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
260
261        # get the metrics values of the batch from the outputs
262        loss_cls_batch = outputs["loss_cls"]
263
264        # update the accumulated metrics in order to calculate the metrics of the epoch
265        self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the CLAlgorithm.
  • batch (Any): the test data batch.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
@rank_zero_only
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
267    @rank_zero_only
268    def on_test_epoch_end(
269        self,
270        trainer: Trainer,
271        pl_module: CLAlgorithm,
272    ) -> None:
273        r"""Save and plot test metrics at the end of test."""
274
275        # save (update) the test metrics to CSV files
276        self.update_test_loss_cls_to_csv(
277            after_training_task_id=self.task_id,
278            csv_path=self.test_loss_cls_csv_path,
279        )
280
281        # plot the test metrics
282        if hasattr(self, "test_loss_cls_matrix_plot_path"):
283            self.plot_test_loss_cls_matrix_from_csv(
284                csv_path=self.test_loss_cls_csv_path,
285                plot_path=self.test_loss_cls_matrix_plot_path,
286            )
287        if hasattr(self, "test_ave_loss_cls_plot_path"):
288            self.plot_test_ave_loss_cls_curve_from_csv(
289                csv_path=self.test_loss_cls_csv_path,
290                plot_path=self.test_ave_loss_cls_plot_path,
291            )

Save and plot test metrics at the end of test.

def update_test_loss_cls_to_csv(self, after_training_task_id: int, csv_path: str) -> None:
293    def update_test_loss_cls_to_csv(
294        self,
295        after_training_task_id: int,
296        csv_path: str,
297    ) -> None:
298        """Update the test classification loss metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
299
300        **Args:**
301        - **after_training_task_id** (`int`): the task ID after training.
302        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
303        """
304        processed_task_ids = list(self.loss_cls_test.keys())
305        fieldnames = ["after_training_task", "average_classification_loss"] + [
306            f"test_on_task_{task_id}" for task_id in processed_task_ids
307        ]
308
309        new_line = {
310            "after_training_task": after_training_task_id
311        }  # construct the first column
312
313        # write to the columns and calculate the average classification loss over tasks at the same time
314        average_classification_loss_over_tasks = MeanMetric().to(
315            device=next(iter(self.loss_cls_test.values())).device
316        )
317        for task_id in processed_task_ids:
318            loss_cls = self.loss_cls_test[task_id].compute().item()
319            new_line[f"test_on_task_{task_id}"] = loss_cls
320            average_classification_loss_over_tasks(loss_cls)
321        new_line["average_classification_loss"] = (
322            average_classification_loss_over_tasks.compute().item()
323        )
324
325        # write to the csv file
326        is_first = not os.path.exists(csv_path)
327        if not is_first:
328            with open(csv_path, "r", encoding="utf-8") as file:
329                lines = file.readlines()
330                del lines[0]
331        # write header
332        with open(csv_path, "w", encoding="utf-8") as file:
333            writer = csv.DictWriter(file, fieldnames=fieldnames)
334            writer.writeheader()
335        # write metrics
336        with open(csv_path, "a", encoding="utf-8") as file:
337            if not is_first:
338                file.writelines(lines)  # write the previous lines
339            writer = csv.DictWriter(file, fieldnames=fieldnames)
340            writer.writerow(new_line)

Update the test classification loss metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.

Args:

  • after_training_task_id (int): the task ID after training.
  • csv_path (str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
def plot_test_loss_cls_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
342    def plot_test_loss_cls_matrix_from_csv(self, csv_path: str, plot_path: str) -> None:
343        """Plot the test classification loss matrix from saved CSV file and save the plot to the designated directory.
344
345        **Args:**
346        - **csv_path** (`str`): the path to the CSV file where the `utils.update_loss_cls_to_csv()` saved the test classification loss metric.
347        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls_matrix.png'.
348        """
349        data = pd.read_csv(csv_path)
350        processed_task_ids = [
351            int(col.replace("test_on_task_", ""))
352            for col in data.columns
353            if col.startswith("test_on_task_")
354        ]
355
356        # Get all columns that start with "test_on_task_"
357        test_task_cols = [
358            col for col in data.columns if col.startswith("test_on_task_")
359        ]
360        num_tasks = len(processed_task_ids)
361        num_rows = len(data)
362
363        # Build the loss matrix
364        loss_matrix = data[test_task_cols].values
365
366        fig, ax = plt.subplots(
367            figsize=(2 * num_tasks, 2 * num_rows)
368        )  # adaptive figure size
369
370        cax = ax.imshow(
371            loss_matrix,
372            interpolation="nearest",
373            cmap="Greens",
374            aspect="auto",
375        )
376
377        colorbar = fig.colorbar(cax)
378        yticks = colorbar.ax.get_yticks()
379        colorbar.ax.set_yticks(yticks)
380        colorbar.ax.set_yticklabels(
381            [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks
382        )
383
384        # Annotate each cell
385        for r in range(num_rows):
386            for c in range(r + 1):
387                ax.text(
388                    c,
389                    r,
390                    f"{loss_matrix[r, c]:.3f}",
391                    ha="center",
392                    va="center",
393                    color="black",
394                    fontsize=10 + num_tasks,
395                )
396
397        ax.set_xticks(range(num_tasks))
398        ax.set_yticks(range(num_rows))
399        ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks)
400        ax.set_yticklabels(
401            data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks
402        )
403
404        # Labeling the axes
405        ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks)
406        ax.set_ylabel("After training task t", fontsize=10 + num_tasks)
407        fig.tight_layout()
408        fig.savefig(plot_path)
409        plt.close(fig)

Plot the test classification loss matrix from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the CSV file where the utils.update_loss_cls_to_csv() saved the test classification loss metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls_matrix.png'.
def plot_test_ave_loss_cls_curve_from_csv(self, csv_path: str, plot_path: str) -> None:
411    def plot_test_ave_loss_cls_curve_from_csv(
412        self, csv_path: str, plot_path: str
413    ) -> None:
414        """Plot the test average classfication loss curve over different training tasks from saved CSV file and save the plot to the designated directory.
415
416        **Args:**
417        - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test classfication loss metric.
418        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_loss_cls.png'.
419        """
420        data = pd.read_csv(csv_path)
421        after_training_tasks = data["after_training_task"].astype(int).tolist()
422
423        # plot the average accuracy curve over different training tasks
424        fig, ax = plt.subplots(figsize=(16, 9))
425        ax.plot(
426            after_training_tasks,
427            data["average_classification_loss"],
428            marker="o",
429            linewidth=2,
430        )
431        ax.set_xlabel("After training task $t$", fontsize=16)
432        ax.set_ylabel("Average Classification Loss", fontsize=16)
433        ax.grid(True)
434        xticks = after_training_tasks
435        yticks = [i * 0.05 for i in range(21)]
436        ax.set_xticks(xticks)
437        ax.set_yticks(yticks)
438        ax.set_xticklabels(xticks, fontsize=16)
439        ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16)
440        fig.savefig(plot_path)
441        plt.close(fig)

Plot the test average classfication loss curve over different training tasks from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the CSV file where the utils.update_test_acc_to_csv() saved the test classfication loss metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_loss_cls.png'.