clarena.callbacks.cl_metrics

The submodule in callbacks for CLMetricsCallback.

  1r"""
  2The submodule in `callbacks` for `CLMetricsCallback`.
  3"""
  4
  5__all__ = ["CLMetricsCallback"]
  6
  7import logging
  8import os
  9from typing import Any
 10
 11from lightning import Callback, Trainer
 12
 13from clarena.cl_algorithms import CLAlgorithm
 14from clarena.utils import MeanMetricBatch, plot, save
 15
 16# always get logger for built-in logging in each module
 17pylogger = logging.getLogger(__name__)
 18
 19
 20class CLMetricsCallback(Callback):
 21    r"""Provides all actions that are related to CL metrics, which include:
 22
 23    - Defining, initialising and recording metrics.
 24    - Logging training and validation metrics to Lightning loggers in real time.
 25    - Saving test metrics to files.
 26    - Visualising test metrics as plots.
 27
 28    Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn what continual learning metrics mean.
 29
 30    Lightning provides `self.log()` to log metrics in `LightningModule` where our `CLAlgorithm` based. You can put `self.log()` here if you don't want to mess up the `CLAlgorithm` with a huge amount of logging codes.
 31
 32    The callback is able to produce the following outputs:
 33
 34    - CSV files for test accuracy and classification loss (lower triangular) matrix, average accuracy and classification loss. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 35    - Coloured plot for test accuracy and classification loss (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 36    - Curve plots for test average accuracy and classification loss over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details.
 37
 38
 39    """
 40
 41    def __init__(
 42        self,
 43        save_dir: str,
 44        test_acc_csv_name: str | None = None,
 45        test_loss_cls_csv_name: str | None = None,
 46        test_acc_matrix_plot_name: str | None = None,
 47        test_loss_cls_matrix_plot_name: str | None = None,
 48        test_ave_acc_plot_name: str | None = None,
 49        test_ave_loss_cls_plot_name: str | None = None,
 50    ) -> None:
 51        r"""Initialise the `CLMetricsCallback`.
 52
 53        **Args:**
 54        - **save_dir** (`str` | `None`): the directory to save the test metrics files and plots. Better inside the output folder.
 55        - **test_acc_csv_name** (`str` | `None`): file name to save test accuracy matrix and average accuracy as CSV file. If `None`, no file will be saved.
 56        - **test_loss_cls_csv_name**(`str` | `None`): file name to save classification loss matrix and average classification loss as CSV file. If `None`, no file will be saved.
 57        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
 58        - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved.
 59        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
 60        - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved.
 61        """
 62        Callback.__init__(self)
 63
 64        os.makedirs(save_dir, exist_ok=True)
 65
 66        # paths
 67        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 68        r"""Store the path to save test accuracy matrix and average accuracy CSV file."""
 69        self.test_loss_cls_csv_path: str = os.path.join(
 70            save_dir, test_loss_cls_csv_name
 71        )
 72        r"""Store the path to save test classification loss and average accuracy CSV file."""
 73        self.test_acc_matrix_plot_path: str = os.path.join(
 74            save_dir, test_acc_matrix_plot_name
 75        )
 76        r"""Store the path to save test accuracy matrix plot."""
 77        self.test_loss_cls_matrix_plot_path: str = os.path.join(
 78            save_dir, test_loss_cls_matrix_plot_name
 79        )
 80        r"""Store the path to save test classification loss matrix plot."""
 81        self.test_ave_acc_plot_path: str = os.path.join(
 82            save_dir, test_ave_acc_plot_name
 83        )
 84        r"""Store the path to save test average accuracy curve plot."""
 85        self.test_ave_loss_cls_plot_path: str = os.path.join(
 86            save_dir, test_ave_loss_cls_plot_name
 87        )
 88        r"""Store the path to save test average classification loss curve plot."""
 89
 90        # training accumulated metrics
 91        self.acc_training_epoch: MeanMetricBatch
 92        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 93        self.loss_cls_training_epoch: MeanMetricBatch
 94        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 95        self.loss_training_epoch: MeanMetricBatch
 96        r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 97
 98        # validation accumulated metrics
 99        self.acc_val: MeanMetricBatch
