clarena.metrics.cl_acc
The submodule in metrics for CLAccuracy.
1r""" 2The submodule in `metrics` for `CLAccuracy`. 3""" 4 5__all__ = ["CLAccuracy"] 6 7import csv 8import logging 9import os 10from typing import Any 11 12import pandas as pd 13from lightning import Trainer 14from lightning.pytorch.utilities import rank_zero_only 15from matplotlib import pyplot as plt 16from torchmetrics import MeanMetric 17 18from clarena.cl_algorithms import CLAlgorithm 19from clarena.metrics import MetricCallback 20from clarena.utils.metrics import MeanMetricBatch 21 22# always get logger for built-in logging in each module 23pylogger = logging.getLogger(__name__) 24 25 26class CLAccuracy(MetricCallback): 27 r"""Provides all actions that are related to CL accuracy metric, which include: 28 29 - Defining, initializing and recording accuracy metric. 30 - Logging training and validation accuracy metric to Lightning loggers in real time. 31 - Saving test accuracy metric to files. 32 - Visualizing test accuracy metric as plots. 33 34 The callback is able to produce the following outputs: 35 36 - CSV files for test accuracy (lower triangular) matrix and average accuracy. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 37 - Coloured plot for test accuracy (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 38 - Curve plots for test average accuracy over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details. 39 40 Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric. 41 """ 42 43 def __init__( 44 self, 45 save_dir: str, 46 test_acc_csv_name: str = "acc.csv", 47 test_acc_matrix_plot_name: str | None = None, 48 test_ave_acc_plot_name: str | None = None, 49 ) -> None: 50 r""" 51 **Args:** 52 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 53 - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file. 54 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 55 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 56 """ 57 super().__init__(save_dir=save_dir) 58 59 # paths 60 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 61 r"""The path to save test accuracy matrix and average accuracy CSV file.""" 62 if test_acc_matrix_plot_name: 63 self.test_acc_matrix_plot_path: str = os.path.join( 64 save_dir, test_acc_matrix_plot_name 65 ) 66 r"""The path to save test accuracy matrix plot.""" 67 if test_ave_acc_plot_name: 68 self.test_ave_acc_plot_path: str = os.path.join( 69 save_dir, test_ave_acc_plot_name 70 ) 71 r"""The path to save test average accuracy curve plot.""" 72 73 # training accumulated metrics 74 self.acc_training_epoch: MeanMetricBatch 75 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 76 77 # validation accumulated metrics 78 self.acc_val: MeanMetricBatch 79 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 80 81 # test accumulated metrics 82 self.acc_test: dict[int, MeanMetricBatch] 83 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 84 85 # task ID control 86 self.task_id: int 87 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 88 89 @rank_zero_only 90 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 91 r"""Initialize training and validation metrics.""" 92 93 # set the current task_id from the `CLAlgorithm` object 94 self.task_id = pl_module.task_id 95 96 # get the device to put the metrics on the same device 97 device = pl_module.device 98 99 # initialize training metrics 100 self.acc_training_epoch = MeanMetricBatch().to(device) 101 102 # initialize validation metrics 103 self.acc_val = MeanMetricBatch().to(device) 104 105 @rank_zero_only 106 def on_train_batch_end( 107 self, 108 trainer: Trainer, 109 pl_module: CLAlgorithm, 110 outputs: dict[str, Any], 111 batch: Any, 112 batch_idx: int, 113 ) -> None: 114 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 115 116 **Args:** 117 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 118 - **batch** (`Any`): the training data batch. 119 """ 120 # get the batch size 121 batch_size = len(batch) 122 123 # get training metrics values of current training batch from the outputs of the `training_step()` 124 acc_batch = outputs["acc"] 125 126 # update accumulated training metrics to calculate training metrics of the epoch 127 self.acc_training_epoch.update(acc_batch, batch_size) 128 129 # log training metrics of current training batch to Lightning loggers 130 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 131 132 # log accumulated training metrics till this training batch to Lightning loggers 133 pl_module.log( 134 f"task_{self.task_id}/train/acc", 135 self.acc_training_epoch.compute(), 136 prog_bar=True, 137 ) 138 139 @rank_zero_only 140 def on_train_epoch_end( 141 self, 142 trainer: Trainer, 143 pl_module: CLAlgorithm, 144 ) -> None: 145 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 146 147 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 148 pl_module.log( 149 f"task_{self.task_id}/learning_curve/train/acc", 150 self.acc_training_epoch.compute(), 151 on_epoch=True, 152 prog_bar=True, 153 ) 154 155 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 156 self.acc_training_epoch.reset() 157 158 @rank_zero_only 159 def on_validation_batch_end( 160 self, 161 trainer: Trainer, 162 pl_module: CLAlgorithm, 163 outputs: dict[str, Any], 164 batch: Any, 165 batch_idx: int, 166 ) -> None: 167 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 168 169 **Args:** 170 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 171 - **batch** (`Any`): the validation data batch. 172 """ 173 174 # get the batch size 175 batch_size = len(batch) 176 177 # get the metrics values of the batch from the outputs 178 acc_batch = outputs["acc"] 179 180 # update the accumulated metrics in order to calculate the validation metrics 181 self.acc_val.update(acc_batch, batch_size) 182 183 @rank_zero_only 184 def on_validation_epoch_end( 185 self, 186 trainer: Trainer, 187 pl_module: CLAlgorithm, 188 ) -> None: 189 r"""Log validation metrics to plot learning curves.""" 190 191 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 192 pl_module.log( 193 f"task_{self.task_id}/learning_curve/val/acc", 194 self.acc_val.compute(), 195 on_epoch=True, 196 prog_bar=True, 197 ) 198 199 @rank_zero_only 200 def on_test_start( 201 self, 202 trainer: Trainer, 203 pl_module: CLAlgorithm, 204 ) -> None: 205 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 206 207 # set the current task_id again (double checking) from the `CLAlgorithm` object 208 self.task_id = pl_module.task_id 209 210 # get the device to put the metrics on the same device 211 device = pl_module.device 212 213 # initialize test metrics for current and previous tasks 214 self.acc_test = { 215 task_id: MeanMetricBatch().to(device) 216 for task_id in pl_module.processed_task_ids 217 } 218 219 @rank_zero_only 220 def on_test_batch_end( 221 self, 222 trainer: Trainer, 223 pl_module: CLAlgorithm, 224 outputs: dict[str, Any], 225 batch: Any, 226 batch_idx: int, 227 dataloader_idx: int = 0, 228 ) -> None: 229 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 230 231 **Args:** 232 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 233 - **batch** (`Any`): the test data batch. 234 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 235 """ 236 237 # get the batch size 238 batch_size = len(batch) 239 240 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 241 242 # get the metrics values of the batch from the outputs 243 acc_batch = outputs["acc"] 244 245 # update the accumulated metrics in order to calculate the metrics of the epoch 246 self.acc_test[test_task_id].update(acc_batch, batch_size) 247 248 @rank_zero_only 249 def on_test_epoch_end( 250 self, 251 trainer: Trainer, 252 pl_module: CLAlgorithm, 253 ) -> None: 254 r"""Save and plot test metrics at the end of test.""" 255 256 # save (update) the test metrics to CSV files 257 self.update_test_acc_to_csv( 258 after_training_task_id=self.task_id, 259 csv_path=self.test_acc_csv_path, 260 ) 261 262 # plot the test metrics 263 if hasattr(self, "test_acc_matrix_plot_path"): 264 self.plot_test_acc_matrix_from_csv( 265 csv_path=self.test_acc_csv_path, 266 plot_path=self.test_acc_matrix_plot_path, 267 ) 268 if hasattr(self, "test_ave_acc_plot_path"): 269 self.plot_test_ave_acc_curve_from_csv( 270 csv_path=self.test_acc_csv_path, 271 plot_path=self.test_ave_acc_plot_path, 272 ) 273 274 def update_test_acc_to_csv( 275 self, 276 after_training_task_id: int, 277 csv_path: str, 278 ) -> None: 279 r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 280 281 **Args:** 282 - **after_training_task_id** (`int`): the task ID after training. 283 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 284 """ 285 processed_task_ids = list(self.acc_test.keys()) 286 fieldnames = ["after_training_task", "average_accuracy"] + [ 287 f"test_on_task_{task_id}" for task_id in processed_task_ids 288 ] 289 290 new_line = { 291 "after_training_task": after_training_task_id 292 } # construct the first column 293 294 # construct the columns and calculate the average accuracy over tasks at the same time 295 average_accuracy_over_tasks = MeanMetric().to( 296 device=next(iter(self.acc_test.values())).device 297 ) 298 for task_id in processed_task_ids: 299 acc = self.acc_test[task_id].compute().item() 300 new_line[f"test_on_task_{task_id}"] = acc 301 average_accuracy_over_tasks(acc) 302 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 303 304 # write to the csv file 305 is_first = not os.path.exists(csv_path) 306 if not is_first: 307 with open(csv_path, "r", encoding="utf-8") as file: 308 lines = file.readlines() 309 del lines[0] 310 # write header 311 with open(csv_path, "w", encoding="utf-8") as file: 312 writer = csv.DictWriter(file, fieldnames=fieldnames) 313 writer.writeheader() 314 # write metrics 315 with open(csv_path, "a", encoding="utf-8") as file: 316 if not is_first: 317 file.writelines(lines) # write the previous lines 318 writer = csv.DictWriter(file, fieldnames=fieldnames) 319 writer.writerow(new_line) 320 321 def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 322 """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory. 323 324 **Args:** 325 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 326 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'. 327 """ 328 data = pd.read_csv(csv_path) 329 processed_task_ids = [ 330 int(col.replace("test_on_task_", "")) 331 for col in data.columns 332 if col.startswith("test_on_task_") 333 ] 334 335 # Get all columns that start with "test_on_task_" 336 test_task_cols = [ 337 col for col in data.columns if col.startswith("test_on_task_") 338 ] 339 num_tasks = len(processed_task_ids) 340 num_rows = len(data) 341 342 # Build the accuracy matrix 343 acc_matrix = data[test_task_cols].values 344 345 fig, ax = plt.subplots( 346 figsize=(2 * num_tasks, 2 * num_rows) 347 ) # adaptive figure size 348 349 cax = ax.imshow( 350 acc_matrix, 351 interpolation="nearest", 352 cmap="Greens", 353 vmin=0, 354 vmax=1, 355 aspect="auto", 356 ) 357 358 colorbar = fig.colorbar(cax) 359 yticks = colorbar.ax.get_yticks() 360 colorbar.ax.set_yticks(yticks) 361 colorbar.ax.set_yticklabels( 362 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 363 ) 364 365 # Annotate each cell 366 for r in range(num_rows): 367 for c in range(r + 1): 368 ax.text( 369 c, 370 r, 371 f"{acc_matrix[r, c]:.3f}", 372 ha="center", 373 va="center", 374 color="black", 375 fontsize=10 + num_tasks, 376 ) 377 378 ax.set_xticks(range(num_tasks)) 379 ax.set_yticks(range(num_rows)) 380 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 381 ax.set_yticklabels( 382 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 383 ) 384 385 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 386 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 387 fig.tight_layout() 388 fig.savefig(plot_path) 389 plt.close(fig) 390 391 def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None: 392 """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory. 393 394 **Args:** 395 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 396 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'. 397 """ 398 data = pd.read_csv(csv_path) 399 after_training_tasks = data["after_training_task"].astype(int).tolist() 400 401 # plot the average accuracy curve over different training tasks 402 fig, ax = plt.subplots(figsize=(16, 9)) 403 ax.plot( 404 after_training_tasks, 405 data["average_accuracy"], 406 marker="o", 407 linewidth=2, 408 ) 409 ax.set_xlabel("After training task $t$", fontsize=16) 410 ax.set_ylabel("Average Accuracy (AA)", fontsize=16) 411 ax.grid(True) 412 xticks = after_training_tasks 413 yticks = [i * 0.05 for i in range(21)] 414 ax.set_xticks(xticks) 415 ax.set_yticks(yticks) 416 ax.set_xticklabels(xticks, fontsize=16) 417 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 418 fig.savefig(plot_path) 419 plt.close(fig)
27class CLAccuracy(MetricCallback): 28 r"""Provides all actions that are related to CL accuracy metric, which include: 29 30 - Defining, initializing and recording accuracy metric. 31 - Logging training and validation accuracy metric to Lightning loggers in real time. 32 - Saving test accuracy metric to files. 33 - Visualizing test accuracy metric as plots. 34 35 The callback is able to produce the following outputs: 36 37 - CSV files for test accuracy (lower triangular) matrix and average accuracy. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 38 - Coloured plot for test accuracy (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 39 - Curve plots for test average accuracy over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details. 40 41 Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric. 42 """ 43 44 def __init__( 45 self, 46 save_dir: str, 47 test_acc_csv_name: str = "acc.csv", 48 test_acc_matrix_plot_name: str | None = None, 49 test_ave_acc_plot_name: str | None = None, 50 ) -> None: 51 r""" 52 **Args:** 53 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 54 - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file. 55 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 56 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 57 """ 58 super().__init__(save_dir=save_dir) 59 60 # paths 61 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 62 r"""The path to save test accuracy matrix and average accuracy CSV file.""" 63 if test_acc_matrix_plot_name: 64 self.test_acc_matrix_plot_path: str = os.path.join( 65 save_dir, test_acc_matrix_plot_name 66 ) 67 r"""The path to save test accuracy matrix plot.""" 68 if test_ave_acc_plot_name: 69 self.test_ave_acc_plot_path: str = os.path.join( 70 save_dir, test_ave_acc_plot_name 71 ) 72 r"""The path to save test average accuracy curve plot.""" 73 74 # training accumulated metrics 75 self.acc_training_epoch: MeanMetricBatch 76 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 77 78 # validation accumulated metrics 79 self.acc_val: MeanMetricBatch 80 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 81 82 # test accumulated metrics 83 self.acc_test: dict[int, MeanMetricBatch] 84 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 85 86 # task ID control 87 self.task_id: int 88 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 89 90 @rank_zero_only 91 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 92 r"""Initialize training and validation metrics.""" 93 94 # set the current task_id from the `CLAlgorithm` object 95 self.task_id = pl_module.task_id 96 97 # get the device to put the metrics on the same device 98 device = pl_module.device 99 100 # initialize training metrics 101 self.acc_training_epoch = MeanMetricBatch().to(device) 102 103 # initialize validation metrics 104 self.acc_val = MeanMetricBatch().to(device) 105 106 @rank_zero_only 107 def on_train_batch_end( 108 self, 109 trainer: Trainer, 110 pl_module: CLAlgorithm, 111 outputs: dict[str, Any], 112 batch: Any, 113 batch_idx: int, 114 ) -> None: 115 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 116 117 **Args:** 118 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 119 - **batch** (`Any`): the training data batch. 120 """ 121 # get the batch size 122 batch_size = len(batch) 123 124 # get training metrics values of current training batch from the outputs of the `training_step()` 125 acc_batch = outputs["acc"] 126 127 # update accumulated training metrics to calculate training metrics of the epoch 128 self.acc_training_epoch.update(acc_batch, batch_size) 129 130 # log training metrics of current training batch to Lightning loggers 131 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 132 133 # log accumulated training metrics till this training batch to Lightning loggers 134 pl_module.log( 135 f"task_{self.task_id}/train/acc", 136 self.acc_training_epoch.compute(), 137 prog_bar=True, 138 ) 139 140 @rank_zero_only 141 def on_train_epoch_end( 142 self, 143 trainer: Trainer, 144 pl_module: CLAlgorithm, 145 ) -> None: 146 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 147 148 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 149 pl_module.log( 150 f"task_{self.task_id}/learning_curve/train/acc", 151 self.acc_training_epoch.compute(), 152 on_epoch=True, 153 prog_bar=True, 154 ) 155 156 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 157 self.acc_training_epoch.reset() 158 159 @rank_zero_only 160 def on_validation_batch_end( 161 self, 162 trainer: Trainer, 163 pl_module: CLAlgorithm, 164 outputs: dict[str, Any], 165 batch: Any, 166 batch_idx: int, 167 ) -> None: 168 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 169 170 **Args:** 171 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 172 - **batch** (`Any`): the validation data batch. 173 """ 174 175 # get the batch size 176 batch_size = len(batch) 177 178 # get the metrics values of the batch from the outputs 179 acc_batch = outputs["acc"] 180 181 # update the accumulated metrics in order to calculate the validation metrics 182 self.acc_val.update(acc_batch, batch_size) 183 184 @rank_zero_only 185 def on_validation_epoch_end( 186 self, 187 trainer: Trainer, 188 pl_module: CLAlgorithm, 189 ) -> None: 190 r"""Log validation metrics to plot learning curves.""" 191 192 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 193 pl_module.log( 194 f"task_{self.task_id}/learning_curve/val/acc", 195 self.acc_val.compute(), 196 on_epoch=True, 197 prog_bar=True, 198 ) 199 200 @rank_zero_only 201 def on_test_start( 202 self, 203 trainer: Trainer, 204 pl_module: CLAlgorithm, 205 ) -> None: 206 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 207 208 # set the current task_id again (double checking) from the `CLAlgorithm` object 209 self.task_id = pl_module.task_id 210 211 # get the device to put the metrics on the same device 212 device = pl_module.device 213 214 # initialize test metrics for current and previous tasks 215 self.acc_test = { 216 task_id: MeanMetricBatch().to(device) 217 for task_id in pl_module.processed_task_ids 218 } 219 220 @rank_zero_only 221 def on_test_batch_end( 222 self, 223 trainer: Trainer, 224 pl_module: CLAlgorithm, 225 outputs: dict[str, Any], 226 batch: Any, 227 batch_idx: int, 228 dataloader_idx: int = 0, 229 ) -> None: 230 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 231 232 **Args:** 233 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 234 - **batch** (`Any`): the test data batch. 235 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 236 """ 237 238 # get the batch size 239 batch_size = len(batch) 240 241 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 242 243 # get the metrics values of the batch from the outputs 244 acc_batch = outputs["acc"] 245 246 # update the accumulated metrics in order to calculate the metrics of the epoch 247 self.acc_test[test_task_id].update(acc_batch, batch_size) 248 249 @rank_zero_only 250 def on_test_epoch_end( 251 self, 252 trainer: Trainer, 253 pl_module: CLAlgorithm, 254 ) -> None: 255 r"""Save and plot test metrics at the end of test.""" 256 257 # save (update) the test metrics to CSV files 258 self.update_test_acc_to_csv( 259 after_training_task_id=self.task_id, 260 csv_path=self.test_acc_csv_path, 261 ) 262 263 # plot the test metrics 264 if hasattr(self, "test_acc_matrix_plot_path"): 265 self.plot_test_acc_matrix_from_csv( 266 csv_path=self.test_acc_csv_path, 267 plot_path=self.test_acc_matrix_plot_path, 268 ) 269 if hasattr(self, "test_ave_acc_plot_path"): 270 self.plot_test_ave_acc_curve_from_csv( 271 csv_path=self.test_acc_csv_path, 272 plot_path=self.test_ave_acc_plot_path, 273 ) 274 275 def update_test_acc_to_csv( 276 self, 277 after_training_task_id: int, 278 csv_path: str, 279 ) -> None: 280 r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 281 282 **Args:** 283 - **after_training_task_id** (`int`): the task ID after training. 284 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 285 """ 286 processed_task_ids = list(self.acc_test.keys()) 287 fieldnames = ["after_training_task", "average_accuracy"] + [ 288 f"test_on_task_{task_id}" for task_id in processed_task_ids 289 ] 290 291 new_line = { 292 "after_training_task": after_training_task_id 293 } # construct the first column 294 295 # construct the columns and calculate the average accuracy over tasks at the same time 296 average_accuracy_over_tasks = MeanMetric().to( 297 device=next(iter(self.acc_test.values())).device 298 ) 299 for task_id in processed_task_ids: 300 acc = self.acc_test[task_id].compute().item() 301 new_line[f"test_on_task_{task_id}"] = acc 302 average_accuracy_over_tasks(acc) 303 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 304 305 # write to the csv file 306 is_first = not os.path.exists(csv_path) 307 if not is_first: 308 with open(csv_path, "r", encoding="utf-8") as file: 309 lines = file.readlines() 310 del lines[0] 311 # write header 312 with open(csv_path, "w", encoding="utf-8") as file: 313 writer = csv.DictWriter(file, fieldnames=fieldnames) 314 writer.writeheader() 315 # write metrics 316 with open(csv_path, "a", encoding="utf-8") as file: 317 if not is_first: 318 file.writelines(lines) # write the previous lines 319 writer = csv.DictWriter(file, fieldnames=fieldnames) 320 writer.writerow(new_line) 321 322 def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 323 """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory. 324 325 **Args:** 326 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 327 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'. 328 """ 329 data = pd.read_csv(csv_path) 330 processed_task_ids = [ 331 int(col.replace("test_on_task_", "")) 332 for col in data.columns 333 if col.startswith("test_on_task_") 334 ] 335 336 # Get all columns that start with "test_on_task_" 337 test_task_cols = [ 338 col for col in data.columns if col.startswith("test_on_task_") 339 ] 340 num_tasks = len(processed_task_ids) 341 num_rows = len(data) 342 343 # Build the accuracy matrix 344 acc_matrix = data[test_task_cols].values 345 346 fig, ax = plt.subplots( 347 figsize=(2 * num_tasks, 2 * num_rows) 348 ) # adaptive figure size 349 350 cax = ax.imshow( 351 acc_matrix, 352 interpolation="nearest", 353 cmap="Greens", 354 vmin=0, 355 vmax=1, 356 aspect="auto", 357 ) 358 359 colorbar = fig.colorbar(cax) 360 yticks = colorbar.ax.get_yticks() 361 colorbar.ax.set_yticks(yticks) 362 colorbar.ax.set_yticklabels( 363 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 364 ) 365 366 # Annotate each cell 367 for r in range(num_rows): 368 for c in range(r + 1): 369 ax.text( 370 c, 371 r, 372 f"{acc_matrix[r, c]:.3f}", 373 ha="center", 374 va="center", 375 color="black", 376 fontsize=10 + num_tasks, 377 ) 378 379 ax.set_xticks(range(num_tasks)) 380 ax.set_yticks(range(num_rows)) 381 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 382 ax.set_yticklabels( 383 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 384 ) 385 386 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 387 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 388 fig.tight_layout() 389 fig.savefig(plot_path) 390 plt.close(fig) 391 392 def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None: 393 """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory. 394 395 **Args:** 396 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 397 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'. 398 """ 399 data = pd.read_csv(csv_path) 400 after_training_tasks = data["after_training_task"].astype(int).tolist() 401 402 # plot the average accuracy curve over different training tasks 403 fig, ax = plt.subplots(figsize=(16, 9)) 404 ax.plot( 405 after_training_tasks, 406 data["average_accuracy"], 407 marker="o", 408 linewidth=2, 409 ) 410 ax.set_xlabel("After training task $t$", fontsize=16) 411 ax.set_ylabel("Average Accuracy (AA)", fontsize=16) 412 ax.grid(True) 413 xticks = after_training_tasks 414 yticks = [i * 0.05 for i in range(21)] 415 ax.set_xticks(xticks) 416 ax.set_yticks(yticks) 417 ax.set_xticklabels(xticks, fontsize=16) 418 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 419 fig.savefig(plot_path) 420 plt.close(fig)
Provides all actions that are related to CL accuracy metric, which include:
- Defining, initializing and recording accuracy metric.
- Logging training and validation accuracy metric to Lightning loggers in real time.
- Saving test accuracy metric to files.
- Visualizing test accuracy metric as plots.
The callback is able to produce the following outputs:
- CSV files for test accuracy (lower triangular) matrix and average accuracy. See here for details.
- Coloured plot for test accuracy (lower triangular) matrix. See here for details.
- Curve plots for test average accuracy over different training tasks. See here for details.
Please refer to the A Summary of Continual Learning Metrics to learn about this metric.
44 def __init__( 45 self, 46 save_dir: str, 47 test_acc_csv_name: str = "acc.csv", 48 test_acc_matrix_plot_name: str | None = None, 49 test_ave_acc_plot_name: str | None = None, 50 ) -> None: 51 r""" 52 **Args:** 53 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 54 - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file. 55 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 56 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 57 """ 58 super().__init__(save_dir=save_dir) 59 60 # paths 61 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 62 r"""The path to save test accuracy matrix and average accuracy CSV file.""" 63 if test_acc_matrix_plot_name: 64 self.test_acc_matrix_plot_path: str = os.path.join( 65 save_dir, test_acc_matrix_plot_name 66 ) 67 r"""The path to save test accuracy matrix plot.""" 68 if test_ave_acc_plot_name: 69 self.test_ave_acc_plot_path: str = os.path.join( 70 save_dir, test_ave_acc_plot_name 71 ) 72 r"""The path to save test average accuracy curve plot.""" 73 74 # training accumulated metrics 75 self.acc_training_epoch: MeanMetricBatch 76 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 77 78 # validation accumulated metrics 79 self.acc_val: MeanMetricBatch 80 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 81 82 # test accumulated metrics 83 self.acc_test: dict[int, MeanMetricBatch] 84 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 85 86 # task ID control 87 self.task_id: int 88 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
Args:
- save_dir (
str): The directory where data and figures of metrics will be saved. Better inside the output folder. - test_acc_csv_name (
str): file name to save test accuracy matrix and average accuracy as CSV file. - test_acc_matrix_plot_name (
str|None): file name to save accuracy matrix plot. IfNone, no file will be saved. - test_ave_acc_plot_name (
str|None): file name to save average accuracy as curve plot over different training tasks. IfNone, no file will be saved.
Classification accuracy of training epoch. Accumulated and calculated from the training batches. See here for details.
Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See here for details.
Test classification accuracy of the current model (self.task_id) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks.
90 @rank_zero_only 91 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 92 r"""Initialize training and validation metrics.""" 93 94 # set the current task_id from the `CLAlgorithm` object 95 self.task_id = pl_module.task_id 96 97 # get the device to put the metrics on the same device 98 device = pl_module.device 99 100 # initialize training metrics 101 self.acc_training_epoch = MeanMetricBatch().to(device) 102 103 # initialize validation metrics 104 self.acc_val = MeanMetricBatch().to(device)
Initialize training and validation metrics.
106 @rank_zero_only 107 def on_train_batch_end( 108 self, 109 trainer: Trainer, 110 pl_module: CLAlgorithm, 111 outputs: dict[str, Any], 112 batch: Any, 113 batch_idx: int, 114 ) -> None: 115 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 116 117 **Args:** 118 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 119 - **batch** (`Any`): the training data batch. 120 """ 121 # get the batch size 122 batch_size = len(batch) 123 124 # get training metrics values of current training batch from the outputs of the `training_step()` 125 acc_batch = outputs["acc"] 126 127 # update accumulated training metrics to calculate training metrics of the epoch 128 self.acc_training_epoch.update(acc_batch, batch_size) 129 130 # log training metrics of current training batch to Lightning loggers 131 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 132 133 # log accumulated training metrics till this training batch to Lightning loggers 134 pl_module.log( 135 f"task_{self.task_id}/train/acc", 136 self.acc_training_epoch.compute(), 137 prog_bar=True, 138 )
Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
Args:
- outputs (
dict[str, Any]): the outputs of the training step, the returns of thetraining_step()method in theCLAlgorithm. - batch (
Any): the training data batch.
140 @rank_zero_only 141 def on_train_epoch_end( 142 self, 143 trainer: Trainer, 144 pl_module: CLAlgorithm, 145 ) -> None: 146 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 147 148 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 149 pl_module.log( 150 f"task_{self.task_id}/learning_curve/train/acc", 151 self.acc_training_epoch.compute(), 152 on_epoch=True, 153 prog_bar=True, 154 ) 155 156 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 157 self.acc_training_epoch.reset()
Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.
159 @rank_zero_only 160 def on_validation_batch_end( 161 self, 162 trainer: Trainer, 163 pl_module: CLAlgorithm, 164 outputs: dict[str, Any], 165 batch: Any, 166 batch_idx: int, 167 ) -> None: 168 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 169 170 **Args:** 171 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 172 - **batch** (`Any`): the validation data batch. 173 """ 174 175 # get the batch size 176 batch_size = len(batch) 177 178 # get the metrics values of the batch from the outputs 179 acc_batch = outputs["acc"] 180 181 # update the accumulated metrics in order to calculate the validation metrics 182 self.acc_val.update(acc_batch, batch_size)
Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
Args:
- outputs (
dict[str, Any]): the outputs of the validation step, which is the returns of thevalidation_step()method in theCLAlgorithm. - batch (
Any): the validation data batch.
184 @rank_zero_only 185 def on_validation_epoch_end( 186 self, 187 trainer: Trainer, 188 pl_module: CLAlgorithm, 189 ) -> None: 190 r"""Log validation metrics to plot learning curves.""" 191 192 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 193 pl_module.log( 194 f"task_{self.task_id}/learning_curve/val/acc", 195 self.acc_val.compute(), 196 on_epoch=True, 197 prog_bar=True, 198 )
Log validation metrics to plot learning curves.
200 @rank_zero_only 201 def on_test_start( 202 self, 203 trainer: Trainer, 204 pl_module: CLAlgorithm, 205 ) -> None: 206 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 207 208 # set the current task_id again (double checking) from the `CLAlgorithm` object 209 self.task_id = pl_module.task_id 210 211 # get the device to put the metrics on the same device 212 device = pl_module.device 213 214 # initialize test metrics for current and previous tasks 215 self.acc_test = { 216 task_id: MeanMetricBatch().to(device) 217 for task_id in pl_module.processed_task_ids 218 }
Initialize the metrics for testing each seen task in the beginning of a task's testing.
220 @rank_zero_only 221 def on_test_batch_end( 222 self, 223 trainer: Trainer, 224 pl_module: CLAlgorithm, 225 outputs: dict[str, Any], 226 batch: Any, 227 batch_idx: int, 228 dataloader_idx: int = 0, 229 ) -> None: 230 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 231 232 **Args:** 233 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 234 - **batch** (`Any`): the test data batch. 235 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 236 """ 237 238 # get the batch size 239 batch_size = len(batch) 240 241 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 242 243 # get the metrics values of the batch from the outputs 244 acc_batch = outputs["acc"] 245 246 # update the accumulated metrics in order to calculate the metrics of the epoch 247 self.acc_test[test_task_id].update(acc_batch, batch_size)
Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
Args:
- outputs (
dict[str, Any]): the outputs of the test step, which is the returns of thetest_step()method in theCLAlgorithm. - batch (
Any): the test data batch. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
249 @rank_zero_only 250 def on_test_epoch_end( 251 self, 252 trainer: Trainer, 253 pl_module: CLAlgorithm, 254 ) -> None: 255 r"""Save and plot test metrics at the end of test.""" 256 257 # save (update) the test metrics to CSV files 258 self.update_test_acc_to_csv( 259 after_training_task_id=self.task_id, 260 csv_path=self.test_acc_csv_path, 261 ) 262 263 # plot the test metrics 264 if hasattr(self, "test_acc_matrix_plot_path"): 265 self.plot_test_acc_matrix_from_csv( 266 csv_path=self.test_acc_csv_path, 267 plot_path=self.test_acc_matrix_plot_path, 268 ) 269 if hasattr(self, "test_ave_acc_plot_path"): 270 self.plot_test_ave_acc_curve_from_csv( 271 csv_path=self.test_acc_csv_path, 272 plot_path=self.test_ave_acc_plot_path, 273 )
Save and plot test metrics at the end of test.
275 def update_test_acc_to_csv( 276 self, 277 after_training_task_id: int, 278 csv_path: str, 279 ) -> None: 280 r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 281 282 **Args:** 283 - **after_training_task_id** (`int`): the task ID after training. 284 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 285 """ 286 processed_task_ids = list(self.acc_test.keys()) 287 fieldnames = ["after_training_task", "average_accuracy"] + [ 288 f"test_on_task_{task_id}" for task_id in processed_task_ids 289 ] 290 291 new_line = { 292 "after_training_task": after_training_task_id 293 } # construct the first column 294 295 # construct the columns and calculate the average accuracy over tasks at the same time 296 average_accuracy_over_tasks = MeanMetric().to( 297 device=next(iter(self.acc_test.values())).device 298 ) 299 for task_id in processed_task_ids: 300 acc = self.acc_test[task_id].compute().item() 301 new_line[f"test_on_task_{task_id}"] = acc 302 average_accuracy_over_tasks(acc) 303 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 304 305 # write to the csv file 306 is_first = not os.path.exists(csv_path) 307 if not is_first: 308 with open(csv_path, "r", encoding="utf-8") as file: 309 lines = file.readlines() 310 del lines[0] 311 # write header 312 with open(csv_path, "w", encoding="utf-8") as file: 313 writer = csv.DictWriter(file, fieldnames=fieldnames) 314 writer.writeheader() 315 # write metrics 316 with open(csv_path, "a", encoding="utf-8") as file: 317 if not is_first: 318 file.writelines(lines) # write the previous lines 319 writer = csv.DictWriter(file, fieldnames=fieldnames) 320 writer.writerow(new_line)
Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
Args:
- after_training_task_id (
int): the task ID after training. - csv_path (
str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
322 def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 323 """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory. 324 325 **Args:** 326 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 327 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'. 328 """ 329 data = pd.read_csv(csv_path) 330 processed_task_ids = [ 331 int(col.replace("test_on_task_", "")) 332 for col in data.columns 333 if col.startswith("test_on_task_") 334 ] 335 336 # Get all columns that start with "test_on_task_" 337 test_task_cols = [ 338 col for col in data.columns if col.startswith("test_on_task_") 339 ] 340 num_tasks = len(processed_task_ids) 341 num_rows = len(data) 342 343 # Build the accuracy matrix 344 acc_matrix = data[test_task_cols].values 345 346 fig, ax = plt.subplots( 347 figsize=(2 * num_tasks, 2 * num_rows) 348 ) # adaptive figure size 349 350 cax = ax.imshow( 351 acc_matrix, 352 interpolation="nearest", 353 cmap="Greens", 354 vmin=0, 355 vmax=1, 356 aspect="auto", 357 ) 358 359 colorbar = fig.colorbar(cax) 360 yticks = colorbar.ax.get_yticks() 361 colorbar.ax.set_yticks(yticks) 362 colorbar.ax.set_yticklabels( 363 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 364 ) 365 366 # Annotate each cell 367 for r in range(num_rows): 368 for c in range(r + 1): 369 ax.text( 370 c, 371 r, 372 f"{acc_matrix[r, c]:.3f}", 373 ha="center", 374 va="center", 375 color="black", 376 fontsize=10 + num_tasks, 377 ) 378 379 ax.set_xticks(range(num_tasks)) 380 ax.set_yticks(range(num_rows)) 381 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 382 ax.set_yticklabels( 383 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 384 ) 385 386 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 387 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 388 fig.tight_layout() 389 fig.savefig(plot_path) 390 plt.close(fig)
Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.
Args:
- csv_path (
str): the path to the CSV file where theutils.update_test_acc_to_csv()saved the test accuracy metric. - plot_path (
str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
392 def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None: 393 """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory. 394 395 **Args:** 396 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 397 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'. 398 """ 399 data = pd.read_csv(csv_path) 400 after_training_tasks = data["after_training_task"].astype(int).tolist() 401 402 # plot the average accuracy curve over different training tasks 403 fig, ax = plt.subplots(figsize=(16, 9)) 404 ax.plot( 405 after_training_tasks, 406 data["average_accuracy"], 407 marker="o", 408 linewidth=2, 409 ) 410 ax.set_xlabel("After training task $t$", fontsize=16) 411 ax.set_ylabel("Average Accuracy (AA)", fontsize=16) 412 ax.grid(True) 413 xticks = after_training_tasks 414 yticks = [i * 0.05 for i in range(21)] 415 ax.set_xticks(xticks) 416 ax.set_yticks(yticks) 417 ax.set_xticklabels(xticks, fontsize=16) 418 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 419 fig.savefig(plot_path) 420 plt.close(fig)
Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.
Args:
- csv_path (
str): the path to the CSV file where theutils.update_test_acc_to_csv()saved the test accuracy metric. - plot_path (
str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.