clarena.callbacks.cl_metrics
The submodule in callbacks
for CLMetricsCallback
.
1r""" 2The submodule in `callbacks` for `CLMetricsCallback`. 3""" 4 5__all__ = ["CLMetricsCallback"] 6 7import logging 8import os 9from typing import Any 10 11from lightning import Callback, Trainer 12 13from clarena.cl_algorithms import CLAlgorithm 14from clarena.utils import MeanMetricBatch, plot, save 15 16# always get logger for built-in logging in each module 17pylogger = logging.getLogger(__name__) 18 19 20class CLMetricsCallback(Callback): 21 r"""Provides all actions that are related to CL metrics, which include: 22 23 - Defining, initialising and recording metrics. 24 - Logging training and validation metrics to Lightning loggers in real time. 25 - Saving test metrics to files. 26 - Visualising test metrics as plots. 27 28 Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn what continual learning metrics mean. 29 30 Lightning provides `self.log()` to log metrics in `LightningModule` where our `CLAlgorithm` based. You can put `self.log()` here if you don't want to mess up the `CLAlgorithm` with a huge amount of logging codes. 31 32 The callback is able to produce the following outputs: 33 34 - CSV files for test accuracy and classification loss (lower triangular) matrix, average accuracy and classification loss. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 35 - Coloured plot for test accuracy and classification loss (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 36 - Curve plots for test average accuracy and classification loss over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details. 37 38 39 """ 40 41 def __init__( 42 self, 43 save_dir: str, 44 test_acc_csv_name: str | None = None, 45 test_loss_cls_csv_name: str | None = None, 46 test_acc_matrix_plot_name: str | None = None, 47 test_loss_cls_matrix_plot_name: str | None = None, 48 test_ave_acc_plot_name: str | None = None, 49 test_ave_loss_cls_plot_name: str | None = None, 50 ) -> None: 51 r"""Initialise the `CLMetricsCallback`. 52 53 **Args:** 54 - **save_dir** (`str` | `None`): the directory to save the test metrics files and plots. Better inside the output folder. 55 - **test_acc_csv_name** (`str` | `None`): file name to save test accuracy matrix and average accuracy as CSV file. If `None`, no file will be saved. 56 - **test_loss_cls_csv_name**(`str` | `None`): file name to save classification loss matrix and average classification loss as CSV file. If `None`, no file will be saved. 57 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 58 - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved. 59 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 60 - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved. 61 """ 62 Callback.__init__(self) 63 64 os.makedirs(save_dir, exist_ok=True) 65 66 # paths 67 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 68 r"""Store the path to save test accuracy matrix and average accuracy CSV file.""" 69 self.test_loss_cls_csv_path: str = os.path.join( 70 save_dir, test_loss_cls_csv_name 71 ) 72 r"""Store the path to save test classification loss and average accuracy CSV file.""" 73 self.test_acc_matrix_plot_path: str = os.path.join( 74 save_dir, test_acc_matrix_plot_name 75 ) 76 r"""Store the path to save test accuracy matrix plot.""" 77 self.test_loss_cls_matrix_plot_path: str = os.path.join( 78 save_dir, test_loss_cls_matrix_plot_name 79 ) 80 r"""Store the path to save test classification loss matrix plot.""" 81 self.test_ave_acc_plot_path: str = os.path.join( 82 save_dir, test_ave_acc_plot_name 83 ) 84 r"""Store the path to save test average accuracy curve plot.""" 85 self.test_ave_loss_cls_plot_path: str = os.path.join( 86 save_dir, test_ave_loss_cls_plot_name 87 ) 88 r"""Store the path to save test average classification loss curve plot.""" 89 90 # training accumulated metrics 91 self.acc_training_epoch: MeanMetricBatch 92 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 93 self.loss_cls_training_epoch: MeanMetricBatch 94 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 95 self.loss_training_epoch: MeanMetricBatch 96 r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 97 98 # validation accumulated metrics 99 self.acc_val: MeanMetricBatch 100 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 101 self.loss_cls_val: MeanMetricBatch 102 r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 103 104 # test accumulated metrics 105 self.acc_test: dict[str, MeanMetricBatch] 106 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 107 self.loss_cls_test: dict[str, MeanMetricBatch] 108 r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 109 110 self.task_id: int 111 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 112 113 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 114 r"""Initialise training and validation metrics.""" 115 116 # set the current task_id from the `CLAlgorithm` object 117 self.task_id = pl_module.task_id 118 119 # initialise training metrics 120 self.loss_cls_training_epoch = MeanMetricBatch() 121 self.loss_training_epoch = MeanMetricBatch() 122 self.acc_training_epoch = MeanMetricBatch() 123 124 # initialise validation metrics 125 self.loss_cls_val = MeanMetricBatch() 126 self.acc_val = MeanMetricBatch() 127 128 def on_train_batch_end( 129 self, 130 trainer: Trainer, 131 pl_module: CLAlgorithm, 132 outputs: dict[str, Any], 133 batch: Any, 134 batch_idx: int, 135 ) -> None: 136 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 137 138 **Args:** 139 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 140 - **batch** (`Any`): the training data batch. 141 """ 142 # get the batch size 143 batch_size = len(batch) 144 145 # get training metrics values of current training batch from the outputs of the `training_step()` 146 loss_cls_batch = outputs["loss_cls"] 147 loss_batch = outputs["loss"] 148 acc_batch = outputs["acc"] 149 150 # update accumulated training metrics to calculate training metrics of the epoch 151 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 152 self.loss_training_epoch.update(loss_cls_batch, batch_size) 153 self.acc_training_epoch.update(acc_batch, batch_size) 154 155 # log training metrics of current training batch to Lightning loggers 156 pl_module.log( 157 f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True 158 ) 159 pl_module.log( 160 f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True 161 ) 162 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 163 164 # log accumulated training metrics till this training batch to Lightning loggers 165 pl_module.log( 166 f"task_{self.task_id}/train/loss_cls", 167 self.loss_cls_training_epoch.compute(), 168 prog_bar=True, 169 ) 170 pl_module.log( 171 f"task_{self.task_id}/train/loss", 172 self.loss_training_epoch.compute(), 173 prog_bar=True, 174 ) 175 pl_module.log( 176 f"task_{self.task_id}/train/acc", 177 self.acc_training_epoch.compute(), 178 prog_bar=True, 179 ) 180 181 def on_train_epoch_end( 182 self, 183 trainer: Trainer, 184 pl_module: CLAlgorithm, 185 ) -> None: 186 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 187 188 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 189 pl_module.log( 190 f"task_{self.task_id}/learning_curve/train/loss_cls", 191 self.loss_cls_training_epoch.compute(), 192 on_epoch=True, 193 prog_bar=True, 194 ) 195 pl_module.log( 196 f"task_{self.task_id}/learning_curve/train/acc", 197 self.acc_training_epoch.compute(), 198 on_epoch=True, 199 prog_bar=True, 200 ) 201 202 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 203 self.loss_cls_training_epoch.reset() 204 self.loss_training_epoch.reset() 205 self.acc_training_epoch.reset() 206 207 def on_validation_batch_end( 208 self, 209 trainer: Trainer, 210 pl_module: CLAlgorithm, 211 outputs: dict[str, Any], 212 batch: Any, 213 batch_idx: int, 214 ) -> None: 215 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 216 217 **Args:** 218 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 219 - **batch** (`Any`): the validation data batch. 220 """ 221 222 # get the batch size 223 batch_size = len(batch) 224 225 # get the metrics values of the batch from the outputs 226 loss_cls_batch = outputs["loss_cls"] 227 acc_batch = outputs["acc"] 228 229 # update the accumulated metrics in order to calculate the validation metrics 230 self.loss_cls_val.update(loss_cls_batch, batch_size) 231 self.acc_val.update(acc_batch, batch_size) 232 233 def on_validation_epoch_end( 234 self, 235 trainer: Trainer, 236 pl_module: CLAlgorithm, 237 ) -> None: 238 r"""Log validation metrics to plot learning curves.""" 239 240 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 241 pl_module.log( 242 f"task_{self.task_id}/learning_curve/val/loss_cls", 243 self.loss_cls_val.compute(), 244 on_epoch=True, 245 prog_bar=True, 246 ) 247 pl_module.log( 248 f"task_{self.task_id}/learning_curve/val/acc", 249 self.acc_val.compute(), 250 on_epoch=True, 251 prog_bar=True, 252 ) 253 254 def on_test_start( 255 self, 256 trainer: Trainer, 257 pl_module: CLAlgorithm, 258 ) -> None: 259 r"""Initialise the metrics for testing each seen task in the beginning of a task's testing.""" 260 261 # set the current task_id again (double checking) from the `CLAlgorithm` object 262 self.task_id = pl_module.task_id 263 264 # initialise test metrics for current and previous tasks 265 self.loss_cls_test = { 266 f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1) 267 } 268 self.acc_test = { 269 f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1) 270 } 271 272 def on_test_batch_end( 273 self, 274 trainer: Trainer, 275 pl_module: CLAlgorithm, 276 outputs: dict[str, Any], 277 batch: Any, 278 batch_idx: int, 279 dataloader_idx: int = 0, 280 ) -> None: 281 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 282 283 **Args:** 284 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 285 - **batch** (`Any`): the validation data batch. 286 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 287 """ 288 289 # get the batch size 290 batch_size = len(batch) 291 292 task_id = dataloader_idx + 1 # our `task_id` system starts from 1 293 294 # get the metrics values of the batch from the outputs 295 loss_cls_batch = outputs["loss_cls"] 296 acc_batch = outputs["acc"] 297 298 # update the accumulated metrics in order to calculate the metrics of the epoch 299 self.loss_cls_test[f"{task_id}"].update(loss_cls_batch, batch_size) 300 self.acc_test[f"{task_id}"].update(acc_batch, batch_size) 301 302 def on_test_epoch_end( 303 self, 304 trainer: Trainer, 305 pl_module: CLAlgorithm, 306 ) -> None: 307 r"""Save and plot test metrics at the end of test.""" 308 309 # save (update) the test metrics to CSV files 310 save.update_test_acc_to_csv( 311 test_acc_metric=self.acc_test, 312 csv_path=self.test_acc_csv_path, 313 ) 314 save.update_test_loss_cls_to_csv( 315 test_loss_cls_metric=self.loss_cls_test, 316 csv_path=self.test_loss_cls_csv_path, 317 ) 318 319 # plot the test metrics 320 plot.plot_test_acc_matrix_from_csv( 321 csv_path=self.test_acc_csv_path, 322 task_id=self.task_id, 323 plot_path=self.test_acc_matrix_plot_path, 324 ) 325 plot.plot_test_loss_cls_matrix_from_csv( 326 csv_path=self.test_loss_cls_csv_path, 327 task_id=self.task_id, 328 plot_path=self.test_loss_cls_matrix_plot_path, 329 ) 330 plot.plot_test_ave_acc_curve_from_csv( 331 csv_path=self.test_acc_csv_path, 332 task_id=self.task_id, 333 plot_path=self.test_ave_acc_plot_path, 334 ) 335 plot.plot_test_ave_loss_cls_curve_from_csv( 336 csv_path=self.test_loss_cls_csv_path, 337 task_id=self.task_id, 338 plot_path=self.test_ave_loss_cls_plot_path, 339 )
21class CLMetricsCallback(Callback): 22 r"""Provides all actions that are related to CL metrics, which include: 23 24 - Defining, initialising and recording metrics. 25 - Logging training and validation metrics to Lightning loggers in real time. 26 - Saving test metrics to files. 27 - Visualising test metrics as plots. 28 29 Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn what continual learning metrics mean. 30 31 Lightning provides `self.log()` to log metrics in `LightningModule` where our `CLAlgorithm` based. You can put `self.log()` here if you don't want to mess up the `CLAlgorithm` with a huge amount of logging codes. 32 33 The callback is able to produce the following outputs: 34 35 - CSV files for test accuracy and classification loss (lower triangular) matrix, average accuracy and classification loss. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 36 - Coloured plot for test accuracy and classification loss (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 37 - Curve plots for test average accuracy and classification loss over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details. 38 39 40 """ 41 42 def __init__( 43 self, 44 save_dir: str, 45 test_acc_csv_name: str | None = None, 46 test_loss_cls_csv_name: str | None = None, 47 test_acc_matrix_plot_name: str | None = None, 48 test_loss_cls_matrix_plot_name: str | None = None, 49 test_ave_acc_plot_name: str | None = None, 50 test_ave_loss_cls_plot_name: str | None = None, 51 ) -> None: 52 r"""Initialise the `CLMetricsCallback`. 53 54 **Args:** 55 - **save_dir** (`str` | `None`): the directory to save the test metrics files and plots. Better inside the output folder. 56 - **test_acc_csv_name** (`str` | `None`): file name to save test accuracy matrix and average accuracy as CSV file. If `None`, no file will be saved. 57 - **test_loss_cls_csv_name**(`str` | `None`): file name to save classification loss matrix and average classification loss as CSV file. If `None`, no file will be saved. 58 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 59 - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved. 60 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 61 - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved. 62 """ 63 Callback.__init__(self) 64 65 os.makedirs(save_dir, exist_ok=True) 66 67 # paths 68 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 69 r"""Store the path to save test accuracy matrix and average accuracy CSV file.""" 70 self.test_loss_cls_csv_path: str = os.path.join( 71 save_dir, test_loss_cls_csv_name 72 ) 73 r"""Store the path to save test classification loss and average accuracy CSV file.""" 74 self.test_acc_matrix_plot_path: str = os.path.join( 75 save_dir, test_acc_matrix_plot_name 76 ) 77 r"""Store the path to save test accuracy matrix plot.""" 78 self.test_loss_cls_matrix_plot_path: str = os.path.join( 79 save_dir, test_loss_cls_matrix_plot_name 80 ) 81 r"""Store the path to save test classification loss matrix plot.""" 82 self.test_ave_acc_plot_path: str = os.path.join( 83 save_dir, test_ave_acc_plot_name 84 ) 85 r"""Store the path to save test average accuracy curve plot.""" 86 self.test_ave_loss_cls_plot_path: str = os.path.join( 87 save_dir, test_ave_loss_cls_plot_name 88 ) 89 r"""Store the path to save test average classification loss curve plot.""" 90 91 # training accumulated metrics 92 self.acc_training_epoch: MeanMetricBatch 93 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 94 self.loss_cls_training_epoch: MeanMetricBatch 95 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 96 self.loss_training_epoch: MeanMetricBatch 97 r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 98 99 # validation accumulated metrics 100 self.acc_val: MeanMetricBatch 101 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 102 self.loss_cls_val: MeanMetricBatch 103 r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 104 105 # test accumulated metrics 106 self.acc_test: dict[str, MeanMetricBatch] 107 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 108 self.loss_cls_test: dict[str, MeanMetricBatch] 109 r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 110 111 self.task_id: int 112 r"""Task ID counter indicating which task is being processed. Self updated during the task loop.""" 113 114 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 115 r"""Initialise training and validation metrics.""" 116 117 # set the current task_id from the `CLAlgorithm` object 118 self.task_id = pl_module.task_id 119 120 # initialise training metrics 121 self.loss_cls_training_epoch = MeanMetricBatch() 122 self.loss_training_epoch = MeanMetricBatch() 123 self.acc_training_epoch = MeanMetricBatch() 124 125 # initialise validation metrics 126 self.loss_cls_val = MeanMetricBatch() 127 self.acc_val = MeanMetricBatch() 128 129 def on_train_batch_end( 130 self, 131 trainer: Trainer, 132 pl_module: CLAlgorithm, 133 outputs: dict[str, Any], 134 batch: Any, 135 batch_idx: int, 136 ) -> None: 137 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 138 139 **Args:** 140 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 141 - **batch** (`Any`): the training data batch. 142 """ 143 # get the batch size 144 batch_size = len(batch) 145 146 # get training metrics values of current training batch from the outputs of the `training_step()` 147 loss_cls_batch = outputs["loss_cls"] 148 loss_batch = outputs["loss"] 149 acc_batch = outputs["acc"] 150 151 # update accumulated training metrics to calculate training metrics of the epoch 152 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 153 self.loss_training_epoch.update(loss_cls_batch, batch_size) 154 self.acc_training_epoch.update(acc_batch, batch_size) 155 156 # log training metrics of current training batch to Lightning loggers 157 pl_module.log( 158 f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True 159 ) 160 pl_module.log( 161 f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True 162 ) 163 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 164 165 # log accumulated training metrics till this training batch to Lightning loggers 166 pl_module.log( 167 f"task_{self.task_id}/train/loss_cls", 168 self.loss_cls_training_epoch.compute(), 169 prog_bar=True, 170 ) 171 pl_module.log( 172 f"task_{self.task_id}/train/loss", 173 self.loss_training_epoch.compute(), 174 prog_bar=True, 175 ) 176 pl_module.log( 177 f"task_{self.task_id}/train/acc", 178 self.acc_training_epoch.compute(), 179 prog_bar=True, 180 ) 181 182 def on_train_epoch_end( 183 self, 184 trainer: Trainer, 185 pl_module: CLAlgorithm, 186 ) -> None: 187 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 188 189 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 190 pl_module.log( 191 f"task_{self.task_id}/learning_curve/train/loss_cls", 192 self.loss_cls_training_epoch.compute(), 193 on_epoch=True, 194 prog_bar=True, 195 ) 196 pl_module.log( 197 f"task_{self.task_id}/learning_curve/train/acc", 198 self.acc_training_epoch.compute(), 199 on_epoch=True, 200 prog_bar=True, 201 ) 202 203 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 204 self.loss_cls_training_epoch.reset() 205 self.loss_training_epoch.reset() 206 self.acc_training_epoch.reset() 207 208 def on_validation_batch_end( 209 self, 210 trainer: Trainer, 211 pl_module: CLAlgorithm, 212 outputs: dict[str, Any], 213 batch: Any, 214 batch_idx: int, 215 ) -> None: 216 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 217 218 **Args:** 219 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 220 - **batch** (`Any`): the validation data batch. 221 """ 222 223 # get the batch size 224 batch_size = len(batch) 225 226 # get the metrics values of the batch from the outputs 227 loss_cls_batch = outputs["loss_cls"] 228 acc_batch = outputs["acc"] 229 230 # update the accumulated metrics in order to calculate the validation metrics 231 self.loss_cls_val.update(loss_cls_batch, batch_size) 232 self.acc_val.update(acc_batch, batch_size) 233 234 def on_validation_epoch_end( 235 self, 236 trainer: Trainer, 237 pl_module: CLAlgorithm, 238 ) -> None: 239 r"""Log validation metrics to plot learning curves.""" 240 241 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 242 pl_module.log( 243 f"task_{self.task_id}/learning_curve/val/loss_cls", 244 self.loss_cls_val.compute(), 245 on_epoch=True, 246 prog_bar=True, 247 ) 248 pl_module.log( 249 f"task_{self.task_id}/learning_curve/val/acc", 250 self.acc_val.compute(), 251 on_epoch=True, 252 prog_bar=True, 253 ) 254 255 def on_test_start( 256 self, 257 trainer: Trainer, 258 pl_module: CLAlgorithm, 259 ) -> None: 260 r"""Initialise the metrics for testing each seen task in the beginning of a task's testing.""" 261 262 # set the current task_id again (double checking) from the `CLAlgorithm` object 263 self.task_id = pl_module.task_id 264 265 # initialise test metrics for current and previous tasks 266 self.loss_cls_test = { 267 f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1) 268 } 269 self.acc_test = { 270 f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1) 271 } 272 273 def on_test_batch_end( 274 self, 275 trainer: Trainer, 276 pl_module: CLAlgorithm, 277 outputs: dict[str, Any], 278 batch: Any, 279 batch_idx: int, 280 dataloader_idx: int = 0, 281 ) -> None: 282 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 283 284 **Args:** 285 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 286 - **batch** (`Any`): the validation data batch. 287 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 288 """ 289 290 # get the batch size 291 batch_size = len(batch) 292 293 task_id = dataloader_idx + 1 # our `task_id` system starts from 1 294 295 # get the metrics values of the batch from the outputs 296 loss_cls_batch = outputs["loss_cls"] 297 acc_batch = outputs["acc"] 298 299 # update the accumulated metrics in order to calculate the metrics of the epoch 300 self.loss_cls_test[f"{task_id}"].update(loss_cls_batch, batch_size) 301 self.acc_test[f"{task_id}"].update(acc_batch, batch_size) 302 303 def on_test_epoch_end( 304 self, 305 trainer: Trainer, 306 pl_module: CLAlgorithm, 307 ) -> None: 308 r"""Save and plot test metrics at the end of test.""" 309 310 # save (update) the test metrics to CSV files 311 save.update_test_acc_to_csv( 312 test_acc_metric=self.acc_test, 313 csv_path=self.test_acc_csv_path, 314 ) 315 save.update_test_loss_cls_to_csv( 316 test_loss_cls_metric=self.loss_cls_test, 317 csv_path=self.test_loss_cls_csv_path, 318 ) 319 320 # plot the test metrics 321 plot.plot_test_acc_matrix_from_csv( 322 csv_path=self.test_acc_csv_path, 323 task_id=self.task_id, 324 plot_path=self.test_acc_matrix_plot_path, 325 ) 326 plot.plot_test_loss_cls_matrix_from_csv( 327 csv_path=self.test_loss_cls_csv_path, 328 task_id=self.task_id, 329 plot_path=self.test_loss_cls_matrix_plot_path, 330 ) 331 plot.plot_test_ave_acc_curve_from_csv( 332 csv_path=self.test_acc_csv_path, 333 task_id=self.task_id, 334 plot_path=self.test_ave_acc_plot_path, 335 ) 336 plot.plot_test_ave_loss_cls_curve_from_csv( 337 csv_path=self.test_loss_cls_csv_path, 338 task_id=self.task_id, 339 plot_path=self.test_ave_loss_cls_plot_path, 340 )
Provides all actions that are related to CL metrics, which include:
- Defining, initialising and recording metrics.
- Logging training and validation metrics to Lightning loggers in real time.
- Saving test metrics to files.
- Visualising test metrics as plots.
Please refer to the A Summary of Continual Learning Metrics to learn what continual learning metrics mean.
Lightning provides self.log()
to log metrics in LightningModule
where our CLAlgorithm
based. You can put self.log()
here if you don't want to mess up the CLAlgorithm
with a huge amount of logging codes.
The callback is able to produce the following outputs:
- CSV files for test accuracy and classification loss (lower triangular) matrix, average accuracy and classification loss. See here for details.
- Coloured plot for test accuracy and classification loss (lower triangular) matrix. See here for details.
- Curve plots for test average accuracy and classification loss over different training tasks. See here for details.
42 def __init__( 43 self, 44 save_dir: str, 45 test_acc_csv_name: str | None = None, 46 test_loss_cls_csv_name: str | None = None, 47 test_acc_matrix_plot_name: str | None = None, 48 test_loss_cls_matrix_plot_name: str | None = None, 49 test_ave_acc_plot_name: str | None = None, 50 test_ave_loss_cls_plot_name: str | None = None, 51 ) -> None: 52 r"""Initialise the `CLMetricsCallback`. 53 54 **Args:** 55 - **save_dir** (`str` | `None`): the directory to save the test metrics files and plots. Better inside the output folder. 56 - **test_acc_csv_name** (`str` | `None`): file name to save test accuracy matrix and average accuracy as CSV file. If `None`, no file will be saved. 57 - **test_loss_cls_csv_name**(`str` | `None`): file name to save classification loss matrix and average classification loss as CSV file. If `None`, no file will be saved. 58 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 59 - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved. 60 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 61 - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved. 62 """ 63 Callback.__init__(self) 64 65 os.makedirs(save_dir, exist_ok=True) 66 67 # paths 68 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 69 r"""Store the path to save test accuracy matrix and average accuracy CSV file.""" 70 self.test_loss_cls_csv_path: str = os.path.join( 71 save_dir, test_loss_cls_csv_name 72 ) 73 r"""Store the path to save test classification loss and average accuracy CSV file.""" 74 self.test_acc_matrix_plot_path: str = os.path.join( 75 save_dir, test_acc_matrix_plot_name 76 ) 77 r"""Store the path to save test accuracy matrix plot.""" 78 self.test_loss_cls_matrix_plot_path: str = os.path.join( 79 save_dir, test_loss_cls_matrix_plot_name 80 ) 81 r"""Store the path to save test classification loss matrix plot.""" 82 self.test_ave_acc_plot_path: str = os.path.join( 83 save_dir, test_ave_acc_plot_name 84 ) 85 r"""Store the path to save test average accuracy curve plot.""" 86 self.test_ave_loss_cls_plot_path: str = os.path.join( 87 save_dir, test_ave_loss_cls_plot_name 88 ) 89 r"""Store the path to save test average classification loss curve plot.""" 90 91 # training accumulated metrics 92 self.acc_training_epoch: MeanMetricBatch 93 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 94 self.loss_cls_training_epoch: MeanMetricBatch 95 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 96 self.loss_training_epoch: MeanMetricBatch 97 r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 98 99 # validation accumulated metrics 100 self.acc_val: MeanMetricBatch 101 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 102 self.loss_cls_val: MeanMetricBatch 103 r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 104 105 # test accumulated metrics 106 self.acc_test: dict[str, MeanMetricBatch] 107 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 108 self.loss_cls_test: dict[str, MeanMetricBatch] 109 r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 110 111 self.task_id: int 112 r"""Task ID counter indicating which task is being processed. Self updated during the task loop."""
Initialise the CLMetricsCallback
.
Args:
- save_dir (
str
|None
): the directory to save the test metrics files and plots. Better inside the output folder. - test_acc_csv_name (
str
|None
): file name to save test accuracy matrix and average accuracy as CSV file. IfNone
, no file will be saved. - test_loss_cls_csv_name(
str
|None
): file name to save classification loss matrix and average classification loss as CSV file. IfNone
, no file will be saved. - test_acc_matrix_plot_name (
str
|None
): file name to save accuracy matrix plot. IfNone
, no file will be saved. - test_loss_cls_matrix_plot_name (
str
|None
): file name to save classification loss matrix plot. IfNone
, no file will be saved. - test_ave_acc_plot_name (
str
|None
): file name to save average accuracy as curve plot over different training tasks. IfNone
, no file will be saved. - test_ave_loss_cls_plot_name (
str
|None
): file name to save average classification loss as curve plot over different training tasks. IfNone
, no file will be saved.
Store the path to save test classification loss and average accuracy CSV file.
Store the path to save test average classification loss curve plot.
Classification accuracy of training epoch. Accumulated and calculated from the training batches. See here for details.
Classification loss of training epoch. Accumulated and calculated from the training batches. See here for details.
Total loss of training epoch. Accumulated and calculated from the training batches. See here for details.
Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See here for details.
Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See here for details.
Test classification accuracy of the current model (self.task_id
) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.
Test classification loss of the current model (self.task_id
) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs (string type) and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.
Task ID counter indicating which task is being processed. Self updated during the task loop.
114 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 115 r"""Initialise training and validation metrics.""" 116 117 # set the current task_id from the `CLAlgorithm` object 118 self.task_id = pl_module.task_id 119 120 # initialise training metrics 121 self.loss_cls_training_epoch = MeanMetricBatch() 122 self.loss_training_epoch = MeanMetricBatch() 123 self.acc_training_epoch = MeanMetricBatch() 124 125 # initialise validation metrics 126 self.loss_cls_val = MeanMetricBatch() 127 self.acc_val = MeanMetricBatch()
Initialise training and validation metrics.
129 def on_train_batch_end( 130 self, 131 trainer: Trainer, 132 pl_module: CLAlgorithm, 133 outputs: dict[str, Any], 134 batch: Any, 135 batch_idx: int, 136 ) -> None: 137 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 138 139 **Args:** 140 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 141 - **batch** (`Any`): the training data batch. 142 """ 143 # get the batch size 144 batch_size = len(batch) 145 146 # get training metrics values of current training batch from the outputs of the `training_step()` 147 loss_cls_batch = outputs["loss_cls"] 148 loss_batch = outputs["loss"] 149 acc_batch = outputs["acc"] 150 151 # update accumulated training metrics to calculate training metrics of the epoch 152 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 153 self.loss_training_epoch.update(loss_cls_batch, batch_size) 154 self.acc_training_epoch.update(acc_batch, batch_size) 155 156 # log training metrics of current training batch to Lightning loggers 157 pl_module.log( 158 f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True 159 ) 160 pl_module.log( 161 f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True 162 ) 163 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 164 165 # log accumulated training metrics till this training batch to Lightning loggers 166 pl_module.log( 167 f"task_{self.task_id}/train/loss_cls", 168 self.loss_cls_training_epoch.compute(), 169 prog_bar=True, 170 ) 171 pl_module.log( 172 f"task_{self.task_id}/train/loss", 173 self.loss_training_epoch.compute(), 174 prog_bar=True, 175 ) 176 pl_module.log( 177 f"task_{self.task_id}/train/acc", 178 self.acc_training_epoch.compute(), 179 prog_bar=True, 180 )
Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
Args:
- outputs (
dict[str, Any]
): the outputs of the training step, the returns of thetraining_step()
method in theCLAlgorithm
. - batch (
Any
): the training data batch.
182 def on_train_epoch_end( 183 self, 184 trainer: Trainer, 185 pl_module: CLAlgorithm, 186 ) -> None: 187 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 188 189 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 190 pl_module.log( 191 f"task_{self.task_id}/learning_curve/train/loss_cls", 192 self.loss_cls_training_epoch.compute(), 193 on_epoch=True, 194 prog_bar=True, 195 ) 196 pl_module.log( 197 f"task_{self.task_id}/learning_curve/train/acc", 198 self.acc_training_epoch.compute(), 199 on_epoch=True, 200 prog_bar=True, 201 ) 202 203 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 204 self.loss_cls_training_epoch.reset() 205 self.loss_training_epoch.reset() 206 self.acc_training_epoch.reset()
Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.
208 def on_validation_batch_end( 209 self, 210 trainer: Trainer, 211 pl_module: CLAlgorithm, 212 outputs: dict[str, Any], 213 batch: Any, 214 batch_idx: int, 215 ) -> None: 216 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 217 218 **Args:** 219 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 220 - **batch** (`Any`): the validation data batch. 221 """ 222 223 # get the batch size 224 batch_size = len(batch) 225 226 # get the metrics values of the batch from the outputs 227 loss_cls_batch = outputs["loss_cls"] 228 acc_batch = outputs["acc"] 229 230 # update the accumulated metrics in order to calculate the validation metrics 231 self.loss_cls_val.update(loss_cls_batch, batch_size) 232 self.acc_val.update(acc_batch, batch_size)
Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
Args:
- outputs (
dict[str, Any]
): the outputs of the validation step, which is the returns of thevalidation_step()
method in theCLAlgorithm
. - batch (
Any
): the validation data batch.
234 def on_validation_epoch_end( 235 self, 236 trainer: Trainer, 237 pl_module: CLAlgorithm, 238 ) -> None: 239 r"""Log validation metrics to plot learning curves.""" 240 241 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 242 pl_module.log( 243 f"task_{self.task_id}/learning_curve/val/loss_cls", 244 self.loss_cls_val.compute(), 245 on_epoch=True, 246 prog_bar=True, 247 ) 248 pl_module.log( 249 f"task_{self.task_id}/learning_curve/val/acc", 250 self.acc_val.compute(), 251 on_epoch=True, 252 prog_bar=True, 253 )
Log validation metrics to plot learning curves.
255 def on_test_start( 256 self, 257 trainer: Trainer, 258 pl_module: CLAlgorithm, 259 ) -> None: 260 r"""Initialise the metrics for testing each seen task in the beginning of a task's testing.""" 261 262 # set the current task_id again (double checking) from the `CLAlgorithm` object 263 self.task_id = pl_module.task_id 264 265 # initialise test metrics for current and previous tasks 266 self.loss_cls_test = { 267 f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1) 268 } 269 self.acc_test = { 270 f"{task_id}": MeanMetricBatch() for task_id in range(1, self.task_id + 1) 271 }
Initialise the metrics for testing each seen task in the beginning of a task's testing.
273 def on_test_batch_end( 274 self, 275 trainer: Trainer, 276 pl_module: CLAlgorithm, 277 outputs: dict[str, Any], 278 batch: Any, 279 batch_idx: int, 280 dataloader_idx: int = 0, 281 ) -> None: 282 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 283 284 **Args:** 285 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 286 - **batch** (`Any`): the validation data batch. 287 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 288 """ 289 290 # get the batch size 291 batch_size = len(batch) 292 293 task_id = dataloader_idx + 1 # our `task_id` system starts from 1 294 295 # get the metrics values of the batch from the outputs 296 loss_cls_batch = outputs["loss_cls"] 297 acc_batch = outputs["acc"] 298 299 # update the accumulated metrics in order to calculate the metrics of the epoch 300 self.loss_cls_test[f"{task_id}"].update(loss_cls_batch, batch_size) 301 self.acc_test[f"{task_id}"].update(acc_batch, batch_size)
Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
Args:
- outputs (
dict[str, Any]
): the outputs of the test step, which is the returns of thetest_step()
method in theCLAlgorithm
. - batch (
Any
): the validation data batch. - dataloader_idx (
int
): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.
303 def on_test_epoch_end( 304 self, 305 trainer: Trainer, 306 pl_module: CLAlgorithm, 307 ) -> None: 308 r"""Save and plot test metrics at the end of test.""" 309 310 # save (update) the test metrics to CSV files 311 save.update_test_acc_to_csv( 312 test_acc_metric=self.acc_test, 313 csv_path=self.test_acc_csv_path, 314 ) 315 save.update_test_loss_cls_to_csv( 316 test_loss_cls_metric=self.loss_cls_test, 317 csv_path=self.test_loss_cls_csv_path, 318 ) 319 320 # plot the test metrics 321 plot.plot_test_acc_matrix_from_csv( 322 csv_path=self.test_acc_csv_path, 323 task_id=self.task_id, 324 plot_path=self.test_acc_matrix_plot_path, 325 ) 326 plot.plot_test_loss_cls_matrix_from_csv( 327 csv_path=self.test_loss_cls_csv_path, 328 task_id=self.task_id, 329 plot_path=self.test_loss_cls_matrix_plot_path, 330 ) 331 plot.plot_test_ave_acc_curve_from_csv( 332 csv_path=self.test_acc_csv_path, 333 task_id=self.task_id, 334 plot_path=self.test_ave_acc_plot_path, 335 ) 336 plot.plot_test_ave_loss_cls_curve_from_csv( 337 csv_path=self.test_loss_cls_csv_path, 338 task_id=self.task_id, 339 plot_path=self.test_ave_loss_cls_plot_path, 340 )
Save and plot test metrics at the end of test.