100        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
101        self.loss_cls_val: MeanMetricBatch
102        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
103
104        # test accumulated metrics
105        self.acc_test: dict[str, MeanMetricBatch]
106        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
107        self.loss_cls_test: dict[str, MeanMetricBatch]
108        r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
109
110        self.task_id: int
111        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
112
113    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
114        r"""Initialise training and validation metrics."""
115
116        # set the current task_id from the `CLAlgorithm` object
117        self.task_id = pl_module.task_id
118
119        # initialise training metrics
120        self.loss_cls_training_epoch = MeanMetricBatch()
121        self.loss_training_epoch = MeanMetricBatch()
122        self.acc_training_epoch = MeanMetricBatch()
123
124        # initialise validation metrics
125        self.loss_cls_val = MeanMetricBatch()
126        self.acc_val = MeanMetricBatch()
127
128    def on_train_batch_end(
129        self,
130        trainer: Trainer,
131        pl_module: CLAlgorithm,
132        outputs: dict[str, Any],
133        batch: Any,
134        batch_idx: int,
135    ) -> None:
136        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
137
138        **Args:**
139        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
140        - **batch** (`Any`): the training data batch.
141        """
142        # get the batch size
143        batch_size = len(batch)
144
145        # get training metrics values of current training batch from the outputs of the `training_step()`
146        loss_cls_batch = outputs["loss_cls"]
147        loss_batch = outputs["loss"]
148        acc_batch = outputs["acc"]
149
150        # update accumulated training metrics to calculate training metrics of the epoch
151        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
152        self.loss_training_epoch.update(loss_cls_batch, batch_size)
153        self.acc_training_epoch.update(acc_batch, batch_size)
154
155        # log training metrics of current training batch to Lightning loggers
156        pl_module.log(
157            f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True
158        )
159        pl_module.log(
160            f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True
161        )
162        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
163
164        # log accumulated training metrics till this training batch to Lightning loggers
165        pl_module.log(
166            f"task_{self.task_id}/train/loss_cls",
167            self.loss_cls_training_epoch.compute(),
168            prog_bar=True,
169        )
170        pl_module.log(
171            f"task_{self.task_id}/train/loss",
172            self.loss_training_epoch.compute(),
173            prog_bar=True,
174        )
175        pl_module.log(
176            f"task_{self.task_id}/train/acc",
177            self.acc_training_epoch.compute(),
178            prog_bar=True,
179        )
180
181    def on_train_epoch_end(
182        self,
183        trainer: Trainer,
184        pl_module: CLAlgorithm,
185    ) -> None:
186        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
187
188        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
189        pl_module.log(
190            f"task_{self.task_id}/learning_curve/train/loss_cls",
191            self.loss_cls_training_epoch.compute(),
192            on_epoch=True,
193            prog_bar=True,
194        )
195        pl_module.log(
196            f"task_{self.task_id}/learning_curve/train/acc",
197            self.acc_training_epoch.compute(),
198            on_epoch=True,
199            prog_bar=True,
200        )
201
202        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
203        self.loss_cls_training_epoch.reset()
204        self.loss_training_epoch.reset()
205        self.acc_training_epoch.reset()
206
207    def on_validation_batch_end(
208        self,
209        trainer: Trainer,
210        pl_module: CLAlgorithm,
211        outputs: dict[str, Any],
212        batch: Any,
213        batch_idx: int,
214    ) -> None:
215        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
216
217        **Args:**
218        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
219        - **batch** (`Any`): the validation data batch.
220        """
221
222        # get the batch size
223        batch_size = len(batch)
224
225        # get the metrics values of the batch from the outputs
226        loss_cls_batch = outputs["loss_cls"]
227        acc_batch = outputs["acc"]
228
229        # update the accumulated metrics in order to calculate the validation metrics
230        self.loss_cls_val.update(loss_cls_batch, batch_size)
231        self.acc_val.update(acc_batch, batch_size)
232
233    def on_validation_epoch_end(
234        self,
235        trainer: Trainer,
236        pl_module: CLAlgorithm,
237    ) -> None:
238        r"""Log validation metrics to plot learning curves."""
239
240        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
241        pl_module.log(
242            f"task_{self.task_id}/learning_curve/val/loss_cls",
243            self.loss_cls_val.compute(),
244            on_epoch=True,
245            prog_bar=True,
246        )
247        pl_module.log(
248            f"task_{self.task_id}/learning_curve/val/acc",
249            self.acc_val.compute(),
250            on_epoch=True,
251            prog_bar=True,
252        )
253
254    def on_test_start(
255        self,
256        trainer: Trainer,
257        pl_module: CLAlgorithm,
258    ) -> None:
259        r"""Initialise the metrics for testing each seen task in the beginning of a task's testing."""
260
261        # set the current task_id again (double checking) from the `CLAlgorithm` object
262        self.task_id = pl_module.task_id
263
264        # initialise test metrics for current and previous tasks
265        self.loss_cls_test = {
266            f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1)
267        }
268        self.acc_test = {
269            f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1)
270        }
271
272    def on_test_batch_end(
273        self,
274        trainer: Trainer,
275        pl_module: CLAlgorithm,
276        outputs: dict[str, Any],
277        batch: Any,
278        batch_idx: int,
279        dataloader_idx: int = 0,
280    ) -> None:
281        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
282
283        **Args:**
284        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
285        - **batch** (`Any`): the validation data batch.
286        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
287        """
288
289        # get the batch size
290        batch_size = len(batch)
291
292        task_id = dataloader_idx + 1  # our `task_id` system starts from 1
293
294        # get the metrics values of the batch from the outputs
295        loss_cls_batch = outputs["loss_cls"]
296        acc_batch = outputs["acc"]
297
298        # update the accumulated metrics in order to calculate the metrics of the epoch
299        self.loss_cls_test[f"{task_id}"].update(loss_cls_batch, batch_size)
300        self.acc_test[f"{task_id}"].update(acc_batch, batch_size)
301
302    def on_test_epoch_end(
303        self,
304        trainer: Trainer,
305        pl_module: CLAlgorithm,
306    ) -> None:
307        r"""Save and plot test metrics at the end of test."""
308
309        # save (update) the test metrics to CSV files
310        save.update_test_acc_to_csv(
311            test_acc_metric=self.acc_test,
312            csv_path=self.test_acc_csv_path,
313        )
314        save.update_test_loss_cls_to_csv(
315            test_loss_cls_metric=self.loss_cls_test,
316            csv_path=self.test_loss_cls_csv_path,
317        )
318
319        # plot the test metrics
320        plot.plot_test_acc_matrix_from_csv(
321            csv_path=self.test_acc_csv_path,
322            task_id=self.task_id,
323            plot_path=self.test_acc_matrix_plot_path,
324        )
325        plot.plot_test_loss_cls_matrix_from_csv(
326            csv_path=self.test_loss_cls_csv_path,
327            task_id=self.task_id,
328            plot_path=self.test_loss_cls_matrix_plot_path,
329        )
330        plot.plot_test_ave_acc_curve_from_csv(
331            csv_path=self.test_acc_csv_path,
332            task_id=self.task_id,
333            plot_path=self.test_ave_acc_plot_path,
334        )
335        plot.plot_test_ave_loss_cls_curve_from_csv(
336            csv_path=self.test_loss_cls_csv_path,
337            task_id=self.task_id,
338            plot_path=self.test_ave_loss_cls_plot_path,
339        )
class CLMetricsCallback(lightning.pytorch.callbacks.callback.Callback):
 21class CLMetricsCallback(Callback):
 22    r"""Provides all actions that are related to CL metrics, which include:
 23
 24    - Defining, initialising and recording metrics.
 25    - Logging training and validation metrics to Lightning loggers in real time.
 26    - Saving test metrics to files.
 27    - Visualising test metrics as plots.
 28
 29    Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn what continual learning metrics mean.
 30
 31    Lightning provides `self.log()` to log metrics in `LightningModule` where our `CLAlgorithm` based. You can put `self.log()` here if you don't want to mess up the `CLAlgorithm` with a huge amount of logging codes.
 32
 33    The callback is able to produce the following outputs:
 34
 35    - CSV files for test accuracy and classification loss (lower triangular) matrix, average accuracy and classification loss. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 36    - Coloured plot for test accuracy and classification loss (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details.
 37    - Curve plots for test average accuracy and classification loss over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details.
 38
 39
 40    """
 41
 42    def __init__(
 43        self,
 44        save_dir: str,
 45        test_acc_csv_name: str | None = None,
 46        test_loss_cls_csv_name: str | None = None,
 47        test_acc_matrix_plot_name: str | None = None,
 48        test_loss_cls_matrix_plot_name: str | None = None,
 49        test_ave_acc_plot_name: str | None = None,
 50        test_ave_loss_cls_plot_name: str | None = None,
 51    ) -> None:
 52        r"""Initialise the `CLMetricsCallback`.
 53
 54        **Args:**
 55        - **save_dir** (`str` | `None`): the directory to save the test metrics files and plots. Better inside the output folder.
 56        - **test_acc_csv_name** (`str` | `None`): file name to save test accuracy matrix and average accuracy as CSV file. If `None`, no file will be saved.
 57        - **test_loss_cls_csv_name**(`str` | `None`): file name to save classification loss matrix and average classification loss as CSV file. If `None`, no file will be saved.
 58        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
 59        - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved.
 60        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
 61        - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved.
 62        """
 63        Callback.__init__(self)
 64
 65        os.makedirs(save_dir, exist_ok=True)
 66
 67        # paths
 68        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 69        r"""Store the path to save test accuracy matrix and average accuracy CSV file."""
 70        self.test_loss_cls_csv_path: str = os.path.join(
 71            save_dir, test_loss_cls_csv_name
 72        )
 73        r"""Store the path to save test classification loss and average accuracy CSV file."""
 74        self.test_acc_matrix_plot_path: str = os.path.join(
 75            save_dir, test_acc_matrix_plot_name
 76        )
 77        r"""Store the path to save test accuracy matrix plot."""
 78        self.test_loss_cls_matrix_plot_path: str = os.path.join(
 79            save_dir, test_loss_cls_matrix_plot_name
 80        )
 81        r"""Store the path to save test classification loss matrix plot."""
 82        self.test_ave_acc_plot_path: str = os.path.join(
 83            save_dir, test_ave_acc_plot_name
 84        )
 85        r"""Store the path to save test average accuracy curve plot."""
 86        self.test_ave_loss_cls_plot_path: str = os.path.join(
 87            save_dir, test_ave_loss_cls_plot_name
 88        )
 89        r"""Store the path to save test average classification loss curve plot."""
 90
 91        # training accumulated metrics
 92        self.acc_training_epoch: MeanMetricBatch
 93        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 94        self.loss_cls_training_epoch: MeanMetricBatch
 95        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 96        self.loss_training_epoch: MeanMetricBatch
 97        r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 98
 99        # validation accumulated metrics
100        self.acc_val: MeanMetricBatch
101        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
102        self.loss_cls_val: MeanMetricBatch
103        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
104
105        # test accumulated metrics
106        self.acc_test: dict[str, MeanMetricBatch]
107        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
108        self.loss_cls_test: dict[str, MeanMetricBatch]
109        r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
110
111        self.task_id: int
112        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
113
114    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
115        r"""Initialise training and validation metrics."""
116
117        # set the current task_id from the `CLAlgorithm` object
118        self.task_id = pl_module.task_id
119
120        # initialise training metrics
121        self.loss_cls_training_epoch = MeanMetricBatch()
122        self.loss_training_epoch = MeanMetricBatch()
123        self.acc_training_epoch = MeanMetricBatch()
124
125        # initialise validation metrics
126        self.loss_cls_val = MeanMetricBatch()
127        self.acc_val = MeanMetricBatch()
128
129    def on_train_batch_end(
130        self,
131        trainer: Trainer,
132        pl_module: CLAlgorithm,
133        outputs: dict[str, Any],
134        batch: Any,
135        batch_idx: int,
136    ) -> None:
137        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
138
139        **Args:**
140        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
141        - **batch** (`Any`): the training data batch.
142        """
143        # get the batch size
144        batch_size = len(batch)
145
146        # get training metrics values of current training batch from the outputs of the `training_step()`
147        loss_cls_batch = outputs["loss_cls"]
148        loss_batch = outputs["loss"]
149        acc_batch = outputs["acc"]
150
151        # update accumulated training metrics to calculate training metrics of the epoch
152        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
153        self.loss_training_epoch.update(loss_cls_batch, batch_size)
154        self.acc_training_epoch.update(acc_batch, batch_size)
155
156        # log training metrics of current training batch to Lightning loggers
157        pl_module.log(
158            f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True
159        )
160        pl_module.log(
161            f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True
162        )
163        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
164
165        # log accumulated training metrics till this training batch to Lightning loggers
166        pl_module.log(
167            f"task_{self.task_id}/train/loss_cls",
168            self.loss_cls_training_epoch.compute(),
169            prog_bar=True,
170        )
171        pl_module.log(
172            f"task_{self.task_id}/train/loss",
173            self.loss_training_epoch.compute(),
174            prog_bar=True,
175        )
176        pl_module.log(
177            f"task_{self.task_id}/train/acc",
178            self.acc_training_epoch.compute(),
179            prog_bar=True,
180        )
181
182    def on_train_epoch_end(
183        self,
184        trainer: Trainer,
185        pl_module: CLAlgorithm,
186    ) -> None:
187        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
188
189        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
190        pl_module.log(
191            f"task_{self.task_id}/learning_curve/train/loss_cls",
192            self.loss_cls_training_epoch.compute(),
193            on_epoch=True,
194            prog_bar=True,
195        )
196        pl_module.log(
197            f"task_{self.task_id}/learning_curve/train/acc",
198            self.acc_training_epoch.compute(),
199            on_epoch=True,
200            prog_bar=True,
201        )
202
203        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
204        self.loss_cls_training_epoch.reset()
205        self.loss_training_epoch.reset()
206        self.acc_training_epoch.reset()
207
208    def on_validation_batch_end(
209        self,
210        trainer: Trainer,
211        pl_module: CLAlgorithm,
212        outputs: dict[str, Any],
213        batch: Any,
214        batch_idx: int,
215    ) -> None:
216        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
217
218        **Args:**
219        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
220        - **batch** (`Any`): the validation data batch.
221        """
222
223        # get the batch size
224        batch_size = len(batch)
225
226        # get the metrics values of the batch from the outputs
227        loss_cls_batch = outputs["loss_cls"]
228        acc_batch = outputs["acc"]
229
230        # update the accumulated metrics in order to calculate the validation metrics
231        self.loss_cls_val.update(loss_cls_batch, batch_size)
232        self.acc_val.update(acc_batch, batch_size)
233
234    def on_validation_epoch_end(
235        self,
236        trainer: Trainer,
237        pl_module: CLAlgorithm,
238    ) -> None:
239        r"""Log validation metrics to plot learning curves."""
240
241        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
242        pl_module.log(
243            f"task_{self.task_id}/learning_curve/val/loss_cls",
244            self.loss_cls_val.compute(),
245            on_epoch=True,
246            prog_bar=True,
247        )
248        pl_module.log(
249            f"task_{self.task_id}/learning_curve/val/acc",
250            self.acc_val.compute(),
251            on_epoch=True,
252            prog_bar=True,
253        )
254
255    def on_test_start(
256        self,
257        trainer: Trainer,
258        pl_module: CLAlgorithm,
259    ) -> None:
260        r"""Initialise the metrics for testing each seen task in the beginning of a task's testing."""
261
262        # set the current task_id again (double checking) from the `CLAlgorithm` object
263        self.task_id = pl_module.task_id
264
265        # initialise test metrics for current and previous tasks
266        self.loss_cls_test = {
267            f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1)
268        }
269        self.acc_test = {
270            f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1)
271        }
272
273    def on_test_batch_end(
274        self,
275        trainer: Trainer,
276        pl_module: CLAlgorithm,
277        outputs: dict[str, Any],
278        batch: Any,
279        batch_idx: int,
280        dataloader_idx: int = 0,
281    ) -> None:
282        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
283
284        **Args:**
285        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
286        - **batch** (`Any`): the validation data batch.
287        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
288        """
289
290        # get the batch size
291        batch_size = len(batch)
292
293        task_id = dataloader_idx + 1  # our `task_id` system starts from 1
294
295        # get the metrics values of the batch from the outputs
296        loss_cls_batch = outputs["loss_cls"]
297        acc_batch = outputs["acc"]
298
299        # update the accumulated metrics in order to calculate the metrics of the epoch
300        self.loss_cls_test[f"{task_id}"].update(loss_cls_batch, batch_size)
301        self.acc_test[f"{task_id}"].update(acc_batch, batch_size)
302
303    def on_test_epoch_end(
304        self,
305        trainer: Trainer,
306        pl_module: CLAlgorithm,
307    ) -> None:
308        r"""Save and plot test metrics at the end of test."""
309
310        # save (update) the test metrics to CSV files
311        save.update_test_acc_to_csv(
312            test_acc_metric=self.acc_test,
313            csv_path=self.test_acc_csv_path,
314        )
315        save.update_test_loss_cls_to_csv(
316            test_loss_cls_metric=self.loss_cls_test,
317            csv_path=self.test_loss_cls_csv_path,
318        )
319
320        # plot the test metrics
321        plot.plot_test_acc_matrix_from_csv(
322            csv_path=self.test_acc_csv_path,
323            task_id=self.task_id,
324            plot_path=self.test_acc_matrix_plot_path,
325        )
326        plot.plot_test_loss_cls_matrix_from_csv(
327            csv_path=self.test_loss_cls_csv_path,
328            task_id=self.task_id,
329            plot_path=self.test_loss_cls_matrix_plot_path,
330        )
331        plot.plot_test_ave_acc_curve_from_csv(
332            csv_path=self.test_acc_csv_path,
333            task_id=self.task_id,
334            plot_path=self.test_ave_acc_plot_path,
335        )
336        plot.plot_test_ave_loss_cls_curve_from_csv(
337            csv_path=self.test_loss_cls_csv_path,
338            task_id=self.task_id,
339            plot_path=self.test_ave_loss_cls_plot_path,
340        )

Provides all actions that are related to CL metrics, which include:

  • Defining, initialising and recording metrics.
  • Logging training and validation metrics to Lightning loggers in real time.
  • Saving test metrics to files.
  • Visualising test metrics as plots.

Please refer to the A Summary of Continual Learning Metrics to learn what continual learning metrics mean.

Lightning provides self.log() to log metrics in LightningModule where our CLAlgorithm based. You can put self.log() here if you don't want to mess up the CLAlgorithm with a huge amount of logging codes.

The callback is able to produce the following outputs:

  • CSV files for test accuracy and classification loss (lower triangular) matrix, average accuracy and classification loss. See here for details.
  • Coloured plot for test accuracy and classification loss (lower triangular) matrix. See here for details.
  • Curve plots for test average accuracy and classification loss over different training tasks. See here for details.
CLMetricsCallback( save_dir: str, test_acc_csv_name: str | None = None, test_loss_cls_csv_name: str | None = None, test_acc_matrix_plot_name: str | None = None, test_loss_cls_matrix_plot_name: str | None = None, test_ave_acc_plot_name: str | None = None, test_ave_loss_cls_plot_name: str | None = None)
 42    def __init__(
 43        self,
 44        save_dir: str,
 45        test_acc_csv_name: str | None = None,
 46        test_loss_cls_csv_name: str | None = None,
 47        test_acc_matrix_plot_name: str | None = None,
 48        test_loss_cls_matrix_plot_name: str | None = None,
 49        test_ave_acc_plot_name: str | None = None,
 50        test_ave_loss_cls_plot_name: str | None = None,
 51    ) -> None:
 52        r"""Initialise the `CLMetricsCallback`.
 53
 54        **Args:**
 55        - **save_dir** (`str` | `None`): the directory to save the test metrics files and plots. Better inside the output folder.
 56        - **test_acc_csv_name** (`str` | `None`): file name to save test accuracy matrix and average accuracy as CSV file. If `None`, no file will be saved.
 57        - **test_loss_cls_csv_name**(`str` | `None`): file name to save classification loss matrix and average classification loss as CSV file. If `None`, no file will be saved.
 58        - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved.
 59        - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved.
 60        - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved.
 61        - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved.
 62        """
 63        Callback.__init__(self)
 64
 65        os.makedirs(save_dir, exist_ok=True)
 66
 67        # paths
 68        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 69        r"""Store the path to save test accuracy matrix and average accuracy CSV file."""
 70        self.test_loss_cls_csv_path: str = os.path.join(
 71            save_dir, test_loss_cls_csv_name
 72        )
 73        r"""Store the path to save test classification loss and average accuracy CSV file."""
 74        self.test_acc_matrix_plot_path: str = os.path.join(
 75            save_dir, test_acc_matrix_plot_name
 76        )
 77        r"""Store the path to save test accuracy matrix plot."""
 78        self.test_loss_cls_matrix_plot_path: str = os.path.join(
 79            save_dir, test_loss_cls_matrix_plot_name
 80        )
 81        r"""Store the path to save test classification loss matrix plot."""
 82        self.test_ave_acc_plot_path: str = os.path.join(
 83            save_dir, test_ave_acc_plot_name
 84        )
 85        r"""Store the path to save test average accuracy curve plot."""
 86        self.test_ave_loss_cls_plot_path: str = os.path.join(
 87            save_dir, test_ave_loss_cls_plot_name
 88        )
 89        r"""Store the path to save test average classification loss curve plot."""
 90
 91        # training accumulated metrics
 92        self.acc_training_epoch: MeanMetricBatch
 93        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 94        self.loss_cls_training_epoch: MeanMetricBatch
 95        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 96        self.loss_training_epoch: MeanMetricBatch
 97        r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """
 98
 99        # validation accumulated metrics
