clarena.metrics.cl_loss
The submodule in metrics for CLLoss.
1r""" 2The submodule in `metrics` for `CLLoss`. 3""" 4 5__all__ = ["CLLoss"] 6 7import csv 8import logging 9import os 10from typing import Any 11 12import pandas as pd 13from lightning import Trainer 14from lightning.pytorch.utilities import rank_zero_only 15from matplotlib import pyplot as plt 16from torchmetrics import MeanMetric 17 18from clarena.cl_algorithms import CLAlgorithm 19from clarena.metrics import MetricCallback 20from clarena.utils.metrics import MeanMetricBatch 21 22# always get logger for built-in logging in each module 23pylogger = logging.getLogger(__name__) 24 25 26class CLLoss(MetricCallback): 27 r"""Provides all actions that are related to CL loss metrics, which include: 28 29 - Defining, initializing and recording loss metrics. 30 - Logging training and validation loss metrics to Lightning loggers in real time. 31 - Saving test loss metrics to files. 32 - Visualizing test loss metrics as plots. 33 34 35 The callback is able to produce the following outputs: 36 37 - CSV files for classification loss (lower triangular) matrix and average classification loss. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 38 - Coloured plot for test classification loss (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 39 - Curve plots for test average classification loss over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details. 40 41 Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric. 42 """ 43 44 def __init__( 45 self, 46 save_dir: str, 47 test_loss_cls_csv_name: str = "loss_cls.csv", 48 test_loss_cls_matrix_plot_name: str | None = None, 49 test_ave_loss_cls_plot_name: str | None = None, 50 ) -> None: 51 r""" 52 **Args:** 53 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 54 - **test_loss_cls_csv_name**(`str`): file name to save classification loss matrix and average classification loss as CSV file. 55 - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved. 56 - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved. 57 """ 58 super().__init__(save_dir=save_dir) 59 60 self.test_loss_cls_csv_path: str = os.path.join( 61 save_dir, test_loss_cls_csv_name 62 ) 63 r"""The path to save test classification loss matrix and average classification loss CSV file.""" 64 if test_loss_cls_matrix_plot_name: 65 self.test_loss_cls_matrix_plot_path: str = os.path.join( 66 save_dir, test_loss_cls_matrix_plot_name 67 ) 68 r"""The path to save test classification loss matrix plot.""" 69 if test_ave_loss_cls_plot_name: 70 self.test_ave_loss_cls_plot_path: str = os.path.join( 71 save_dir, test_ave_loss_cls_plot_name 72 ) 73 r"""The path to save test average classification loss curve plot.""" 74 75 # training accumulated metrics 76 self.loss_cls_training_epoch: MeanMetricBatch 77 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 78 self.loss_training_epoch: MeanMetricBatch 79 r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 80 81 # validation accumulated metrics 82 self.loss_cls_val: MeanMetricBatch 83 r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 84 85 # test accumulated metrics 86 self.loss_cls_test: dict[int, MeanMetricBatch] 87 r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 88 89 # task ID control 90 self.task_id: int 91 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 92 93 @rank_zero_only 94 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 95 r"""Initialize training and validation metrics.""" 96 97 # set the current task_id from the `CLAlgorithm` object 98 self.task_id = pl_module.task_id 99 100 # get the device to put the metrics on the same device 101 device = pl_module.device 102 103 # initialize training metrics 104 self.loss_cls_training_epoch = MeanMetricBatch().to(device) 105 self.loss_training_epoch = MeanMetricBatch().to(device) 106 107 # initialize validation metrics 108 self.loss_cls_val = MeanMetricBatch().to(device) 109 110 @rank_zero_only 111 def on_train_batch_end( 112 self, 113 trainer: Trainer, 114 pl_module: CLAlgorithm, 115 outputs: dict[str, Any], 116 batch: Any, 117 batch_idx: int, 118 ) -> None: 119 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 120 121 **Args:** 122 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 123 - **batch** (`Any`): the training data batch. 124 """ 125 # get the batch size 126 batch_size = len(batch) 127 128 # get training metrics values of current training batch from the outputs of the `training_step()` 129 loss_cls_batch = outputs["loss_cls"] 130 loss_batch = outputs["loss"] 131 132 # update accumulated training metrics to calculate training metrics of the epoch 133 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 134 self.loss_training_epoch.update(loss_batch, batch_size) 135 136 # log training metrics of current training batch to Lightning loggers 137 pl_module.log( 138 f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True 139 ) 140 pl_module.log( 141 f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True 142 ) 143 144 # log accumulated training metrics till this training batch to Lightning loggers 145 pl_module.log( 146 f"task_{self.task_id}/train/loss_cls", 147 self.loss_cls_training_epoch.compute(), 148 prog_bar=True, 149 ) 150 pl_module.log( 151 f"task_{self.task_id}/train/loss", 152 self.loss_training_epoch.compute(), 153 prog_bar=True, 154 ) 155 156 @rank_zero_only 157 def on_train_epoch_end( 158 self, 159 trainer: Trainer, 160 pl_module: CLAlgorithm, 161 ) -> None: 162 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 163 164 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 165 pl_module.log( 166 f"task_{self.task_id}/learning_curve/train/loss_cls", 167 self.loss_cls_training_epoch.compute(), 168 on_epoch=True, 169 prog_bar=True, 170 ) 171 172 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 173 self.loss_cls_training_epoch.reset() 174 self.loss_training_epoch.reset() 175 176 @rank_zero_only 177 def on_validation_batch_end( 178 self, 179 trainer: Trainer, 180 pl_module: CLAlgorithm, 181 outputs: dict[str, Any], 182 batch: Any, 183 batch_idx: int, 184 ) -> None: 185 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 186 187 **Args:** 188 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 189 - **batch** (`Any`): the validation data batch. 190 """ 191 192 # get the batch size 193 batch_size = len(batch) 194 195 # get the metrics values of the batch from the outputs 196 loss_cls_batch = outputs["loss_cls"] 197 198 # update the accumulated metrics in order to calculate the validation metrics 199 self.loss_cls_val.update(loss_cls_batch, batch_size) 200 201 @rank_zero_only 202 def on_validation_epoch_end( 203 self, 204 trainer: Trainer, 205 pl_module: CLAlgorithm, 206 ) -> None: 207 r"""Log validation metrics to plot learning curves.""" 208 209 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 210 pl_module.log( 211 f"task_{self.task_id}/learning_curve/val/loss_cls", 212 self.loss_cls_val.compute(), 213 on_epoch=True, 214 prog_bar=True, 215 ) 216 217 @rank_zero_only 218 def on_test_start( 219 self, 220 trainer: Trainer, 221 pl_module: CLAlgorithm, 222 ) -> None: 223 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 224 225 # set the current task_id again (double checking) from the `CLAlgorithm` object 226 self.task_id = pl_module.task_id 227 228 # get the device to put the metrics on the same device 229 device = pl_module.device 230 231 # initialize test metrics for current and previous tasks 232 self.loss_cls_test = { 233 task_id: MeanMetricBatch().to(device) 234 for task_id in pl_module.processed_task_ids 235 } 236 237 @rank_zero_only 238 def on_test_batch_end( 239 self, 240 trainer: Trainer, 241 pl_module: CLAlgorithm, 242 outputs: dict[str, Any], 243 batch: Any, 244 batch_idx: int, 245 dataloader_idx: int = 0, 246 ) -> None: 247 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 248 249 **Args:** 250 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 251 - **batch** (`Any`): the test data batch. 252 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 253 """ 254 255 # get the batch size 256 batch_size = len(batch) 257 258 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 259 260 # get the metrics values of the batch from the outputs 261 loss_cls_batch = outputs["loss_cls"] 262 263 # update the accumulated metrics in order to calculate the metrics of the epoch 264 self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size) 265 266 @rank_zero_only 267 def on_test_epoch_end( 268 self, 269 trainer: Trainer, 270 pl_module: CLAlgorithm, 271 ) -> None: 272 r"""Save and plot test metrics at the end of test.""" 273 274 # save (update) the test metrics to CSV files 275 self.update_test_loss_cls_to_csv( 276 after_training_task_id=self.task_id, 277 csv_path=self.test_loss_cls_csv_path, 278 ) 279 280 # plot the test metrics 281 if hasattr(self, "test_loss_cls_matrix_plot_path"): 282 self.plot_test_loss_cls_matrix_from_csv( 283 csv_path=self.test_loss_cls_csv_path, 284 plot_path=self.test_loss_cls_matrix_plot_path, 285 ) 286 if hasattr(self, "test_ave_loss_cls_plot_path"): 287 self.plot_test_ave_loss_cls_curve_from_csv( 288 csv_path=self.test_loss_cls_csv_path, 289 plot_path=self.test_ave_loss_cls_plot_path, 290 ) 291 292 def update_test_loss_cls_to_csv( 293 self, 294 after_training_task_id: int, 295 csv_path: str, 296 ) -> None: 297 """Update the test classification loss metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 298 299 **Args:** 300 - **after_training_task_id** (`int`): the task ID after training. 301 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'. 302 """ 303 processed_task_ids = list(self.loss_cls_test.keys()) 304 fieldnames = ["after_training_task", "average_classification_loss"] + [ 305 f"test_on_task_{task_id}" for task_id in processed_task_ids 306 ] 307 308 new_line = { 309 "after_training_task": after_training_task_id 310 } # construct the first column 311 312 # write to the columns and calculate the average classification loss over tasks at the same time 313 average_classification_loss_over_tasks = MeanMetric().to( 314 device=next(iter(self.loss_cls_test.values())).device 315 ) 316 for task_id in processed_task_ids: 317 loss_cls = self.loss_cls_test[task_id].compute().item() 318 new_line[f"test_on_task_{task_id}"] = loss_cls 319 average_classification_loss_over_tasks(loss_cls) 320 new_line["average_classification_loss"] = ( 321 average_classification_loss_over_tasks.compute().item() 322 ) 323 324 # write to the csv file 325 is_first = not os.path.exists(csv_path) 326 if not is_first: 327 with open(csv_path, "r", encoding="utf-8") as file: 328 lines = file.readlines() 329 del lines[0] 330 # write header 331 with open(csv_path, "w", encoding="utf-8") as file: 332 writer = csv.DictWriter(file, fieldnames=fieldnames) 333 writer.writeheader() 334 # write metrics 335 with open(csv_path, "a", encoding="utf-8") as file: 336 if not is_first: 337 file.writelines(lines) # write the previous lines 338 writer = csv.DictWriter(file, fieldnames=fieldnames) 339 writer.writerow(new_line) 340 341 def plot_test_loss_cls_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 342 """Plot the test classification loss matrix from saved CSV file and save the plot to the designated directory. 343 344 **Args:** 345 - **csv_path** (`str`): the path to the CSV file where the `utils.update_loss_cls_to_csv()` saved the test classification loss metric. 346 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls_matrix.png'. 347 """ 348 data = pd.read_csv(csv_path) 349 processed_task_ids = [ 350 int(col.replace("test_on_task_", "")) 351 for col in data.columns 352 if col.startswith("test_on_task_") 353 ] 354 355 # Get all columns that start with "test_on_task_" 356 test_task_cols = [ 357 col for col in data.columns if col.startswith("test_on_task_") 358 ] 359 num_tasks = len(processed_task_ids) 360 num_rows = len(data) 361 362 # Build the loss matrix 363 loss_matrix = data[test_task_cols].values 364 365 fig, ax = plt.subplots( 366 figsize=(2 * num_tasks, 2 * num_rows) 367 ) # adaptive figure size 368 369 cax = ax.imshow( 370 loss_matrix, 371 interpolation="nearest", 372 cmap="Greens", 373 aspect="auto", 374 ) 375 376 colorbar = fig.colorbar(cax) 377 yticks = colorbar.ax.get_yticks() 378 colorbar.ax.set_yticks(yticks) 379 colorbar.ax.set_yticklabels( 380 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 381 ) 382 383 # Annotate each cell 384 for r in range(num_rows): 385 for c in range(r + 1): 386 ax.text( 387 c, 388 r, 389 f"{loss_matrix[r, c]:.3f}", 390 ha="center", 391 va="center", 392 color="black", 393 fontsize=10 + num_tasks, 394 ) 395 396 ax.set_xticks(range(num_tasks)) 397 ax.set_yticks(range(num_rows)) 398 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 399 ax.set_yticklabels( 400 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 401 ) 402 403 # Labeling the axes 404 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 405 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 406 fig.tight_layout() 407 fig.savefig(plot_path) 408 plt.close(fig) 409 410 def plot_test_ave_loss_cls_curve_from_csv( 411 self, csv_path: str, plot_path: str 412 ) -> None: 413 """Plot the test average classfication loss curve over different training tasks from saved CSV file and save the plot to the designated directory. 414 415 **Args:** 416 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test classfication loss metric. 417 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_loss_cls.png'. 418 """ 419 data = pd.read_csv(csv_path) 420 after_training_tasks = data["after_training_task"].astype(int).tolist() 421 422 # plot the average accuracy curve over different training tasks 423 fig, ax = plt.subplots(figsize=(16, 9)) 424 ax.plot( 425 after_training_tasks, 426 data["average_classification_loss"], 427 marker="o", 428 linewidth=2, 429 ) 430 ax.set_xlabel("After training task $t$", fontsize=16) 431 ax.set_ylabel("Average Classification Loss", fontsize=16) 432 ax.grid(True) 433 xticks = after_training_tasks 434 yticks = [i * 0.05 for i in range(21)] 435 ax.set_xticks(xticks) 436 ax.set_yticks(yticks) 437 ax.set_xticklabels(xticks, fontsize=16) 438 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 439 fig.savefig(plot_path) 440 plt.close(fig)
27class CLLoss(MetricCallback): 28 r"""Provides all actions that are related to CL loss metrics, which include: 29 30 - Defining, initializing and recording loss metrics. 31 - Logging training and validation loss metrics to Lightning loggers in real time. 32 - Saving test loss metrics to files. 33 - Visualizing test loss metrics as plots. 34 35 36 The callback is able to produce the following outputs: 37 38 - CSV files for classification loss (lower triangular) matrix and average classification loss. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 39 - Coloured plot for test classification loss (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 40 - Curve plots for test average classification loss over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details. 41 42 Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric. 43 """ 44 45 def __init__( 46 self, 47 save_dir: str, 48 test_loss_cls_csv_name: str = "loss_cls.csv", 49 test_loss_cls_matrix_plot_name: str | None = None, 50 test_ave_loss_cls_plot_name: str | None = None, 51 ) -> None: 52 r""" 53 **Args:** 54 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 55 - **test_loss_cls_csv_name**(`str`): file name to save classification loss matrix and average classification loss as CSV file. 56 - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved. 57 - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved. 58 """ 59 super().__init__(save_dir=save_dir) 60 61 self.test_loss_cls_csv_path: str = os.path.join( 62 save_dir, test_loss_cls_csv_name 63 ) 64 r"""The path to save test classification loss matrix and average classification loss CSV file.""" 65 if test_loss_cls_matrix_plot_name: 66 self.test_loss_cls_matrix_plot_path: str = os.path.join( 67 save_dir, test_loss_cls_matrix_plot_name 68 ) 69 r"""The path to save test classification loss matrix plot.""" 70 if test_ave_loss_cls_plot_name: 71 self.test_ave_loss_cls_plot_path: str = os.path.join( 72 save_dir, test_ave_loss_cls_plot_name 73 ) 74 r"""The path to save test average classification loss curve plot.""" 75 76 # training accumulated metrics 77 self.loss_cls_training_epoch: MeanMetricBatch 78 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 79 self.loss_training_epoch: MeanMetricBatch 80 r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 81 82 # validation accumulated metrics 83 self.loss_cls_val: MeanMetricBatch 84 r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 85 86 # test accumulated metrics 87 self.loss_cls_test: dict[int, MeanMetricBatch] 88 r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 89 90 # task ID control 91 self.task_id: int 92 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 93 94 @rank_zero_only 95 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 96 r"""Initialize training and validation metrics.""" 97 98 # set the current task_id from the `CLAlgorithm` object 99 self.task_id = pl_module.task_id 100 101 # get the device to put the metrics on the same device 102 device = pl_module.device 103 104 # initialize training metrics 105 self.loss_cls_training_epoch = MeanMetricBatch().to(device) 106 self.loss_training_epoch = MeanMetricBatch().to(device) 107 108 # initialize validation metrics 109 self.loss_cls_val = MeanMetricBatch().to(device) 110 111 @rank_zero_only 112 def on_train_batch_end( 113 self, 114 trainer: Trainer, 115 pl_module: CLAlgorithm, 116 outputs: dict[str, Any], 117 batch: Any, 118 batch_idx: int, 119 ) -> None: 120 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 121 122 **Args:** 123 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 124 - **batch** (`Any`): the training data batch. 125 """ 126 # get the batch size 127 batch_size = len(batch) 128 129 # get training metrics values of current training batch from the outputs of the `training_step()` 130 loss_cls_batch = outputs["loss_cls"] 131 loss_batch = outputs["loss"] 132 133 # update accumulated training metrics to calculate training metrics of the epoch 134 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 135 self.loss_training_epoch.update(loss_batch, batch_size) 136 137 # log training metrics of current training batch to Lightning loggers 138 pl_module.log( 139 f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True 140 ) 141 pl_module.log( 142 f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True 143 ) 144 145 # log accumulated training metrics till this training batch to Lightning loggers 146 pl_module.log( 147 f"task_{self.task_id}/train/loss_cls", 148 self.loss_cls_training_epoch.compute(), 149 prog_bar=True, 150 ) 151 pl_module.log( 152 f"task_{self.task_id}/train/loss", 153 self.loss_training_epoch.compute(), 154 prog_bar=True, 155 ) 156 157 @rank_zero_only 158 def on_train_epoch_end( 159 self, 160 trainer: Trainer, 161 pl_module: CLAlgorithm, 162 ) -> None: 163 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 164 165 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 166 pl_module.log( 167 f"task_{self.task_id}/learning_curve/train/loss_cls", 168 self.loss_cls_training_epoch.compute(), 169 on_epoch=True, 170 prog_bar=True, 171 ) 172 173 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 174 self.loss_cls_training_epoch.reset() 175 self.loss_training_epoch.reset() 176 177 @rank_zero_only 178 def on_validation_batch_end( 179 self, 180 trainer: Trainer, 181 pl_module: CLAlgorithm, 182 outputs: dict[str, Any], 183 batch: Any, 184 batch_idx: int, 185 ) -> None: 186 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 187 188 **Args:** 189 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 190 - **batch** (`Any`): the validation data batch. 191 """ 192 193 # get the batch size 194 batch_size = len(batch) 195 196 # get the metrics values of the batch from the outputs 197 loss_cls_batch = outputs["loss_cls"] 198 199 # update the accumulated metrics in order to calculate the validation metrics 200 self.loss_cls_val.update(loss_cls_batch, batch_size) 201 202 @rank_zero_only 203 def on_validation_epoch_end( 204 self, 205 trainer: Trainer, 206 pl_module: CLAlgorithm, 207 ) -> None: 208 r"""Log validation metrics to plot learning curves.""" 209 210 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 211 pl_module.log( 212 f"task_{self.task_id}/learning_curve/val/loss_cls", 213 self.loss_cls_val.compute(), 214 on_epoch=True, 215 prog_bar=True, 216 ) 217 218 @rank_zero_only 219 def on_test_start( 220 self, 221 trainer: Trainer, 222 pl_module: CLAlgorithm, 223 ) -> None: 224 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 225 226 # set the current task_id again (double checking) from the `CLAlgorithm` object 227 self.task_id = pl_module.task_id 228 229 # get the device to put the metrics on the same device 230 device = pl_module.device 231 232 # initialize test metrics for current and previous tasks 233 self.loss_cls_test = { 234 task_id: MeanMetricBatch().to(device) 235 for task_id in pl_module.processed_task_ids 236 } 237 238 @rank_zero_only 239 def on_test_batch_end( 240 self, 241 trainer: Trainer, 242 pl_module: CLAlgorithm, 243 outputs: dict[str, Any], 244 batch: Any, 245 batch_idx: int, 246 dataloader_idx: int = 0, 247 ) -> None: 248 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 249 250 **Args:** 251 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 252 - **batch** (`Any`): the test data batch. 253 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 254 """ 255 256 # get the batch size 257 batch_size = len(batch) 258 259 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 260 261 # get the metrics values of the batch from the outputs 262 loss_cls_batch = outputs["loss_cls"] 263 264 # update the accumulated metrics in order to calculate the metrics of the epoch 265 self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size) 266 267 @rank_zero_only 268 def on_test_epoch_end( 269 self, 270 trainer: Trainer, 271 pl_module: CLAlgorithm, 272 ) -> None: 273 r"""Save and plot test metrics at the end of test.""" 274 275 # save (update) the test metrics to CSV files 276 self.update_test_loss_cls_to_csv( 277 after_training_task_id=self.task_id, 278 csv_path=self.test_loss_cls_csv_path, 279 ) 280 281 # plot the test metrics 282 if hasattr(self, "test_loss_cls_matrix_plot_path"): 283 self.plot_test_loss_cls_matrix_from_csv( 284 csv_path=self.test_loss_cls_csv_path, 285 plot_path=self.test_loss_cls_matrix_plot_path, 286 ) 287 if hasattr(self, "test_ave_loss_cls_plot_path"): 288 self.plot_test_ave_loss_cls_curve_from_csv( 289 csv_path=self.test_loss_cls_csv_path, 290 plot_path=self.test_ave_loss_cls_plot_path, 291 ) 292 293 def update_test_loss_cls_to_csv( 294 self, 295 after_training_task_id: int, 296 csv_path: str, 297 ) -> None: 298 """Update the test classification loss metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 299 300 **Args:** 301 - **after_training_task_id** (`int`): the task ID after training. 302 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'. 303 """ 304 processed_task_ids = list(self.loss_cls_test.keys()) 305 fieldnames = ["after_training_task", "average_classification_loss"] + [ 306 f"test_on_task_{task_id}" for task_id in processed_task_ids 307 ] 308 309 new_line = { 310 "after_training_task": after_training_task_id 311 } # construct the first column 312 313 # write to the columns and calculate the average classification loss over tasks at the same time 314 average_classification_loss_over_tasks = MeanMetric().to( 315 device=next(iter(self.loss_cls_test.values())).device 316 ) 317 for task_id in processed_task_ids: 318 loss_cls = self.loss_cls_test[task_id].compute().item() 319 new_line[f"test_on_task_{task_id}"] = loss_cls 320 average_classification_loss_over_tasks(loss_cls) 321 new_line["average_classification_loss"] = ( 322 average_classification_loss_over_tasks.compute().item() 323 ) 324 325 # write to the csv file 326 is_first = not os.path.exists(csv_path) 327 if not is_first: 328 with open(csv_path, "r", encoding="utf-8") as file: 329 lines = file.readlines() 330 del lines[0] 331 # write header 332 with open(csv_path, "w", encoding="utf-8") as file: 333 writer = csv.DictWriter(file, fieldnames=fieldnames) 334 writer.writeheader() 335 # write metrics 336 with open(csv_path, "a", encoding="utf-8") as file: 337 if not is_first: 338 file.writelines(lines) # write the previous lines 339 writer = csv.DictWriter(file, fieldnames=fieldnames) 340 writer.writerow(new_line) 341 342 def plot_test_loss_cls_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 343 """Plot the test classification loss matrix from saved CSV file and save the plot to the designated directory. 344 345 **Args:** 346 - **csv_path** (`str`): the path to the CSV file where the `utils.update_loss_cls_to_csv()` saved the test classification loss metric. 347 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls_matrix.png'. 348 """ 349 data = pd.read_csv(csv_path) 350 processed_task_ids = [ 351 int(col.replace("test_on_task_", "")) 352 for col in data.columns 353 if col.startswith("test_on_task_") 354 ] 355 356 # Get all columns that start with "test_on_task_" 357 test_task_cols = [ 358 col for col in data.columns if col.startswith("test_on_task_") 359 ] 360 num_tasks = len(processed_task_ids) 361 num_rows = len(data) 362 363 # Build the loss matrix 364 loss_matrix = data[test_task_cols].values 365 366 fig, ax = plt.subplots( 367 figsize=(2 * num_tasks, 2 * num_rows) 368 ) # adaptive figure size 369 370 cax = ax.imshow( 371 loss_matrix, 372 interpolation="nearest", 373 cmap="Greens", 374 aspect="auto", 375 ) 376 377 colorbar = fig.colorbar(cax) 378 yticks = colorbar.ax.get_yticks() 379 colorbar.ax.set_yticks(yticks) 380 colorbar.ax.set_yticklabels( 381 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 382 ) 383 384 # Annotate each cell 385 for r in range(num_rows): 386 for c in range(r + 1): 387 ax.text( 388 c, 389 r, 390 f"{loss_matrix[r, c]:.3f}", 391 ha="center", 392 va="center", 393 color="black", 394 fontsize=10 + num_tasks, 395 ) 396 397 ax.set_xticks(range(num_tasks)) 398 ax.set_yticks(range(num_rows)) 399 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 400 ax.set_yticklabels( 401 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 402 ) 403 404 # Labeling the axes 405 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 406 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 407 fig.tight_layout() 408 fig.savefig(plot_path) 409 plt.close(fig) 410 411 def plot_test_ave_loss_cls_curve_from_csv( 412 self, csv_path: str, plot_path: str 413 ) -> None: 414 """Plot the test average classfication loss curve over different training tasks from saved CSV file and save the plot to the designated directory. 415 416 **Args:** 417 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test classfication loss metric. 418 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_loss_cls.png'. 419 """ 420 data = pd.read_csv(csv_path) 421 after_training_tasks = data["after_training_task"].astype(int).tolist() 422 423 # plot the average accuracy curve over different training tasks 424 fig, ax = plt.subplots(figsize=(16, 9)) 425 ax.plot( 426 after_training_tasks, 427 data["average_classification_loss"], 428 marker="o", 429 linewidth=2, 430 ) 431 ax.set_xlabel("After training task $t$", fontsize=16) 432 ax.set_ylabel("Average Classification Loss", fontsize=16) 433 ax.grid(True) 434 xticks = after_training_tasks 435 yticks = [i * 0.05 for i in range(21)] 436 ax.set_xticks(xticks) 437 ax.set_yticks(yticks) 438 ax.set_xticklabels(xticks, fontsize=16) 439 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 440 fig.savefig(plot_path) 441 plt.close(fig)
Provides all actions that are related to CL loss metrics, which include:
- Defining, initializing and recording loss metrics.
- Logging training and validation loss metrics to Lightning loggers in real time.
- Saving test loss metrics to files.
- Visualizing test loss metrics as plots.
The callback is able to produce the following outputs:
- CSV files for classification loss (lower triangular) matrix and average classification loss. See here for details.
- Coloured plot for test classification loss (lower triangular) matrix. See here for details.
- Curve plots for test average classification loss over different training tasks. See here for details.
Please refer to the A Summary of Continual Learning Metrics to learn about this metric.
45 def __init__( 46 self, 47 save_dir: str, 48 test_loss_cls_csv_name: str = "loss_cls.csv", 49 test_loss_cls_matrix_plot_name: str | None = None, 50 test_ave_loss_cls_plot_name: str | None = None, 51 ) -> None: 52 r""" 53 **Args:** 54 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 55 - **test_loss_cls_csv_name**(`str`): file name to save classification loss matrix and average classification loss as CSV file. 56 - **test_loss_cls_matrix_plot_name** (`str` | `None`): file name to save classification loss matrix plot. If `None`, no file will be saved. 57 - **test_ave_loss_cls_plot_name** (`str` | `None`): file name to save average classification loss as curve plot over different training tasks. If `None`, no file will be saved. 58 """ 59 super().__init__(save_dir=save_dir) 60 61 self.test_loss_cls_csv_path: str = os.path.join( 62 save_dir, test_loss_cls_csv_name 63 ) 64 r"""The path to save test classification loss matrix and average classification loss CSV file.""" 65 if test_loss_cls_matrix_plot_name: 66 self.test_loss_cls_matrix_plot_path: str = os.path.join( 67 save_dir, test_loss_cls_matrix_plot_name 68 ) 69 r"""The path to save test classification loss matrix plot.""" 70 if test_ave_loss_cls_plot_name: 71 self.test_ave_loss_cls_plot_path: str = os.path.join( 72 save_dir, test_ave_loss_cls_plot_name 73 ) 74 r"""The path to save test average classification loss curve plot.""" 75 76 # training accumulated metrics 77 self.loss_cls_training_epoch: MeanMetricBatch 78 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 79 self.loss_training_epoch: MeanMetricBatch 80 r"""Total loss of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 81 82 # validation accumulated metrics 83 self.loss_cls_val: MeanMetricBatch 84 r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 85 86 # test accumulated metrics 87 self.loss_cls_test: dict[int, MeanMetricBatch] 88 r"""Test classification loss of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 89 90 # task ID control 91 self.task_id: int 92 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
Args:
- save_dir (
str): The directory where data and figures of metrics will be saved. Better inside the output folder. - test_loss_cls_csv_name(
str): file name to save classification loss matrix and average classification loss as CSV file. - test_loss_cls_matrix_plot_name (
str|None): file name to save classification loss matrix plot. IfNone, no file will be saved. - test_ave_loss_cls_plot_name (
str|None): file name to save average classification loss as curve plot over different training tasks. IfNone, no file will be saved.
The path to save test classification loss matrix and average classification loss CSV file.
Classification loss of training epoch. Accumulated and calculated from the training batches. See here for details.
Total loss of training epoch. Accumulated and calculated from the training batches. See here for details.
Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. See here for details.
Test classification loss of the current model (self.task_id) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.
94 @rank_zero_only 95 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 96 r"""Initialize training and validation metrics.""" 97 98 # set the current task_id from the `CLAlgorithm` object 99 self.task_id = pl_module.task_id 100 101 # get the device to put the metrics on the same device 102 device = pl_module.device 103 104 # initialize training metrics 105 self.loss_cls_training_epoch = MeanMetricBatch().to(device) 106 self.loss_training_epoch = MeanMetricBatch().to(device) 107 108 # initialize validation metrics 109 self.loss_cls_val = MeanMetricBatch().to(device)
Initialize training and validation metrics.
111 @rank_zero_only 112 def on_train_batch_end( 113 self, 114 trainer: Trainer, 115 pl_module: CLAlgorithm, 116 outputs: dict[str, Any], 117 batch: Any, 118 batch_idx: int, 119 ) -> None: 120 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 121 122 **Args:** 123 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 124 - **batch** (`Any`): the training data batch. 125 """ 126 # get the batch size 127 batch_size = len(batch) 128 129 # get training metrics values of current training batch from the outputs of the `training_step()` 130 loss_cls_batch = outputs["loss_cls"] 131 loss_batch = outputs["loss"] 132 133 # update accumulated training metrics to calculate training metrics of the epoch 134 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 135 self.loss_training_epoch.update(loss_batch, batch_size) 136 137 # log training metrics of current training batch to Lightning loggers 138 pl_module.log( 139 f"task_{self.task_id}/train/loss_cls_batch", loss_cls_batch, prog_bar=True 140 ) 141 pl_module.log( 142 f"task_{self.task_id}/train/loss_batch", loss_batch, prog_bar=True 143 ) 144 145 # log accumulated training metrics till this training batch to Lightning loggers 146 pl_module.log( 147 f"task_{self.task_id}/train/loss_cls", 148 self.loss_cls_training_epoch.compute(), 149 prog_bar=True, 150 ) 151 pl_module.log( 152 f"task_{self.task_id}/train/loss", 153 self.loss_training_epoch.compute(), 154 prog_bar=True, 155 )
Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
Args:
- outputs (
dict[str, Any]): the outputs of the training step, the returns of thetraining_step()method in theCLAlgorithm. - batch (
Any): the training data batch.
157 @rank_zero_only 158 def on_train_epoch_end( 159 self, 160 trainer: Trainer, 161 pl_module: CLAlgorithm, 162 ) -> None: 163 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 164 165 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 166 pl_module.log( 167 f"task_{self.task_id}/learning_curve/train/loss_cls", 168 self.loss_cls_training_epoch.compute(), 169 on_epoch=True, 170 prog_bar=True, 171 ) 172 173 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 174 self.loss_cls_training_epoch.reset() 175 self.loss_training_epoch.reset()
Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.
177 @rank_zero_only 178 def on_validation_batch_end( 179 self, 180 trainer: Trainer, 181 pl_module: CLAlgorithm, 182 outputs: dict[str, Any], 183 batch: Any, 184 batch_idx: int, 185 ) -> None: 186 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 187 188 **Args:** 189 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 190 - **batch** (`Any`): the validation data batch. 191 """ 192 193 # get the batch size 194 batch_size = len(batch) 195 196 # get the metrics values of the batch from the outputs 197 loss_cls_batch = outputs["loss_cls"] 198 199 # update the accumulated metrics in order to calculate the validation metrics 200 self.loss_cls_val.update(loss_cls_batch, batch_size)
Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
Args:
- outputs (
dict[str, Any]): the outputs of the validation step, which is the returns of thevalidation_step()method in theCLAlgorithm. - batch (
Any): the validation data batch.
202 @rank_zero_only 203 def on_validation_epoch_end( 204 self, 205 trainer: Trainer, 206 pl_module: CLAlgorithm, 207 ) -> None: 208 r"""Log validation metrics to plot learning curves.""" 209 210 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 211 pl_module.log( 212 f"task_{self.task_id}/learning_curve/val/loss_cls", 213 self.loss_cls_val.compute(), 214 on_epoch=True, 215 prog_bar=True, 216 )
Log validation metrics to plot learning curves.
218 @rank_zero_only 219 def on_test_start( 220 self, 221 trainer: Trainer, 222 pl_module: CLAlgorithm, 223 ) -> None: 224 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 225 226 # set the current task_id again (double checking) from the `CLAlgorithm` object 227 self.task_id = pl_module.task_id 228 229 # get the device to put the metrics on the same device 230 device = pl_module.device 231 232 # initialize test metrics for current and previous tasks 233 self.loss_cls_test = { 234 task_id: MeanMetricBatch().to(device) 235 for task_id in pl_module.processed_task_ids 236 }
Initialize the metrics for testing each seen task in the beginning of a task's testing.
238 @rank_zero_only 239 def on_test_batch_end( 240 self, 241 trainer: Trainer, 242 pl_module: CLAlgorithm, 243 outputs: dict[str, Any], 244 batch: Any, 245 batch_idx: int, 246 dataloader_idx: int = 0, 247 ) -> None: 248 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 249 250 **Args:** 251 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 252 - **batch** (`Any`): the test data batch. 253 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 254 """ 255 256 # get the batch size 257 batch_size = len(batch) 258 259 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 260 261 # get the metrics values of the batch from the outputs 262 loss_cls_batch = outputs["loss_cls"] 263 264 # update the accumulated metrics in order to calculate the metrics of the epoch 265 self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size)
Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
Args:
- outputs (
dict[str, Any]): the outputs of the test step, which is the returns of thetest_step()method in theCLAlgorithm. - batch (
Any): the test data batch. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
267 @rank_zero_only 268 def on_test_epoch_end( 269 self, 270 trainer: Trainer, 271 pl_module: CLAlgorithm, 272 ) -> None: 273 r"""Save and plot test metrics at the end of test.""" 274 275 # save (update) the test metrics to CSV files 276 self.update_test_loss_cls_to_csv( 277 after_training_task_id=self.task_id, 278 csv_path=self.test_loss_cls_csv_path, 279 ) 280 281 # plot the test metrics 282 if hasattr(self, "test_loss_cls_matrix_plot_path"): 283 self.plot_test_loss_cls_matrix_from_csv( 284 csv_path=self.test_loss_cls_csv_path, 285 plot_path=self.test_loss_cls_matrix_plot_path, 286 ) 287 if hasattr(self, "test_ave_loss_cls_plot_path"): 288 self.plot_test_ave_loss_cls_curve_from_csv( 289 csv_path=self.test_loss_cls_csv_path, 290 plot_path=self.test_ave_loss_cls_plot_path, 291 )
Save and plot test metrics at the end of test.
293 def update_test_loss_cls_to_csv( 294 self, 295 after_training_task_id: int, 296 csv_path: str, 297 ) -> None: 298 """Update the test classification loss metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 299 300 **Args:** 301 - **after_training_task_id** (`int`): the task ID after training. 302 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'. 303 """ 304 processed_task_ids = list(self.loss_cls_test.keys()) 305 fieldnames = ["after_training_task", "average_classification_loss"] + [ 306 f"test_on_task_{task_id}" for task_id in processed_task_ids 307 ] 308 309 new_line = { 310 "after_training_task": after_training_task_id 311 } # construct the first column 312 313 # write to the columns and calculate the average classification loss over tasks at the same time 314 average_classification_loss_over_tasks = MeanMetric().to( 315 device=next(iter(self.loss_cls_test.values())).device 316 ) 317 for task_id in processed_task_ids: 318 loss_cls = self.loss_cls_test[task_id].compute().item() 319 new_line[f"test_on_task_{task_id}"] = loss_cls 320 average_classification_loss_over_tasks(loss_cls) 321 new_line["average_classification_loss"] = ( 322 average_classification_loss_over_tasks.compute().item() 323 ) 324 325 # write to the csv file 326 is_first = not os.path.exists(csv_path) 327 if not is_first: 328 with open(csv_path, "r", encoding="utf-8") as file: 329 lines = file.readlines() 330 del lines[0] 331 # write header 332 with open(csv_path, "w", encoding="utf-8") as file: 333 writer = csv.DictWriter(file, fieldnames=fieldnames) 334 writer.writeheader() 335 # write metrics 336 with open(csv_path, "a", encoding="utf-8") as file: 337 if not is_first: 338 file.writelines(lines) # write the previous lines 339 writer = csv.DictWriter(file, fieldnames=fieldnames) 340 writer.writerow(new_line)
Update the test classification loss metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
Args:
- after_training_task_id (
int): the task ID after training. - csv_path (
str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
342 def plot_test_loss_cls_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 343 """Plot the test classification loss matrix from saved CSV file and save the plot to the designated directory. 344 345 **Args:** 346 - **csv_path** (`str`): the path to the CSV file where the `utils.update_loss_cls_to_csv()` saved the test classification loss metric. 347 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls_matrix.png'. 348 """ 349 data = pd.read_csv(csv_path) 350 processed_task_ids = [ 351 int(col.replace("test_on_task_", "")) 352 for col in data.columns 353 if col.startswith("test_on_task_") 354 ] 355 356 # Get all columns that start with "test_on_task_" 357 test_task_cols = [ 358 col for col in data.columns if col.startswith("test_on_task_") 359 ] 360 num_tasks = len(processed_task_ids) 361 num_rows = len(data) 362 363 # Build the loss matrix 364 loss_matrix = data[test_task_cols].values 365 366 fig, ax = plt.subplots( 367 figsize=(2 * num_tasks, 2 * num_rows) 368 ) # adaptive figure size 369 370 cax = ax.imshow( 371 loss_matrix, 372 interpolation="nearest", 373 cmap="Greens", 374 aspect="auto", 375 ) 376 377 colorbar = fig.colorbar(cax) 378 yticks = colorbar.ax.get_yticks() 379 colorbar.ax.set_yticks(yticks) 380 colorbar.ax.set_yticklabels( 381 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 382 ) 383 384 # Annotate each cell 385 for r in range(num_rows): 386 for c in range(r + 1): 387 ax.text( 388 c, 389 r, 390 f"{loss_matrix[r, c]:.3f}", 391 ha="center", 392 va="center", 393 color="black", 394 fontsize=10 + num_tasks, 395 ) 396 397 ax.set_xticks(range(num_tasks)) 398 ax.set_yticks(range(num_rows)) 399 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 400 ax.set_yticklabels( 401 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 402 ) 403 404 # Labeling the axes 405 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 406 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 407 fig.tight_layout() 408 fig.savefig(plot_path) 409 plt.close(fig)
Plot the test classification loss matrix from saved CSV file and save the plot to the designated directory.
Args:
- csv_path (
str): the path to the CSV file where theutils.update_loss_cls_to_csv()saved the test classification loss metric. - plot_path (
str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls_matrix.png'.
411 def plot_test_ave_loss_cls_curve_from_csv( 412 self, csv_path: str, plot_path: str 413 ) -> None: 414 """Plot the test average classfication loss curve over different training tasks from saved CSV file and save the plot to the designated directory. 415 416 **Args:** 417 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test classfication loss metric. 418 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_loss_cls.png'. 419 """ 420 data = pd.read_csv(csv_path) 421 after_training_tasks = data["after_training_task"].astype(int).tolist() 422 423 # plot the average accuracy curve over different training tasks 424 fig, ax = plt.subplots(figsize=(16, 9)) 425 ax.plot( 426 after_training_tasks, 427 data["average_classification_loss"], 428 marker="o", 429 linewidth=2, 430 ) 431 ax.set_xlabel("After training task $t$", fontsize=16) 432 ax.set_ylabel("Average Classification Loss", fontsize=16) 433 ax.grid(True) 434 xticks = after_training_tasks 435 yticks = [i * 0.05 for i in range(21)] 436 ax.set_xticks(xticks) 437 ax.set_yticks(yticks) 438 ax.set_xticklabels(xticks, fontsize=16) 439 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 440 fig.savefig(plot_path) 441 plt.close(fig)
Plot the test average classfication loss curve over different training tasks from saved CSV file and save the plot to the designated directory.
Args:
- csv_path (
str): the path to the CSV file where theutils.update_test_acc_to_csv()saved the test classfication loss metric. - plot_path (
str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_loss_cls.png'.