clarena.metrics.cl_acc
The submodule in metrics
for CLAccuracy
.
1r""" 2The submodule in `metrics` for `CLAccuracy`. 3""" 4 5__all__ = ["CLAccuracy"] 6 7import csv 8import logging 9import os 10from typing import Any 11 12import pandas as pd 13from lightning import Trainer 14from lightning.pytorch.utilities import rank_zero_only 15from matplotlib import pyplot as plt 16from torchmetrics import MeanMetric 17 18from clarena.cl_algorithms import CLAlgorithm 19from clarena.metrics import MetricCallback 20from clarena.utils.metrics import MeanMetricBatch 21 22# always get logger for built-in logging in each module 23pylogger = logging.getLogger(__name__) 24 25 26class CLAccuracy(MetricCallback): 27 r"""Provides all actions that are related to CL accuracy metric, which include: 28 29 - Defining, initializing and recording accuracy metric. 30 - Logging training and validation accuracy metric to Lightning loggers in real time. 31 - Saving test accuracy metric to files. 32 - Visualizing test accuracy metric as plots. 33 34 The callback is able to produce the following outputs: 35 36 - CSV files for test accuracy (lower triangular) matrix and average accuracy. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 37 - Coloured plot for test accuracy (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 38 - Curve plots for test average accuracy over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details. 39 40 Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric. 41 """ 42 43 def __init__( 44 self, 45 save_dir: str, 46 test_acc_csv_name: str = "acc.csv", 47 test_acc_matrix_plot_name: str | None = None, 48 test_ave_acc_plot_name: str | None = None, 49 ) -> None: 50 r""" 51 **Args:** 52 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 53 - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file. 54 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 55 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 56 """ 57 super().__init__(save_dir=save_dir) 58 59 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 60 r"""The path to save test accuracy matrix and average accuracy CSV file.""" 61 if test_acc_matrix_plot_name: 62 self.test_acc_matrix_plot_path: str = os.path.join( 63 save_dir, test_acc_matrix_plot_name 64 ) 65 r"""The path to save test accuracy matrix plot.""" 66 if test_ave_acc_plot_name: 67 self.test_ave_acc_plot_path: str = os.path.join( 68 save_dir, test_ave_acc_plot_name 69 ) 70 r"""The path to save test average accuracy curve plot.""" 71 72 # training accumulated metrics 73 self.acc_training_epoch: MeanMetricBatch 74 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 75 76 # validation accumulated metrics 77 self.acc_val: MeanMetricBatch 78 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 79 80 # test accumulated metrics 81 self.acc_test: dict[int, MeanMetricBatch] 82 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 83 84 # task ID control 85 self.task_id: int 86 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 87 88 @rank_zero_only 89 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 90 r"""Initialize training and validation metrics.""" 91 92 # set the current task_id from the `CLAlgorithm` object 93 self.task_id = pl_module.task_id 94 95 # get the device to put the metrics on the same device 96 device = pl_module.device 97 98 # initialize training metrics 99 self.acc_training_epoch = MeanMetricBatch().to(device) 100 101 # initialize validation metrics 102 self.acc_val = MeanMetricBatch().to(device) 103 104 @rank_zero_only 105 def on_train_batch_end( 106 self, 107 trainer: Trainer, 108 pl_module: CLAlgorithm, 109 outputs: dict[str, Any], 110 batch: Any, 111 batch_idx: int, 112 ) -> None: 113 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 114 115 **Args:** 116 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 117 - **batch** (`Any`): the training data batch. 118 """ 119 # get the batch size 120 batch_size = len(batch) 121 122 # get training metrics values of current training batch from the outputs of the `training_step()` 123 acc_batch = outputs["acc"] 124 125 # update accumulated training metrics to calculate training metrics of the epoch 126 self.acc_training_epoch.update(acc_batch, batch_size) 127 128 # log training metrics of current training batch to Lightning loggers 129 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 130 131 # log accumulated training metrics till this training batch to Lightning loggers 132 pl_module.log( 133 f"task_{self.task_id}/train/acc", 134 self.acc_training_epoch.compute(), 135 prog_bar=True, 136 ) 137 138 @rank_zero_only 139 def on_train_epoch_end( 140 self, 141 trainer: Trainer, 142 pl_module: CLAlgorithm, 143 ) -> None: 144 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 145 146 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 147 pl_module.log( 148 f"task_{self.task_id}/learning_curve/train/acc", 149 self.acc_training_epoch.compute(), 150 on_epoch=True, 151 prog_bar=True, 152 ) 153 154 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 155 self.acc_training_epoch.reset() 156 157 @rank_zero_only 158 def on_validation_batch_end( 159 self, 160 trainer: Trainer, 161 pl_module: CLAlgorithm, 162 outputs: dict[str, Any], 163 batch: Any, 164 batch_idx: int, 165 ) -> None: 166 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 167 168 **Args:** 169 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 170 - **batch** (`Any`): the validation data batch. 171 """ 172 173 # get the batch size 174 batch_size = len(batch) 175 176 # get the metrics values of the batch from the outputs 177 acc_batch = outputs["acc"] 178 179 # update the accumulated metrics in order to calculate the validation metrics 180 self.acc_val.update(acc_batch, batch_size) 181 182 @rank_zero_only 183 def on_validation_epoch_end( 184 self, 185 trainer: Trainer, 186 pl_module: CLAlgorithm, 187 ) -> None: 188 r"""Log validation metrics to plot learning curves.""" 189 190 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 191 pl_module.log( 192 f"task_{self.task_id}/learning_curve/val/acc", 193 self.acc_val.compute(), 194 on_epoch=True, 195 prog_bar=True, 196 ) 197 198 @rank_zero_only 199 def on_test_start( 200 self, 201 trainer: Trainer, 202 pl_module: CLAlgorithm, 203 ) -> None: 204 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 205 206 # set the current task_id again (double checking) from the `CLAlgorithm` object 207 self.task_id = pl_module.task_id 208 209 # get the device to put the metrics on the same device 210 device = pl_module.device 211 212 # initialize test metrics for current and previous tasks 213 self.acc_test = { 214 task_id: MeanMetricBatch().to(device) 215 for task_id in pl_module.processed_task_ids 216 } 217 218 @rank_zero_only 219 def on_test_batch_end( 220 self, 221 trainer: Trainer, 222 pl_module: CLAlgorithm, 223 outputs: dict[str, Any], 224 batch: Any, 225 batch_idx: int, 226 dataloader_idx: int = 0, 227 ) -> None: 228 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 229 230 **Args:** 231 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 232 - **batch** (`Any`): the test data batch. 233 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 234 """ 235 236 # get the batch size 237 batch_size = len(batch) 238 239 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 240 241 # get the metrics values of the batch from the outputs 242 acc_batch = outputs["acc"] 243 244 # update the accumulated metrics in order to calculate the metrics of the epoch 245 self.acc_test[test_task_id].update(acc_batch, batch_size) 246 247 @rank_zero_only 248 def on_test_epoch_end( 249 self, 250 trainer: Trainer, 251 pl_module: CLAlgorithm, 252 ) -> None: 253 r"""Save and plot test metrics at the end of test.""" 254 255 # save (update) the test metrics to CSV files 256 self.update_test_acc_to_csv( 257 after_training_task_id=self.task_id, 258 csv_path=self.test_acc_csv_path, 259 ) 260 261 # plot the test metrics 262 if hasattr(self, "test_acc_matrix_plot_path"): 263 self.plot_test_acc_matrix_from_csv( 264 csv_path=self.test_acc_csv_path, 265 plot_path=self.test_acc_matrix_plot_path, 266 ) 267 if hasattr(self, "test_ave_acc_plot_path"): 268 self.plot_test_ave_acc_curve_from_csv( 269 csv_path=self.test_acc_csv_path, 270 plot_path=self.test_ave_acc_plot_path, 271 ) 272 273 def update_test_acc_to_csv( 274 self, 275 after_training_task_id: int, 276 csv_path: str, 277 ) -> None: 278 r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 279 280 **Args:** 281 - **after_training_task_id** (`int`): the task ID after training. 282 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 283 """ 284 processed_task_ids = list(self.acc_test.keys()) 285 fieldnames = ["after_training_task", "average_accuracy"] + [ 286 f"test_on_task_{task_id}" for task_id in processed_task_ids 287 ] 288 289 new_line = { 290 "after_training_task": after_training_task_id 291 } # construct the first column 292 293 # construct the columns and calculate the average accuracy over tasks at the same time 294 average_accuracy_over_tasks = MeanMetric().to( 295 device=next(iter(self.acc_test.values())).device 296 ) 297 for task_id in processed_task_ids: 298 acc = self.acc_test[task_id].compute().item() 299 new_line[f"test_on_task_{task_id}"] = acc 300 average_accuracy_over_tasks(acc) 301 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 302 303 # write to the csv file 304 is_first = not os.path.exists(csv_path) 305 if not is_first: 306 with open(csv_path, "r", encoding="utf-8") as file: 307 lines = file.readlines() 308 del lines[0] 309 # write header 310 with open(csv_path, "w", encoding="utf-8") as file: 311 writer = csv.DictWriter(file, fieldnames=fieldnames) 312 writer.writeheader() 313 # write metrics 314 with open(csv_path, "a", encoding="utf-8") as file: 315 if not is_first: 316 file.writelines(lines) # write the previous lines 317 writer = csv.DictWriter(file, fieldnames=fieldnames) 318 writer.writerow(new_line) 319 320 def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 321 """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory. 322 323 **Args:** 324 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 325 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'. 326 """ 327 data = pd.read_csv(csv_path) 328 processed_task_ids = [ 329 int(col.replace("test_on_task_", "")) 330 for col in data.columns 331 if col.startswith("test_on_task_") 332 ] 333 334 # Get all columns that start with "test_on_task_" 335 test_task_cols = [ 336 col for col in data.columns if col.startswith("test_on_task_") 337 ] 338 num_tasks = len(processed_task_ids) 339 num_rows = len(data) 340 341 # Build the accuracy matrix 342 acc_matrix = data[test_task_cols].values 343 344 fig, ax = plt.subplots( 345 figsize=(2 * num_tasks, 2 * num_rows) 346 ) # adaptive figure size 347 348 cax = ax.imshow( 349 acc_matrix, 350 interpolation="nearest", 351 cmap="Greens", 352 vmin=0, 353 vmax=1, 354 aspect="auto", 355 ) 356 357 colorbar = fig.colorbar(cax) 358 yticks = colorbar.ax.get_yticks() 359 colorbar.ax.set_yticks(yticks) 360 colorbar.ax.set_yticklabels( 361 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 362 ) 363 364 # Annotate each cell 365 for r in range(num_rows): 366 for c in range(r + 1): 367 ax.text( 368 c, 369 r, 370 f"{acc_matrix[r, c]:.3f}", 371 ha="center", 372 va="center", 373 color="black", 374 fontsize=10 + num_tasks, 375 ) 376 377 ax.set_xticks(range(num_tasks)) 378 ax.set_yticks(range(num_rows)) 379 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 380 ax.set_yticklabels( 381 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 382 ) 383 384 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 385 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 386 fig.tight_layout() 387 fig.savefig(plot_path) 388 plt.close(fig) 389 390 def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None: 391 """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory. 392 393 **Args:** 394 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 395 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'. 396 """ 397 data = pd.read_csv(csv_path) 398 after_training_tasks = data["after_training_task"].astype(int).tolist() 399 400 # plot the average accuracy curve over different training tasks 401 fig, ax = plt.subplots(figsize=(16, 9)) 402 ax.plot( 403 after_training_tasks, 404 data["average_accuracy"], 405 marker="o", 406 linewidth=2, 407 ) 408 ax.set_xlabel("After training task $t$", fontsize=16) 409 ax.set_ylabel("Average Accuracy (AA)", fontsize=16) 410 ax.grid(True) 411 xticks = after_training_tasks 412 yticks = [i * 0.05 for i in range(21)] 413 ax.set_xticks(xticks) 414 ax.set_yticks(yticks) 415 ax.set_xticklabels(xticks, fontsize=16) 416 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 417 fig.savefig(plot_path) 418 plt.close(fig)
27class CLAccuracy(MetricCallback): 28 r"""Provides all actions that are related to CL accuracy metric, which include: 29 30 - Defining, initializing and recording accuracy metric. 31 - Logging training and validation accuracy metric to Lightning loggers in real time. 32 - Saving test accuracy metric to files. 33 - Visualizing test accuracy metric as plots. 34 35 The callback is able to produce the following outputs: 36 37 - CSV files for test accuracy (lower triangular) matrix and average accuracy. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 38 - Coloured plot for test accuracy (lower triangular) matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. 39 - Curve plots for test average accuracy over different training tasks. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-average-test-performance-over-tasks) for details. 40 41 Please refer to the [A Summary of Continual Learning Metrics](https://pengxiang-wang.com/posts/continual-learning-metrics) to learn about this metric. 42 """ 43 44 def __init__( 45 self, 46 save_dir: str, 47 test_acc_csv_name: str = "acc.csv", 48 test_acc_matrix_plot_name: str | None = None, 49 test_ave_acc_plot_name: str | None = None, 50 ) -> None: 51 r""" 52 **Args:** 53 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 54 - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file. 55 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 56 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 57 """ 58 super().__init__(save_dir=save_dir) 59 60 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 61 r"""The path to save test accuracy matrix and average accuracy CSV file.""" 62 if test_acc_matrix_plot_name: 63 self.test_acc_matrix_plot_path: str = os.path.join( 64 save_dir, test_acc_matrix_plot_name 65 ) 66 r"""The path to save test accuracy matrix plot.""" 67 if test_ave_acc_plot_name: 68 self.test_ave_acc_plot_path: str = os.path.join( 69 save_dir, test_ave_acc_plot_name 70 ) 71 r"""The path to save test average accuracy curve plot.""" 72 73 # training accumulated metrics 74 self.acc_training_epoch: MeanMetricBatch 75 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 76 77 # validation accumulated metrics 78 self.acc_val: MeanMetricBatch 79 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 80 81 # test accumulated metrics 82 self.acc_test: dict[int, MeanMetricBatch] 83 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 84 85 # task ID control 86 self.task_id: int 87 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`.""" 88 89 @rank_zero_only 90 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 91 r"""Initialize training and validation metrics.""" 92 93 # set the current task_id from the `CLAlgorithm` object 94 self.task_id = pl_module.task_id 95 96 # get the device to put the metrics on the same device 97 device = pl_module.device 98 99 # initialize training metrics 100 self.acc_training_epoch = MeanMetricBatch().to(device) 101 102 # initialize validation metrics 103 self.acc_val = MeanMetricBatch().to(device) 104 105 @rank_zero_only 106 def on_train_batch_end( 107 self, 108 trainer: Trainer, 109 pl_module: CLAlgorithm, 110 outputs: dict[str, Any], 111 batch: Any, 112 batch_idx: int, 113 ) -> None: 114 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 115 116 **Args:** 117 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 118 - **batch** (`Any`): the training data batch. 119 """ 120 # get the batch size 121 batch_size = len(batch) 122 123 # get training metrics values of current training batch from the outputs of the `training_step()` 124 acc_batch = outputs["acc"] 125 126 # update accumulated training metrics to calculate training metrics of the epoch 127 self.acc_training_epoch.update(acc_batch, batch_size) 128 129 # log training metrics of current training batch to Lightning loggers 130 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 131 132 # log accumulated training metrics till this training batch to Lightning loggers 133 pl_module.log( 134 f"task_{self.task_id}/train/acc", 135 self.acc_training_epoch.compute(), 136 prog_bar=True, 137 ) 138 139 @rank_zero_only 140 def on_train_epoch_end( 141 self, 142 trainer: Trainer, 143 pl_module: CLAlgorithm, 144 ) -> None: 145 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 146 147 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 148 pl_module.log( 149 f"task_{self.task_id}/learning_curve/train/acc", 150 self.acc_training_epoch.compute(), 151 on_epoch=True, 152 prog_bar=True, 153 ) 154 155 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 156 self.acc_training_epoch.reset() 157 158 @rank_zero_only 159 def on_validation_batch_end( 160 self, 161 trainer: Trainer, 162 pl_module: CLAlgorithm, 163 outputs: dict[str, Any], 164 batch: Any, 165 batch_idx: int, 166 ) -> None: 167 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 168 169 **Args:** 170 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 171 - **batch** (`Any`): the validation data batch. 172 """ 173 174 # get the batch size 175 batch_size = len(batch) 176 177 # get the metrics values of the batch from the outputs 178 acc_batch = outputs["acc"] 179 180 # update the accumulated metrics in order to calculate the validation metrics 181 self.acc_val.update(acc_batch, batch_size) 182 183 @rank_zero_only 184 def on_validation_epoch_end( 185 self, 186 trainer: Trainer, 187 pl_module: CLAlgorithm, 188 ) -> None: 189 r"""Log validation metrics to plot learning curves.""" 190 191 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 192 pl_module.log( 193 f"task_{self.task_id}/learning_curve/val/acc", 194 self.acc_val.compute(), 195 on_epoch=True, 196 prog_bar=True, 197 ) 198 199 @rank_zero_only 200 def on_test_start( 201 self, 202 trainer: Trainer, 203 pl_module: CLAlgorithm, 204 ) -> None: 205 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 206 207 # set the current task_id again (double checking) from the `CLAlgorithm` object 208 self.task_id = pl_module.task_id 209 210 # get the device to put the metrics on the same device 211 device = pl_module.device 212 213 # initialize test metrics for current and previous tasks 214 self.acc_test = { 215 task_id: MeanMetricBatch().to(device) 216 for task_id in pl_module.processed_task_ids 217 } 218 219 @rank_zero_only 220 def on_test_batch_end( 221 self, 222 trainer: Trainer, 223 pl_module: CLAlgorithm, 224 outputs: dict[str, Any], 225 batch: Any, 226 batch_idx: int, 227 dataloader_idx: int = 0, 228 ) -> None: 229 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 230 231 **Args:** 232 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 233 - **batch** (`Any`): the test data batch. 234 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 235 """ 236 237 # get the batch size 238 batch_size = len(batch) 239 240 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 241 242 # get the metrics values of the batch from the outputs 243 acc_batch = outputs["acc"] 244 245 # update the accumulated metrics in order to calculate the metrics of the epoch 246 self.acc_test[test_task_id].update(acc_batch, batch_size) 247 248 @rank_zero_only 249 def on_test_epoch_end( 250 self, 251 trainer: Trainer, 252 pl_module: CLAlgorithm, 253 ) -> None: 254 r"""Save and plot test metrics at the end of test.""" 255 256 # save (update) the test metrics to CSV files 257 self.update_test_acc_to_csv( 258 after_training_task_id=self.task_id, 259 csv_path=self.test_acc_csv_path, 260 ) 261 262 # plot the test metrics 263 if hasattr(self, "test_acc_matrix_plot_path"): 264 self.plot_test_acc_matrix_from_csv( 265 csv_path=self.test_acc_csv_path, 266 plot_path=self.test_acc_matrix_plot_path, 267 ) 268 if hasattr(self, "test_ave_acc_plot_path"): 269 self.plot_test_ave_acc_curve_from_csv( 270 csv_path=self.test_acc_csv_path, 271 plot_path=self.test_ave_acc_plot_path, 272 ) 273 274 def update_test_acc_to_csv( 275 self, 276 after_training_task_id: int, 277 csv_path: str, 278 ) -> None: 279 r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 280 281 **Args:** 282 - **after_training_task_id** (`int`): the task ID after training. 283 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 284 """ 285 processed_task_ids = list(self.acc_test.keys()) 286 fieldnames = ["after_training_task", "average_accuracy"] + [ 287 f"test_on_task_{task_id}" for task_id in processed_task_ids 288 ] 289 290 new_line = { 291 "after_training_task": after_training_task_id 292 } # construct the first column 293 294 # construct the columns and calculate the average accuracy over tasks at the same time 295 average_accuracy_over_tasks = MeanMetric().to( 296 device=next(iter(self.acc_test.values())).device 297 ) 298 for task_id in processed_task_ids: 299 acc = self.acc_test[task_id].compute().item() 300 new_line[f"test_on_task_{task_id}"] = acc 301 average_accuracy_over_tasks(acc) 302 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 303 304 # write to the csv file 305 is_first = not os.path.exists(csv_path) 306 if not is_first: 307 with open(csv_path, "r", encoding="utf-8") as file: 308 lines = file.readlines() 309 del lines[0] 310 # write header 311 with open(csv_path, "w", encoding="utf-8") as file: 312 writer = csv.DictWriter(file, fieldnames=fieldnames) 313 writer.writeheader() 314 # write metrics 315 with open(csv_path, "a", encoding="utf-8") as file: 316 if not is_first: 317 file.writelines(lines) # write the previous lines 318 writer = csv.DictWriter(file, fieldnames=fieldnames) 319 writer.writerow(new_line) 320 321 def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 322 """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory. 323 324 **Args:** 325 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 326 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'. 327 """ 328 data = pd.read_csv(csv_path) 329 processed_task_ids = [ 330 int(col.replace("test_on_task_", "")) 331 for col in data.columns 332 if col.startswith("test_on_task_") 333 ] 334 335 # Get all columns that start with "test_on_task_" 336 test_task_cols = [ 337 col for col in data.columns if col.startswith("test_on_task_") 338 ] 339 num_tasks = len(processed_task_ids) 340 num_rows = len(data) 341 342 # Build the accuracy matrix 343 acc_matrix = data[test_task_cols].values 344 345 fig, ax = plt.subplots( 346 figsize=(2 * num_tasks, 2 * num_rows) 347 ) # adaptive figure size 348 349 cax = ax.imshow( 350 acc_matrix, 351 interpolation="nearest", 352 cmap="Greens", 353 vmin=0, 354 vmax=1, 355 aspect="auto", 356 ) 357 358 colorbar = fig.colorbar(cax) 359 yticks = colorbar.ax.get_yticks() 360 colorbar.ax.set_yticks(yticks) 361 colorbar.ax.set_yticklabels( 362 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 363 ) 364 365 # Annotate each cell 366 for r in range(num_rows): 367 for c in range(r + 1): 368 ax.text( 369 c, 370 r, 371 f"{acc_matrix[r, c]:.3f}", 372 ha="center", 373 va="center", 374 color="black", 375 fontsize=10 + num_tasks, 376 ) 377 378 ax.set_xticks(range(num_tasks)) 379 ax.set_yticks(range(num_rows)) 380 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 381 ax.set_yticklabels( 382 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 383 ) 384 385 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 386 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 387 fig.tight_layout() 388 fig.savefig(plot_path) 389 plt.close(fig) 390 391 def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None: 392 """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory. 393 394 **Args:** 395 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 396 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'. 397 """ 398 data = pd.read_csv(csv_path) 399 after_training_tasks = data["after_training_task"].astype(int).tolist() 400 401 # plot the average accuracy curve over different training tasks 402 fig, ax = plt.subplots(figsize=(16, 9)) 403 ax.plot( 404 after_training_tasks, 405 data["average_accuracy"], 406 marker="o", 407 linewidth=2, 408 ) 409 ax.set_xlabel("After training task $t$", fontsize=16) 410 ax.set_ylabel("Average Accuracy (AA)", fontsize=16) 411 ax.grid(True) 412 xticks = after_training_tasks 413 yticks = [i * 0.05 for i in range(21)] 414 ax.set_xticks(xticks) 415 ax.set_yticks(yticks) 416 ax.set_xticklabels(xticks, fontsize=16) 417 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 418 fig.savefig(plot_path) 419 plt.close(fig)
Provides all actions that are related to CL accuracy metric, which include:
- Defining, initializing and recording accuracy metric.
- Logging training and validation accuracy metric to Lightning loggers in real time.
- Saving test accuracy metric to files.
- Visualizing test accuracy metric as plots.
The callback is able to produce the following outputs:
- CSV files for test accuracy (lower triangular) matrix and average accuracy. See here for details.
- Coloured plot for test accuracy (lower triangular) matrix. See here for details.
- Curve plots for test average accuracy over different training tasks. See here for details.
Please refer to the A Summary of Continual Learning Metrics to learn about this metric.
44 def __init__( 45 self, 46 save_dir: str, 47 test_acc_csv_name: str = "acc.csv", 48 test_acc_matrix_plot_name: str | None = None, 49 test_ave_acc_plot_name: str | None = None, 50 ) -> None: 51 r""" 52 **Args:** 53 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 54 - **test_acc_csv_name** (`str`): file name to save test accuracy matrix and average accuracy as CSV file. 55 - **test_acc_matrix_plot_name** (`str` | `None`): file name to save accuracy matrix plot. If `None`, no file will be saved. 56 - **test_ave_acc_plot_name** (`str` | `None`): file name to save average accuracy as curve plot over different training tasks. If `None`, no file will be saved. 57 """ 58 super().__init__(save_dir=save_dir) 59 60 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 61 r"""The path to save test accuracy matrix and average accuracy CSV file.""" 62 if test_acc_matrix_plot_name: 63 self.test_acc_matrix_plot_path: str = os.path.join( 64 save_dir, test_acc_matrix_plot_name 65 ) 66 r"""The path to save test accuracy matrix plot.""" 67 if test_ave_acc_plot_name: 68 self.test_ave_acc_plot_path: str = os.path.join( 69 save_dir, test_ave_acc_plot_name 70 ) 71 r"""The path to save test average accuracy curve plot.""" 72 73 # training accumulated metrics 74 self.acc_training_epoch: MeanMetricBatch 75 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-performance-of-training-epoch) for details. """ 76 77 # validation accumulated metrics 78 self.acc_val: MeanMetricBatch 79 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-validation-performace) for details. """ 80 81 # test accumulated metrics 82 self.acc_test: dict[int, MeanMetricBatch] 83 r"""Test classification accuracy of the current model (`self.task_id`) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See [here](https://pengxiang-wang.com/posts/continual-learning-metrics.html#sec-test-performance-of-previous-tasks) for details. """ 84 85 # task ID control 86 self.task_id: int 87 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to `cl_dataset.num_tasks`."""
Args:
- save_dir (
str
): The directory where data and figures of metrics will be saved. Better inside the output folder. - test_acc_csv_name (
str
): file name to save test accuracy matrix and average accuracy as CSV file. - test_acc_matrix_plot_name (
str
|None
): file name to save accuracy matrix plot. IfNone
, no file will be saved. - test_ave_acc_plot_name (
str
|None
): file name to save average accuracy as curve plot over different training tasks. IfNone
, no file will be saved.
Classification accuracy of training epoch. Accumulated and calculated from the training batches. See here for details.
Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. See here for details.
Test classification accuracy of the current model (self.task_id
) on current and previous tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. It is the last row of the lower triangular matrix. See here for details.
Task ID counter indicating which task is being processed. Self updated during the task loop. Valid from 1 to cl_dataset.num_tasks
.
89 @rank_zero_only 90 def on_fit_start(self, trainer: Trainer, pl_module: CLAlgorithm) -> None: 91 r"""Initialize training and validation metrics.""" 92 93 # set the current task_id from the `CLAlgorithm` object 94 self.task_id = pl_module.task_id 95 96 # get the device to put the metrics on the same device 97 device = pl_module.device 98 99 # initialize training metrics 100 self.acc_training_epoch = MeanMetricBatch().to(device) 101 102 # initialize validation metrics 103 self.acc_val = MeanMetricBatch().to(device)
Initialize training and validation metrics.
105 @rank_zero_only 106 def on_train_batch_end( 107 self, 108 trainer: Trainer, 109 pl_module: CLAlgorithm, 110 outputs: dict[str, Any], 111 batch: Any, 112 batch_idx: int, 113 ) -> None: 114 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 115 116 **Args:** 117 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `CLAlgorithm`. 118 - **batch** (`Any`): the training data batch. 119 """ 120 # get the batch size 121 batch_size = len(batch) 122 123 # get training metrics values of current training batch from the outputs of the `training_step()` 124 acc_batch = outputs["acc"] 125 126 # update accumulated training metrics to calculate training metrics of the epoch 127 self.acc_training_epoch.update(acc_batch, batch_size) 128 129 # log training metrics of current training batch to Lightning loggers 130 pl_module.log(f"task_{self.task_id}/train/acc_batch", acc_batch, prog_bar=True) 131 132 # log accumulated training metrics till this training batch to Lightning loggers 133 pl_module.log( 134 f"task_{self.task_id}/train/acc", 135 self.acc_training_epoch.compute(), 136 prog_bar=True, 137 )
Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
Args:
- outputs (
dict[str, Any]
): the outputs of the training step, the returns of thetraining_step()
method in theCLAlgorithm
. - batch (
Any
): the training data batch.
139 @rank_zero_only 140 def on_train_epoch_end( 141 self, 142 trainer: Trainer, 143 pl_module: CLAlgorithm, 144 ) -> None: 145 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 146 147 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 148 pl_module.log( 149 f"task_{self.task_id}/learning_curve/train/acc", 150 self.acc_training_epoch.compute(), 151 on_epoch=True, 152 prog_bar=True, 153 ) 154 155 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 156 self.acc_training_epoch.reset()
Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.
158 @rank_zero_only 159 def on_validation_batch_end( 160 self, 161 trainer: Trainer, 162 pl_module: CLAlgorithm, 163 outputs: dict[str, Any], 164 batch: Any, 165 batch_idx: int, 166 ) -> None: 167 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 168 169 **Args:** 170 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `CLAlgorithm`. 171 - **batch** (`Any`): the validation data batch. 172 """ 173 174 # get the batch size 175 batch_size = len(batch) 176 177 # get the metrics values of the batch from the outputs 178 acc_batch = outputs["acc"] 179 180 # update the accumulated metrics in order to calculate the validation metrics 181 self.acc_val.update(acc_batch, batch_size)
Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
Args:
- outputs (
dict[str, Any]
): the outputs of the validation step, which is the returns of thevalidation_step()
method in theCLAlgorithm
. - batch (
Any
): the validation data batch.
183 @rank_zero_only 184 def on_validation_epoch_end( 185 self, 186 trainer: Trainer, 187 pl_module: CLAlgorithm, 188 ) -> None: 189 r"""Log validation metrics to plot learning curves.""" 190 191 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 192 pl_module.log( 193 f"task_{self.task_id}/learning_curve/val/acc", 194 self.acc_val.compute(), 195 on_epoch=True, 196 prog_bar=True, 197 )
Log validation metrics to plot learning curves.
199 @rank_zero_only 200 def on_test_start( 201 self, 202 trainer: Trainer, 203 pl_module: CLAlgorithm, 204 ) -> None: 205 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 206 207 # set the current task_id again (double checking) from the `CLAlgorithm` object 208 self.task_id = pl_module.task_id 209 210 # get the device to put the metrics on the same device 211 device = pl_module.device 212 213 # initialize test metrics for current and previous tasks 214 self.acc_test = { 215 task_id: MeanMetricBatch().to(device) 216 for task_id in pl_module.processed_task_ids 217 }
Initialize the metrics for testing each seen task in the beginning of a task's testing.
219 @rank_zero_only 220 def on_test_batch_end( 221 self, 222 trainer: Trainer, 223 pl_module: CLAlgorithm, 224 outputs: dict[str, Any], 225 batch: Any, 226 batch_idx: int, 227 dataloader_idx: int = 0, 228 ) -> None: 229 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 230 231 **Args:** 232 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `CLAlgorithm`. 233 - **batch** (`Any`): the test data batch. 234 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 235 """ 236 237 # get the batch size 238 batch_size = len(batch) 239 240 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 241 242 # get the metrics values of the batch from the outputs 243 acc_batch = outputs["acc"] 244 245 # update the accumulated metrics in order to calculate the metrics of the epoch 246 self.acc_test[test_task_id].update(acc_batch, batch_size)
Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
Args:
- outputs (
dict[str, Any]
): the outputs of the test step, which is the returns of thetest_step()
method in theCLAlgorithm
. - batch (
Any
): the test data batch. - dataloader_idx (
int
): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.
248 @rank_zero_only 249 def on_test_epoch_end( 250 self, 251 trainer: Trainer, 252 pl_module: CLAlgorithm, 253 ) -> None: 254 r"""Save and plot test metrics at the end of test.""" 255 256 # save (update) the test metrics to CSV files 257 self.update_test_acc_to_csv( 258 after_training_task_id=self.task_id, 259 csv_path=self.test_acc_csv_path, 260 ) 261 262 # plot the test metrics 263 if hasattr(self, "test_acc_matrix_plot_path"): 264 self.plot_test_acc_matrix_from_csv( 265 csv_path=self.test_acc_csv_path, 266 plot_path=self.test_acc_matrix_plot_path, 267 ) 268 if hasattr(self, "test_ave_acc_plot_path"): 269 self.plot_test_ave_acc_curve_from_csv( 270 csv_path=self.test_acc_csv_path, 271 plot_path=self.test_ave_acc_plot_path, 272 )
Save and plot test metrics at the end of test.
274 def update_test_acc_to_csv( 275 self, 276 after_training_task_id: int, 277 csv_path: str, 278 ) -> None: 279 r"""Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing. 280 281 **Args:** 282 - **after_training_task_id** (`int`): the task ID after training. 283 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 284 """ 285 processed_task_ids = list(self.acc_test.keys()) 286 fieldnames = ["after_training_task", "average_accuracy"] + [ 287 f"test_on_task_{task_id}" for task_id in processed_task_ids 288 ] 289 290 new_line = { 291 "after_training_task": after_training_task_id 292 } # construct the first column 293 294 # construct the columns and calculate the average accuracy over tasks at the same time 295 average_accuracy_over_tasks = MeanMetric().to( 296 device=next(iter(self.acc_test.values())).device 297 ) 298 for task_id in processed_task_ids: 299 acc = self.acc_test[task_id].compute().item() 300 new_line[f"test_on_task_{task_id}"] = acc 301 average_accuracy_over_tasks(acc) 302 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 303 304 # write to the csv file 305 is_first = not os.path.exists(csv_path) 306 if not is_first: 307 with open(csv_path, "r", encoding="utf-8") as file: 308 lines = file.readlines() 309 del lines[0] 310 # write header 311 with open(csv_path, "w", encoding="utf-8") as file: 312 writer = csv.DictWriter(file, fieldnames=fieldnames) 313 writer.writeheader() 314 # write metrics 315 with open(csv_path, "a", encoding="utf-8") as file: 316 if not is_first: 317 file.writelines(lines) # write the previous lines 318 writer = csv.DictWriter(file, fieldnames=fieldnames) 319 writer.writerow(new_line)
Update the test accuracy metrics of seen tasks at the last line to an existing CSV file. A new file will be created if not existing.
Args:
- after_training_task_id (
int
): the task ID after training. - csv_path (
str
): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
321 def plot_test_acc_matrix_from_csv(self, csv_path: str, plot_path: str) -> None: 322 """Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory. 323 324 **Args:** 325 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 326 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'. 327 """ 328 data = pd.read_csv(csv_path) 329 processed_task_ids = [ 330 int(col.replace("test_on_task_", "")) 331 for col in data.columns 332 if col.startswith("test_on_task_") 333 ] 334 335 # Get all columns that start with "test_on_task_" 336 test_task_cols = [ 337 col for col in data.columns if col.startswith("test_on_task_") 338 ] 339 num_tasks = len(processed_task_ids) 340 num_rows = len(data) 341 342 # Build the accuracy matrix 343 acc_matrix = data[test_task_cols].values 344 345 fig, ax = plt.subplots( 346 figsize=(2 * num_tasks, 2 * num_rows) 347 ) # adaptive figure size 348 349 cax = ax.imshow( 350 acc_matrix, 351 interpolation="nearest", 352 cmap="Greens", 353 vmin=0, 354 vmax=1, 355 aspect="auto", 356 ) 357 358 colorbar = fig.colorbar(cax) 359 yticks = colorbar.ax.get_yticks() 360 colorbar.ax.set_yticks(yticks) 361 colorbar.ax.set_yticklabels( 362 [f"{tick:.2f}" for tick in yticks], fontsize=10 + num_tasks 363 ) 364 365 # Annotate each cell 366 for r in range(num_rows): 367 for c in range(r + 1): 368 ax.text( 369 c, 370 r, 371 f"{acc_matrix[r, c]:.3f}", 372 ha="center", 373 va="center", 374 color="black", 375 fontsize=10 + num_tasks, 376 ) 377 378 ax.set_xticks(range(num_tasks)) 379 ax.set_yticks(range(num_rows)) 380 ax.set_xticklabels(processed_task_ids, fontsize=10 + num_tasks) 381 ax.set_yticklabels( 382 data["after_training_task"].astype(int).tolist(), fontsize=10 + num_tasks 383 ) 384 385 ax.set_xlabel("Testing on task τ", fontsize=10 + num_tasks) 386 ax.set_ylabel("After training task t", fontsize=10 + num_tasks) 387 fig.tight_layout() 388 fig.savefig(plot_path) 389 plt.close(fig)
Plot the test accuracy matrix from saved CSV file and save the plot to the designated directory.
Args:
- csv_path (
str
): the path to the CSV file where theutils.update_test_acc_to_csv()
saved the test accuracy metric. - plot_path (
str
): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc_matrix.png'.
391 def plot_test_ave_acc_curve_from_csv(self, csv_path: str, plot_path: str) -> None: 392 """Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory. 393 394 **Args:** 395 - **csv_path** (`str`): the path to the CSV file where the `utils.update_test_acc_to_csv()` saved the test accuracy metric. 396 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'. 397 """ 398 data = pd.read_csv(csv_path) 399 after_training_tasks = data["after_training_task"].astype(int).tolist() 400 401 # plot the average accuracy curve over different training tasks 402 fig, ax = plt.subplots(figsize=(16, 9)) 403 ax.plot( 404 after_training_tasks, 405 data["average_accuracy"], 406 marker="o", 407 linewidth=2, 408 ) 409 ax.set_xlabel("After training task $t$", fontsize=16) 410 ax.set_ylabel("Average Accuracy (AA)", fontsize=16) 411 ax.grid(True) 412 xticks = after_training_tasks 413 yticks = [i * 0.05 for i in range(21)] 414 ax.set_xticks(xticks) 415 ax.set_yticks(yticks) 416 ax.set_xticklabels(xticks, fontsize=16) 417 ax.set_yticklabels([f"{tick:.2f}" for tick in yticks], fontsize=16) 418 fig.savefig(plot_path) 419 plt.close(fig)
Plot the test average accuracy curve over different training tasks from saved CSV file and save the plot to the designated directory.
Args:
- csv_path (
str
): the path to the CSV file where theutils.update_test_acc_to_csv()
saved the test accuracy metric. - plot_path (
str
): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/ave_acc.png'.