100        self.acc_val: MeanMetricBatch
101        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
102        self.loss_cls_val: MeanMetricBatch
103        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """
104
105        # test accumulated metrics
106        self.acc_test: dict[str, MeanMetricBatch]
107        r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
108        self.loss_cls_test: dict[str, MeanMetricBatch]
109        r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """
110
111        self.task_id: int
112        r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""

Initialise the CLMetricsCallback.

Args:

  • save_dir (str | None): the directory to save the test metrics files and plots. Better inside the output folder.
  • test_acc_csv_name (str | None): file name to save test accuracy matrix and average accuracy as CSV file. If None, no file will be saved.
  • test_loss_cls_csv_name(str | None): file name to save classification loss matrix and average classification loss as CSV file. If None, no file will be saved.
  • test_acc_matrix_plot_name (str | None): file name to save accuracy matrix plot. If None, no file will be saved.
  • test_loss_cls_matrix_plot_name (str | None): file name to save classification loss matrix plot. If None, no file will be saved.
  • test_ave_acc_plot_name (str | None): file name to save average accuracy as curve plot over different training tasks. If None, no file will be saved.
  • test_ave_loss_cls_plot_name (str | None): file name to save average classification loss as curve plot over different training tasks. If None, no file will be saved.
test_acc_csv_path: str

