clarena.metrics.mtl_acc
The submodule in metrics for MTLAccuracy.
1r""" 2The submodule in `metrics` for `MTLAccuracy`. 3""" 4 5__all__ = ["MTLAccuracy"] 6 7import csv 8import logging 9import os 10from typing import Any 11 12import pandas as pd 13from lightning import Trainer 14from matplotlib import pyplot as plt 15from torchmetrics import MeanMetric 16 17from clarena.metrics import MetricCallback 18from clarena.mtl_algorithms import MTLAlgorithm 19from clarena.utils.metrics import MeanMetricBatch 20 21# always get logger for built-in logging in each module 22pylogger = logging.getLogger(__name__) 23 24 25class MTLAccuracy(MetricCallback): 26 r"""Provides all actions that are related to MTL accuracy metrics, which include: 27 28 - Defining, initializing and recording accuracy metric. 29 - Logging training and validation accuracy metric to Lightning loggers in real time. 30 - Saving test accuracy metric to files. 31 - Visualizing test accuracy metric as plots. 32 33 The callback is able to produce the following outputs: 34 35 - CSV files for test accuracy of all tasks and average accuracy. 36 - Bar charts for test accuracy of all tasks. 37 """ 38 39 def __init__( 40 self, 41 save_dir: str, 42 test_acc_csv_name: str = "acc.csv", 43 test_acc_plot_name: str | None = None, 44 ) -> None: 45 r""" 46 **Args:** 47 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 48 - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file. 49 - **test_acc_plot_name** (`str` | `None`): file name to save accuracy plot. If `None`, no file will be saved. 50 """ 51 super().__init__(save_dir=save_dir) 52 53 # paths 54 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 55 r"""The path to save test accuracy of all tasks and average accuracy CSV file.""" 56 if test_acc_plot_name: 57 self.test_acc_plot_path: str = os.path.join(save_dir, test_acc_plot_name) 58 r"""The path to save test accuracy plot.""" 59 60 # training accumulated metrics 61 self.acc_training_epoch: MeanMetricBatch 62 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """ 63 64 # validation accumulated metrics 65 self.acc_val: dict[int, MeanMetricBatch] = {} 66 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics.""" 67 68 # test accumulated metrics 69 self.acc_test: dict[int, MeanMetricBatch] = {} 70 r"""Test classification accuracy of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics.""" 71 72 def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None: 73 r"""Initialize training and validation metrics.""" 74 75 # initialize training metrics 76 self.acc_training_epoch = MeanMetricBatch() 77 78 # initialize validation metrics 79 self.acc_val = { 80 task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks 81 } 82 83 def on_train_batch_end( 84 self, 85 trainer: Trainer, 86 pl_module: MTLAlgorithm, 87 outputs: dict[str, Any], 88 batch: Any, 89 batch_idx: int, 90 ) -> None: 91 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 92 93 **Args:** 94 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`. 95 - **batch** (`Any`): the training data batch. 96 """ 97 # get the batch size 98 batch_size = len(batch) 99 100 # get training metrics values of current training batch from the outputs of the `training_step()` 101 acc_batch = outputs["acc"] 102 103 # update accumulated training metrics to calculate training metrics of the epoch 104 self.acc_training_epoch.update(acc_batch, batch_size) 105 106 # log training metrics of current training batch to Lightning loggers 107 pl_module.log("train/acc_batch", acc_batch, prog_bar=True) 108 109 # log accumulated training metrics till this training batch to Lightning loggers 110 pl_module.log( 111 "task/train/acc", 112 self.acc_training_epoch.compute(), 113 prog_bar=True, 114 ) 115 116 def on_train_epoch_end( 117 self, 118 trainer: Trainer, 119 pl_module: MTLAlgorithm, 120 ) -> None: 121 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 122 123 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 124 self.acc_training_epoch.reset() 125 126 def on_validation_batch_end( 127 self, 128 trainer: Trainer, 129 pl_module: MTLAlgorithm, 130 outputs: dict[str, Any], 131 batch: Any, 132 batch_idx: int, 133 dataloader_idx: int = 0, 134 ) -> None: 135 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 136 137 **Args:** 138 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`. 139 - **batch** (`Any`): the validation data batch. 140 - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 141 """ 142 143 # get the batch size 144 batch_size = len(batch) 145 146 val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx) 147 148 # get the metrics values of the batch from the outputs 149 acc_batch = outputs["acc"] 150 151 # update the accumulated metrics in order to calculate the validation metrics 152 self.acc_val[val_task_id].update(acc_batch, batch_size) 153 154 def on_test_start( 155 self, 156 trainer: Trainer, 157 pl_module: MTLAlgorithm, 158 ) -> None: 159 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 160 161 # initialize test metrics for current and previous tasks 162 self.acc_test = { 163 task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks 164 } 165 166 def on_test_batch_end( 167 self, 168 trainer: Trainer, 169 pl_module: MTLAlgorithm, 170 outputs: dict[str, Any], 171 batch: Any, 172 batch_idx: int, 173 dataloader_idx: int = 0, 174 ) -> None: 175 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 176 177 **Args:** 178 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`. 179 - **batch** (`Any`): the validation data batch. 180 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 181 """ 182 183 # get the batch size 184 batch_size = len(batch) 185 186 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 187 188 # get the metrics values of the batch from the outputs 189 acc_batch = outputs["acc"] 190 191 # update the accumulated metrics in order to calculate the metrics of the epoch 192 self.acc_test[test_task_id].update(acc_batch, batch_size) 193 194 def on_test_epoch_end( 195 self, 196 trainer: Trainer, 197 pl_module: MTLAlgorithm, 198 ) -> None: 199 r"""Save and plot test metrics at the end of test.""" 200 201 # save (update) the test metrics to CSV files 202 self.save_test_acc_to_csv( 203 csv_path=self.test_acc_csv_path, 204 ) 205 206 # plot the test metrics 207 if hasattr(self, "test_acc_plot_path"): 208 self.plot_test_acc_from_csv( 209 csv_path=self.test_acc_csv_path, 210 plot_path=self.test_acc_plot_path, 211 ) 212 213 def save_test_acc_to_csv( 214 self, 215 csv_path: str, 216 ) -> None: 217 r"""Save the test accuracy metrics of all tasks in multi-task learning to an CSV file. 218 219 **Args:** 220 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 221 """ 222 all_task_ids = list(self.acc_test.keys()) 223 fieldnames = ["average_accuracy"] + [ 224 f"test_on_task_{task_id}" for task_id in all_task_ids 225 ] 226 new_line = {} 227 228 # construct the columns and calculate the average accuracy over tasks at the same time 229 average_accuracy_over_tasks = MeanMetric().to( 230 device=next(iter(self.acc_test.values())).device 231 ) 232 for task_id in all_task_ids: 233 acc = self.acc_test[task_id].compute().item() 234 new_line[f"test_on_task_{task_id}"] = acc 235 average_accuracy_over_tasks(acc) 236 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 237 238 # write 239 with open(csv_path, "w", encoding="utf-8") as file: 240 writer = csv.DictWriter(file, fieldnames=fieldnames) 241 writer.writeheader() 242 writer.writerow(new_line) 243 244 def plot_test_acc_from_csv(self, csv_path: str, plot_path: str) -> None: 245 """Plot the test accuracy bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory. 246 247 **Args:** 248 - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test accuracy metric. 249 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc.png'. 250 """ 251 data = pd.read_csv(csv_path) 252 253 # extract all accuracy columns including average 254 all_columns = data.columns.tolist() 255 task_ids = list(range(len(all_columns))) # assign index-based positions 256 labels = [ 257 ( 258 col.replace("test_on_task_", "Task ") 259 if "test_on_task_" in col 260 else "Average" 261 ) 262 for col in all_columns 263 ] 264 accuracies = data.iloc[0][all_columns].values 265 266 # plot the accuracy bar chart over tasks 267 fig, ax = plt.subplots(figsize=(16, 9)) 268 ax.bar( 269 task_ids, 270 accuracies, 271 color="skyblue", 272 edgecolor="black", 273 ) 274 ax.set_xlabel("Task", fontsize=16) 275 ax.set_ylabel("Accuracy", fontsize=16) 276 ax.grid(True) 277 ax.set_xticks(task_ids) 278 ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14) 279 ax.set_yticks([i * 0.05 for i in range(21)]) 280 ax.set_yticklabels( 281 [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14 282 ) 283 fig.tight_layout() 284 fig.savefig(plot_path) 285 plt.close(fig)
26class MTLAccuracy(MetricCallback): 27 r"""Provides all actions that are related to MTL accuracy metrics, which include: 28 29 - Defining, initializing and recording accuracy metric. 30 - Logging training and validation accuracy metric to Lightning loggers in real time. 31 - Saving test accuracy metric to files. 32 - Visualizing test accuracy metric as plots. 33 34 The callback is able to produce the following outputs: 35 36 - CSV files for test accuracy of all tasks and average accuracy. 37 - Bar charts for test accuracy of all tasks. 38 """ 39 40 def __init__( 41 self, 42 save_dir: str, 43 test_acc_csv_name: str = "acc.csv", 44 test_acc_plot_name: str | None = None, 45 ) -> None: 46 r""" 47 **Args:** 48 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 49 - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file. 50 - **test_acc_plot_name** (`str` | `None`): file name to save accuracy plot. If `None`, no file will be saved. 51 """ 52 super().__init__(save_dir=save_dir) 53 54 # paths 55 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 56 r"""The path to save test accuracy of all tasks and average accuracy CSV file.""" 57 if test_acc_plot_name: 58 self.test_acc_plot_path: str = os.path.join(save_dir, test_acc_plot_name) 59 r"""The path to save test accuracy plot.""" 60 61 # training accumulated metrics 62 self.acc_training_epoch: MeanMetricBatch 63 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """ 64 65 # validation accumulated metrics 66 self.acc_val: dict[int, MeanMetricBatch] = {} 67 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics.""" 68 69 # test accumulated metrics 70 self.acc_test: dict[int, MeanMetricBatch] = {} 71 r"""Test classification accuracy of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics.""" 72 73 def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None: 74 r"""Initialize training and validation metrics.""" 75 76 # initialize training metrics 77 self.acc_training_epoch = MeanMetricBatch() 78 79 # initialize validation metrics 80 self.acc_val = { 81 task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks 82 } 83 84 def on_train_batch_end( 85 self, 86 trainer: Trainer, 87 pl_module: MTLAlgorithm, 88 outputs: dict[str, Any], 89 batch: Any, 90 batch_idx: int, 91 ) -> None: 92 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 93 94 **Args:** 95 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`. 96 - **batch** (`Any`): the training data batch. 97 """ 98 # get the batch size 99 batch_size = len(batch) 100 101 # get training metrics values of current training batch from the outputs of the `training_step()` 102 acc_batch = outputs["acc"] 103 104 # update accumulated training metrics to calculate training metrics of the epoch 105 self.acc_training_epoch.update(acc_batch, batch_size) 106 107 # log training metrics of current training batch to Lightning loggers 108 pl_module.log("train/acc_batch", acc_batch, prog_bar=True) 109 110 # log accumulated training metrics till this training batch to Lightning loggers 111 pl_module.log( 112 "task/train/acc", 113 self.acc_training_epoch.compute(), 114 prog_bar=True, 115 ) 116 117 def on_train_epoch_end( 118 self, 119 trainer: Trainer, 120 pl_module: MTLAlgorithm, 121 ) -> None: 122 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 123 124 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 125 self.acc_training_epoch.reset() 126 127 def on_validation_batch_end( 128 self, 129 trainer: Trainer, 130 pl_module: MTLAlgorithm, 131 outputs: dict[str, Any], 132 batch: Any, 133 batch_idx: int, 134 dataloader_idx: int = 0, 135 ) -> None: 136 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 137 138 **Args:** 139 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`. 140 - **batch** (`Any`): the validation data batch. 141 - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 142 """ 143 144 # get the batch size 145 batch_size = len(batch) 146 147 val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx) 148 149 # get the metrics values of the batch from the outputs 150 acc_batch = outputs["acc"] 151 152 # update the accumulated metrics in order to calculate the validation metrics 153 self.acc_val[val_task_id].update(acc_batch, batch_size) 154 155 def on_test_start( 156 self, 157 trainer: Trainer, 158 pl_module: MTLAlgorithm, 159 ) -> None: 160 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 161 162 # initialize test metrics for current and previous tasks 163 self.acc_test = { 164 task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks 165 } 166 167 def on_test_batch_end( 168 self, 169 trainer: Trainer, 170 pl_module: MTLAlgorithm, 171 outputs: dict[str, Any], 172 batch: Any, 173 batch_idx: int, 174 dataloader_idx: int = 0, 175 ) -> None: 176 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 177 178 **Args:** 179 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`. 180 - **batch** (`Any`): the validation data batch. 181 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 182 """ 183 184 # get the batch size 185 batch_size = len(batch) 186 187 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 188 189 # get the metrics values of the batch from the outputs 190 acc_batch = outputs["acc"] 191 192 # update the accumulated metrics in order to calculate the metrics of the epoch 193 self.acc_test[test_task_id].update(acc_batch, batch_size) 194 195 def on_test_epoch_end( 196 self, 197 trainer: Trainer, 198 pl_module: MTLAlgorithm, 199 ) -> None: 200 r"""Save and plot test metrics at the end of test.""" 201 202 # save (update) the test metrics to CSV files 203 self.save_test_acc_to_csv( 204 csv_path=self.test_acc_csv_path, 205 ) 206 207 # plot the test metrics 208 if hasattr(self, "test_acc_plot_path"): 209 self.plot_test_acc_from_csv( 210 csv_path=self.test_acc_csv_path, 211 plot_path=self.test_acc_plot_path, 212 ) 213 214 def save_test_acc_to_csv( 215 self, 216 csv_path: str, 217 ) -> None: 218 r"""Save the test accuracy metrics of all tasks in multi-task learning to an CSV file. 219 220 **Args:** 221 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 222 """ 223 all_task_ids = list(self.acc_test.keys()) 224 fieldnames = ["average_accuracy"] + [ 225 f"test_on_task_{task_id}" for task_id in all_task_ids 226 ] 227 new_line = {} 228 229 # construct the columns and calculate the average accuracy over tasks at the same time 230 average_accuracy_over_tasks = MeanMetric().to( 231 device=next(iter(self.acc_test.values())).device 232 ) 233 for task_id in all_task_ids: 234 acc = self.acc_test[task_id].compute().item() 235 new_line[f"test_on_task_{task_id}"] = acc 236 average_accuracy_over_tasks(acc) 237 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 238 239 # write 240 with open(csv_path, "w", encoding="utf-8") as file: 241 writer = csv.DictWriter(file, fieldnames=fieldnames) 242 writer.writeheader() 243 writer.writerow(new_line) 244 245 def plot_test_acc_from_csv(self, csv_path: str, plot_path: str) -> None: 246 """Plot the test accuracy bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory. 247 248 **Args:** 249 - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test accuracy metric. 250 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc.png'. 251 """ 252 data = pd.read_csv(csv_path) 253 254 # extract all accuracy columns including average 255 all_columns = data.columns.tolist() 256 task_ids = list(range(len(all_columns))) # assign index-based positions 257 labels = [ 258 ( 259 col.replace("test_on_task_", "Task ") 260 if "test_on_task_" in col 261 else "Average" 262 ) 263 for col in all_columns 264 ] 265 accuracies = data.iloc[0][all_columns].values 266 267 # plot the accuracy bar chart over tasks 268 fig, ax = plt.subplots(figsize=(16, 9)) 269 ax.bar( 270 task_ids, 271 accuracies, 272 color="skyblue", 273 edgecolor="black", 274 ) 275 ax.set_xlabel("Task", fontsize=16) 276 ax.set_ylabel("Accuracy", fontsize=16) 277 ax.grid(True) 278 ax.set_xticks(task_ids) 279 ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14) 280 ax.set_yticks([i * 0.05 for i in range(21)]) 281 ax.set_yticklabels( 282 [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14 283 ) 284 fig.tight_layout() 285 fig.savefig(plot_path) 286 plt.close(fig)
Provides all actions that are related to MTL accuracy metrics, which include:
- Defining, initializing and recording accuracy metric.
- Logging training and validation accuracy metric to Lightning loggers in real time.
- Saving test accuracy metric to files.
- Visualizing test accuracy metric as plots.
The callback is able to produce the following outputs:
- CSV files for test accuracy of all tasks and average accuracy.
- Bar charts for test accuracy of all tasks.
40 def __init__( 41 self, 42 save_dir: str, 43 test_acc_csv_name: str = "acc.csv", 44 test_acc_plot_name: str | None = None, 45 ) -> None: 46 r""" 47 **Args:** 48 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 49 - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file. 50 - **test_acc_plot_name** (`str` | `None`): file name to save accuracy plot. If `None`, no file will be saved. 51 """ 52 super().__init__(save_dir=save_dir) 53 54 # paths 55 self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name) 56 r"""The path to save test accuracy of all tasks and average accuracy CSV file.""" 57 if test_acc_plot_name: 58 self.test_acc_plot_path: str = os.path.join(save_dir, test_acc_plot_name) 59 r"""The path to save test accuracy plot.""" 60 61 # training accumulated metrics 62 self.acc_training_epoch: MeanMetricBatch 63 r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """ 64 65 # validation accumulated metrics 66 self.acc_val: dict[int, MeanMetricBatch] = {} 67 r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics.""" 68 69 # test accumulated metrics 70 self.acc_test: dict[int, MeanMetricBatch] = {} 71 r"""Test classification accuracy of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics."""
Args:
- save_dir (
str): The directory where data and figures of metrics will be saved. Better inside the output folder. - test_acc_csv_name (
str): file name to save test accuracy of all tasks and average accuracy as CSV file. - test_acc_plot_name (
str|None): file name to save accuracy plot. IfNone, no file will be saved.
Classification accuracy of training epoch. Accumulated and calculated from the training batches.
Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics.
Test classification accuracy of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics.
73 def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None: 74 r"""Initialize training and validation metrics.""" 75 76 # initialize training metrics 77 self.acc_training_epoch = MeanMetricBatch() 78 79 # initialize validation metrics 80 self.acc_val = { 81 task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks 82 }
Initialize training and validation metrics.
84 def on_train_batch_end( 85 self, 86 trainer: Trainer, 87 pl_module: MTLAlgorithm, 88 outputs: dict[str, Any], 89 batch: Any, 90 batch_idx: int, 91 ) -> None: 92 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 93 94 **Args:** 95 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`. 96 - **batch** (`Any`): the training data batch. 97 """ 98 # get the batch size 99 batch_size = len(batch) 100 101 # get training metrics values of current training batch from the outputs of the `training_step()` 102 acc_batch = outputs["acc"] 103 104 # update accumulated training metrics to calculate training metrics of the epoch 105 self.acc_training_epoch.update(acc_batch, batch_size) 106 107 # log training metrics of current training batch to Lightning loggers 108 pl_module.log("train/acc_batch", acc_batch, prog_bar=True) 109 110 # log accumulated training metrics till this training batch to Lightning loggers 111 pl_module.log( 112 "task/train/acc", 113 self.acc_training_epoch.compute(), 114 prog_bar=True, 115 )
Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
Args:
- outputs (
dict[str, Any]): the outputs of the training step, the returns of thetraining_step()method in theMTLAlgorithm. - batch (
Any): the training data batch.
117 def on_train_epoch_end( 118 self, 119 trainer: Trainer, 120 pl_module: MTLAlgorithm, 121 ) -> None: 122 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 123 124 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 125 self.acc_training_epoch.reset()
Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.
127 def on_validation_batch_end( 128 self, 129 trainer: Trainer, 130 pl_module: MTLAlgorithm, 131 outputs: dict[str, Any], 132 batch: Any, 133 batch_idx: int, 134 dataloader_idx: int = 0, 135 ) -> None: 136 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 137 138 **Args:** 139 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`. 140 - **batch** (`Any`): the validation data batch. 141 - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 142 """ 143 144 # get the batch size 145 batch_size = len(batch) 146 147 val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx) 148 149 # get the metrics values of the batch from the outputs 150 acc_batch = outputs["acc"] 151 152 # update the accumulated metrics in order to calculate the validation metrics 153 self.acc_val[val_task_id].update(acc_batch, batch_size)
Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
Args:
- outputs (
dict[str, Any]): the outputs of the validation step, which is the returns of thevalidation_step()method in theMTLAlgorithm. - batch (
Any): the validation data batch. - dataloader_idx (
int): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
155 def on_test_start( 156 self, 157 trainer: Trainer, 158 pl_module: MTLAlgorithm, 159 ) -> None: 160 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 161 162 # initialize test metrics for current and previous tasks 163 self.acc_test = { 164 task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks 165 }
Initialize the metrics for testing each seen task in the beginning of a task's testing.
167 def on_test_batch_end( 168 self, 169 trainer: Trainer, 170 pl_module: MTLAlgorithm, 171 outputs: dict[str, Any], 172 batch: Any, 173 batch_idx: int, 174 dataloader_idx: int = 0, 175 ) -> None: 176 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 177 178 **Args:** 179 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`. 180 - **batch** (`Any`): the validation data batch. 181 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 182 """ 183 184 # get the batch size 185 batch_size = len(batch) 186 187 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 188 189 # get the metrics values of the batch from the outputs 190 acc_batch = outputs["acc"] 191 192 # update the accumulated metrics in order to calculate the metrics of the epoch 193 self.acc_test[test_task_id].update(acc_batch, batch_size)
Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
Args:
- outputs (
dict[str, Any]): the outputs of the test step, which is the returns of thetest_step()method in theMTLAlgorithm. - batch (
Any): the validation data batch. - dataloader_idx (
int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError.
195 def on_test_epoch_end( 196 self, 197 trainer: Trainer, 198 pl_module: MTLAlgorithm, 199 ) -> None: 200 r"""Save and plot test metrics at the end of test.""" 201 202 # save (update) the test metrics to CSV files 203 self.save_test_acc_to_csv( 204 csv_path=self.test_acc_csv_path, 205 ) 206 207 # plot the test metrics 208 if hasattr(self, "test_acc_plot_path"): 209 self.plot_test_acc_from_csv( 210 csv_path=self.test_acc_csv_path, 211 plot_path=self.test_acc_plot_path, 212 )
Save and plot test metrics at the end of test.
214 def save_test_acc_to_csv( 215 self, 216 csv_path: str, 217 ) -> None: 218 r"""Save the test accuracy metrics of all tasks in multi-task learning to an CSV file. 219 220 **Args:** 221 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'. 222 """ 223 all_task_ids = list(self.acc_test.keys()) 224 fieldnames = ["average_accuracy"] + [ 225 f"test_on_task_{task_id}" for task_id in all_task_ids 226 ] 227 new_line = {} 228 229 # construct the columns and calculate the average accuracy over tasks at the same time 230 average_accuracy_over_tasks = MeanMetric().to( 231 device=next(iter(self.acc_test.values())).device 232 ) 233 for task_id in all_task_ids: 234 acc = self.acc_test[task_id].compute().item() 235 new_line[f"test_on_task_{task_id}"] = acc 236 average_accuracy_over_tasks(acc) 237 new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item() 238 239 # write 240 with open(csv_path, "w", encoding="utf-8") as file: 241 writer = csv.DictWriter(file, fieldnames=fieldnames) 242 writer.writeheader() 243 writer.writerow(new_line)
Save the test accuracy metrics of all tasks in multi-task learning to an CSV file.
Args:
- csv_path (
str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
245 def plot_test_acc_from_csv(self, csv_path: str, plot_path: str) -> None: 246 """Plot the test accuracy bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory. 247 248 **Args:** 249 - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test accuracy metric. 250 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc.png'. 251 """ 252 data = pd.read_csv(csv_path) 253 254 # extract all accuracy columns including average 255 all_columns = data.columns.tolist() 256 task_ids = list(range(len(all_columns))) # assign index-based positions 257 labels = [ 258 ( 259 col.replace("test_on_task_", "Task ") 260 if "test_on_task_" in col 261 else "Average" 262 ) 263 for col in all_columns 264 ] 265 accuracies = data.iloc[0][all_columns].values 266 267 # plot the accuracy bar chart over tasks 268 fig, ax = plt.subplots(figsize=(16, 9)) 269 ax.bar( 270 task_ids, 271 accuracies, 272 color="skyblue", 273 edgecolor="black", 274 ) 275 ax.set_xlabel("Task", fontsize=16) 276 ax.set_ylabel("Accuracy", fontsize=16) 277 ax.grid(True) 278 ax.set_xticks(task_ids) 279 ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14) 280 ax.set_yticks([i * 0.05 for i in range(21)]) 281 ax.set_yticklabels( 282 [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14 283 ) 284 fig.tight_layout() 285 fig.savefig(plot_path) 286 plt.close(fig)
Plot the test accuracy bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.
Args:
- csv_path (
str): the path to the csv file where theutils.save_test_acc_csv()saved the test accuracy metric. - plot_path (
str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc.png'.