Store the path to save test accuracy matrix and average accuracy CSV file.

test_loss_cls_csv_path: str

Store the path to save test classification loss and average accuracy CSV file.

test_acc_matrix_plot_path: str

Store the path to save test accuracy matrix plot.

test_loss_cls_matrix_plot_path: str

Store the path to save test classification loss matrix plot.

test_ave_acc_plot_path: str

Store the path to save test average accuracy curve plot.

test_ave_loss_cls_plot_path: str

Store the path to save test average classification loss curve plot.

Classification accuracy of training epoch. Accumulated and calculated from the training batches. See here for details.

loss_cls_training_epoch: clarena.utils.metrics.MeanMetricBatch

Classification loss of training epoch. Accumulated and calculated from the training batches. See here for details.

Total loss of training epoch. Accumulated and calculated from the training batches. See here for details.

Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See here for details.

Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See here for details.

Test classification accuracy of the current model (self.task_id) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.

loss_cls_test: dict[str, clarena.utils.metrics.MeanMetricBatch]

Test classification loss of the current model (self.task_id) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
114    def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None:
115        r"""Initialise training and validation metrics."""
116
117        # set the current task_id from the `CLAlgorithm` object
118        self.task_id = pl_module.task_id
119
120        # initialise training metrics
121        self.loss_cls_training_epoch = MeanMetricBatch()
122        self.loss_training_epoch = MeanMetricBatch()
123        self.acc_training_epoch = MeanMetricBatch()
124
125        # initialise validation metrics
126        self.loss_cls_val = MeanMetricBatch()
127        self.acc_val = MeanMetricBatch()

Initialise training and validation metrics.

def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
129    def on_train_batch_end(
130        self,
131        trainer: Trainer,
132        pl_module: CLAlgorithm,
133        outputs: dict[str, Any],
134        batch: Any,
135        batch_idx: int,
136    ) -> None:
137        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
138
139        **Args:**
140        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`.
141        - **batch** (`Any`): the training data batch.
142        """
143        # get the batch size
144        batch_size = len(batch)
145
146        # get training metrics values of current training batch from the outputs of the `training_step()`
147        loss_cls_batch = outputs["loss_cls"]
148        loss_batch = outputs["loss"]
149        acc_batch = outputs["acc"]
150
151        # update accumulated training metrics to calculate training metrics of the epoch
152        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
153        self.loss_training_epoch.update(loss_cls_batch, batch_size)
154        self.acc_training_epoch.update(acc_batch, batch_size)
155
156        # log training metrics of current training batch to Lightning loggers
157        pl_module.log(
158            f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True
159        )
160        pl_module.log(
161            f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True
162        )
163        pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True)
164
165        # log accumulated training metrics till this training batch to Lightning loggers
166        pl_module.log(
167            f"task_{self.task_id}/train/loss_cls",
168            self.loss_cls_training_epoch.compute(),
169            prog_bar=True,
170        )
171        pl_module.log(
172            f"task_{self.task_id}/train/loss",
173            self.loss_training_epoch.compute(),
174            prog_bar=True,
175        )
176        pl_module.log(
177            f"task_{self.task_id}/train/acc",
178            self.acc_training_epoch.compute(),
179            prog_bar=True,
180        )

Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, the returns of the training_step() method in the CLAlgorithm.
  • batch (Any): the training data batch.
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
182    def on_train_epoch_end(
183        self,
184        trainer: Trainer,
185        pl_module: CLAlgorithm,
186    ) -> None:
187        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
188
189        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
190        pl_module.log(
191            f"task_{self.task_id}/learning_curve/train/loss_cls",
192            self.loss_cls_training_epoch.compute(),
193            on_epoch=True,
194            prog_bar=True,
195        )
196        pl_module.log(
197            f"task_{self.task_id}/learning_curve/train/acc",
198            self.acc_training_epoch.compute(),
199            on_epoch=True,
200            prog_bar=True,
201        )
202
203        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
204        self.loss_cls_training_epoch.reset()
205        self.loss_training_epoch.reset()
206        self.acc_training_epoch.reset()

Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.

def on_validation_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
208    def on_validation_batch_end(
209        self,
210        trainer: Trainer,
211        pl_module: CLAlgorithm,
212        outputs: dict[str, Any],
213        batch: Any,
214        batch_idx: int,
215    ) -> None:
216        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
217
218        **Args:**
219        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`.
220        - **batch** (`Any`): the validation data batch.
221        """
222
223        # get the batch size
224        batch_size = len(batch)
225
226        # get the metrics values of the batch from the outputs
227        loss_cls_batch = outputs["loss_cls"]
228        acc_batch = outputs["acc"]
229
230        # update the accumulated metrics in order to calculate the validation metrics
231        self.loss_cls_val.update(loss_cls_batch, batch_size)
232        self.acc_val.update(acc_batch, batch_size)

Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.

Args:

  • outputs (dict[str, Any]): the outputs of the validation step, which is the returns of the validation_step() method in the CLAlgorithm.
  • batch (Any): the validation data batch.
def on_validation_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
234    def on_validation_epoch_end(
235        self,
236        trainer: Trainer,
237        pl_module: CLAlgorithm,
238    ) -> None:
239        r"""Log validation metrics to plot learning curves."""
240
241        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
242        pl_module.log(
243            f"task_{self.task_id}/learning_curve/val/loss_cls",
244            self.loss_cls_val.compute(),
245            on_epoch=True,
246            prog_bar=True,
247        )
248        pl_module.log(
249            f"task_{self.task_id}/learning_curve/val/acc",
250            self.acc_val.compute(),
251            on_epoch=True,
252            prog_bar=True,
253        )

Log validation metrics to plot learning curves.

def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
255    def on_test_start(
256        self,
257        trainer: Trainer,
258        pl_module: CLAlgorithm,
259    ) -> None:
260        r"""Initialise the metrics for testing each seen task in the beginning of a task's testing."""
261
262        # set the current task_id again (double checking) from the `CLAlgorithm` object
263        self.task_id = pl_module.task_id
264
265        # initialise test metrics for current and previous tasks
266        self.loss_cls_test = {
267            f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1)
268        }
269        self.acc_test = {
270            f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1)
271        }

Initialise the metrics for testing each seen task in the beginning of a task's testing.

def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
273    def on_test_batch_end(
274        self,
275        trainer: Trainer,
276        pl_module: CLAlgorithm,
277        outputs: dict[str, Any],
278        batch: Any,
279        batch_idx: int,
280        dataloader_idx: int = 0,
281    ) -> None:
282        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
283
284        **Args:**
285        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`.
286        - **batch** (`Any`): the validation data batch.
287        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
288        """
289
290        # get the batch size
291        batch_size = len(batch)
292
293        task_id = dataloader_idx + 1  # our `task_id` system starts from 1
294
295        # get the metrics values of the batch from the outputs
296        loss_cls_batch = outputs["loss_cls"]
297        acc_batch = outputs["acc"]
298
299        # update the accumulated metrics in order to calculate the metrics of the epoch
300        self.loss_cls_test[f"{task_id}"].update(loss_cls_batch, batch_size)
301        self.acc_test[f"{task_id}"].update(acc_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the CLAlgorithm.
  • batch (Any): the validation data batch.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.cl_algorithms.CLAlgorithm) -> None:
303    def on_test_epoch_end(
304        self,
305        trainer: Trainer,
306        pl_module: CLAlgorithm,
307    ) -> None:
308        r"""Save and plot test metrics at the end of test."""
309
310        # save (update) the test metrics to CSV files
311        save.update_test_acc_to_csv(
312            test_acc_metric=self.acc_test,
313            csv_path=self.test_acc_csv_path,
314        )
315        save.update_test_loss_cls_to_csv(
316            test_loss_cls_metric=self.loss_cls_test,
317            csv_path=self.test_loss_cls_csv_path,
318        )
319
320        # plot the test metrics
321        plot.plot_test_acc_matrix_from_csv(
322            csv_path=self.test_acc_csv_path,
323            task_id=self.task_id,
324            plot_path=self.test_acc_matrix_plot_path,
325        )
326        plot.plot_test_loss_cls_matrix_from_csv(
327            csv_path=self.test_loss_cls_csv_path,
328            task_id=self.task_id,
329            plot_path=self.test_loss_cls_matrix_plot_path,
330        )
331        plot.plot_test_ave_acc_curve_from_csv(
332            csv_path=self.test_acc_csv_path,
333            task_id=self.task_id,
334            plot_path=self.test_ave_acc_plot_path,
335        )
336        plot.plot_test_ave_loss_cls_curve_from_csv(
337            csv_path=self.test_loss_cls_csv_path,
338            task_id=self.task_id,
339            plot_path=self.test_ave_loss_cls_plot_path,
340        )

Save and plot test metrics at the end of test.