clarena.metrics.mtl_loss
The submodule in metrics
for MTLLoss
.
1r""" 2The submodule in `metrics` for `MTLLoss`. 3""" 4 5__all__ = ["MTLLoss"] 6 7import csv 8import logging 9import os 10from typing import Any 11 12import pandas as pd 13from lightning import Trainer 14from matplotlib import pyplot as plt 15from torchmetrics import MeanMetric 16 17from clarena.metrics import MetricCallback 18from clarena.mtl_algorithms import MTLAlgorithm 19from clarena.utils.metrics import MeanMetricBatch 20 21# always get logger for built-in logging in each module 22pylogger = logging.getLogger(__name__) 23 24 25class MTLLoss(MetricCallback): 26 r"""Provides all actions that are related to MTL loss metrics, which include: 27 28 - Defining, initializing and recording loss metrics. 29 - Logging training and validation loss metrics to Lightning loggers in real time. 30 - Saving test loss metrics to files. 31 - Visualizing test loss metrics as plots. 32 33 The callback is able to produce the following outputs: 34 35 - CSV files for test classification loss of all tasks and average classification loss. 36 - Bar charts for test classification loss of all tasks. 37 """ 38 39 def __init__( 40 self, 41 save_dir: str, 42 test_loss_cls_csv_name: str = "loss_cls.csv", 43 test_loss_cls_plot_name: str | None = None, 44 ) -> None: 45 r""" 46 **Args:** 47 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 48 - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file. 49 - **test_loss_cls_plot_name** (`str` | `None`): file name to save classification loss plot. If `None`, no file will be saved. 50 """ 51 super().__init__(save_dir=save_dir) 52 53 # paths 54 self.test_loss_cls_csv_path: str = os.path.join( 55 save_dir, test_loss_cls_csv_name 56 ) 57 r"""The path to save test classification loss of all tasks and average classification loss CSV file.""" 58 if test_loss_cls_plot_name: 59 self.test_loss_cls_plot_path: str = os.path.join( 60 save_dir, test_loss_cls_plot_name 61 ) 62 r"""The path to save test classification loss plot.""" 63 64 # training accumulated metrics 65 self.loss_cls_training_epoch: MeanMetricBatch 66 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """ 67 68 # validation accumulated metrics 69 self.loss_cls_val: dict[int, MeanMetricBatch] = {} 70 r"""Validation classification loss of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics. """ 71 72 # test accumulated metrics 73 self.loss_cls_test: dict[int, MeanMetricBatch] = {} 74 r"""Test classification loss of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. """ 75 76 def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None: 77 r"""Initialize training and validation metrics.""" 78 79 # initialize training metrics 80 self.loss_cls_training_epoch = MeanMetricBatch() 81 82 # initialize validation metrics 83 self.loss_cls_val = { 84 task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks 85 } 86 87 def on_train_batch_end( 88 self, 89 trainer: Trainer, 90 pl_module: MTLAlgorithm, 91 outputs: dict[str, Any], 92 batch: Any, 93 batch_idx: int, 94 ) -> None: 95 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 96 97 **Args:** 98 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`. 99 - **batch** (`Any`): the training data batch. 100 """ 101 # get the batch size 102 batch_size = len(batch) 103 104 # get training metrics values of current training batch from the outputs of the `training_step()` 105 loss_cls_batch = outputs["loss_cls"] 106 107 # update accumulated training metrics to calculate training metrics of the epoch 108 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 109 110 # log training metrics of current training batch to Lightning loggers 111 pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True) 112 113 # log accumulated training metrics till this training batch to Lightning loggers 114 pl_module.log( 115 "task/train/loss_cls", 116 self.loss_cls_training_epoch.compute(), 117 prog_bar=True, 118 ) 119 120 def on_train_epoch_end( 121 self, 122 trainer: Trainer, 123 pl_module: MTLAlgorithm, 124 ) -> None: 125 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 126 127 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 128 pl_module.log( 129 "learning_curve/train/loss_cls", 130 self.loss_cls_training_epoch.compute(), 131 on_epoch=True, 132 prog_bar=True, 133 ) 134 135 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 136 self.loss_cls_training_epoch.reset() 137 138 def on_validation_batch_end( 139 self, 140 trainer: Trainer, 141 pl_module: MTLAlgorithm, 142 outputs: dict[str, Any], 143 batch: Any, 144 batch_idx: int, 145 dataloader_idx: int = 0, 146 ) -> None: 147 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 148 149 **Args:** 150 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`. 151 - **batch** (`Any`): the validation data batch. 152 - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 153 """ 154 # get the batch size 155 batch_size = len(batch) 156 157 # map dataloader index to task id 158 val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx) 159 160 # get the metrics values of the batch from the outputs 161 loss_cls_batch = outputs["loss_cls"] 162 163 # update the accumulated metrics in order to calculate the validation metrics 164 self.loss_cls_val[val_task_id].update(loss_cls_batch, batch_size) 165 166 def on_validation_epoch_end( 167 self, 168 trainer: Trainer, 169 pl_module: MTLAlgorithm, 170 ) -> None: 171 r"""Log validation metrics to plot learning curves.""" 172 173 # compute average validation loss over tasks for logging learning curves 174 average_val_loss = MeanMetric().to( 175 device=next(iter(self.loss_cls_val.values())).device 176 ) 177 for metric in self.loss_cls_val.values(): 178 average_val_loss(metric.compute()) 179 180 pl_module.log( 181 "learning_curve/val/loss_cls", 182 average_val_loss.compute(), 183 on_epoch=True, 184 prog_bar=True, 185 ) 186 187 def on_test_start( 188 self, 189 trainer: Trainer, 190 pl_module: MTLAlgorithm, 191 ) -> None: 192 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 193 194 # initialize test metrics for current and previous tasks 195 self.loss_cls_test = { 196 task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks 197 } 198 199 def on_test_batch_end( 200 self, 201 trainer: Trainer, 202 pl_module: MTLAlgorithm, 203 outputs: dict[str, Any], 204 batch: Any, 205 batch_idx: int, 206 dataloader_idx: int = 0, 207 ) -> None: 208 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 209 210 **Args:** 211 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`. 212 - **batch** (`Any`): the validation data batch. 213 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 214 """ 215 216 # get the batch size 217 batch_size = len(batch) 218 219 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 220 221 # get the metrics values of the batch from the outputs 222 loss_cls_batch = outputs["loss_cls"] 223 224 # update the accumulated metrics in order to calculate the metrics of the epoch 225 self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size) 226 227 def on_test_epoch_end( 228 self, 229 trainer: Trainer, 230 pl_module: MTLAlgorithm, 231 ) -> None: 232 r"""Save and plot test metrics at the end of test.""" 233 234 # save (update) the test metrics to CSV files 235 self.save_test_loss_cls_to_csv( 236 csv_path=self.test_loss_cls_csv_path, 237 ) 238 239 # plot the test metrics 240 if hasattr(self, "test_loss_cls_plot_path"): 241 self.plot_test_loss_cls_from_csv( 242 csv_path=self.test_loss_cls_csv_path, 243 plot_path=self.test_loss_cls_plot_path, 244 ) 245 246 def save_test_loss_cls_to_csv( 247 self, 248 csv_path: str, 249 ) -> None: 250 r"""Save the test classification loss metrics of all tasks in multi-task learning to an CSV file. 251 252 **Args:** 253 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'. 254 """ 255 all_task_ids = list(self.loss_cls_test.keys()) 256 fieldnames = ["average_classification_loss"] + [ 257 f"test_on_task_{task_id}" for task_id in all_task_ids 258 ] 259 new_line = {} 260 261 # construct the columns and calculate the average loss over tasks at the same time 262 average_loss_over_tasks = MeanMetric().to( 263 device=next(iter(self.loss_cls_test.values())).device 264 ) 265 for task_id in all_task_ids: 266 loss = self.loss_cls_test[task_id].compute().item() 267 new_line[f"test_on_task_{task_id}"] = loss 268 average_loss_over_tasks(loss) 269 new_line["average_classification_loss"] = ( 270 average_loss_over_tasks.compute().item() 271 ) 272 273 # write 274 with open(csv_path, "w", encoding="utf-8") as file: 275 writer = csv.DictWriter(file, fieldnames=fieldnames) 276 writer.writeheader() 277 writer.writerow(new_line) 278 279 def plot_test_loss_cls_from_csv(self, csv_path: str, plot_path: str) -> None: 280 """Plot the test classification loss bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory. 281 282 **Args:** 283 - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test classification loss metric. 284 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls.png'. 285 """ 286 data = pd.read_csv(csv_path) 287 288 # extract all accuracy columns including average 289 all_columns = data.columns.tolist() 290 task_ids = list(range(len(all_columns))) # assign index-based positions 291 labels = [ 292 ( 293 col.replace("test_on_task_", "Task ") 294 if "test_on_task_" in col 295 else "Average" 296 ) 297 for col in all_columns 298 ] 299 loss_cls = data.iloc[0][all_columns].values 300 301 # plot the classification loss bar chart over tasks 302 fig, ax = plt.subplots(figsize=(16, 9)) 303 ax.bar( 304 task_ids, 305 loss_cls, 306 color="skyblue", 307 edgecolor="black", 308 ) 309 ax.set_xlabel("Task", fontsize=16) 310 ax.set_ylabel("Classification Loss", fontsize=16) 311 ax.grid(True) 312 ax.set_xticks(task_ids) 313 ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14) 314 ax.set_yticks([i * 0.05 for i in range(21)]) 315 ax.set_yticklabels( 316 [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14 317 ) 318 fig.tight_layout() 319 fig.savefig(plot_path) 320 plt.close(fig)
26class MTLLoss(MetricCallback): 27 r"""Provides all actions that are related to MTL loss metrics, which include: 28 29 - Defining, initializing and recording loss metrics. 30 - Logging training and validation loss metrics to Lightning loggers in real time. 31 - Saving test loss metrics to files. 32 - Visualizing test loss metrics as plots. 33 34 The callback is able to produce the following outputs: 35 36 - CSV files for test classification loss of all tasks and average classification loss. 37 - Bar charts for test classification loss of all tasks. 38 """ 39 40 def __init__( 41 self, 42 save_dir: str, 43 test_loss_cls_csv_name: str = "loss_cls.csv", 44 test_loss_cls_plot_name: str | None = None, 45 ) -> None: 46 r""" 47 **Args:** 48 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 49 - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file. 50 - **test_loss_cls_plot_name** (`str` | `None`): file name to save classification loss plot. If `None`, no file will be saved. 51 """ 52 super().__init__(save_dir=save_dir) 53 54 # paths 55 self.test_loss_cls_csv_path: str = os.path.join( 56 save_dir, test_loss_cls_csv_name 57 ) 58 r"""The path to save test classification loss of all tasks and average classification loss CSV file.""" 59 if test_loss_cls_plot_name: 60 self.test_loss_cls_plot_path: str = os.path.join( 61 save_dir, test_loss_cls_plot_name 62 ) 63 r"""The path to save test classification loss plot.""" 64 65 # training accumulated metrics 66 self.loss_cls_training_epoch: MeanMetricBatch 67 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """ 68 69 # validation accumulated metrics 70 self.loss_cls_val: dict[int, MeanMetricBatch] = {} 71 r"""Validation classification loss of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics. """ 72 73 # test accumulated metrics 74 self.loss_cls_test: dict[int, MeanMetricBatch] = {} 75 r"""Test classification loss of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. """ 76 77 def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None: 78 r"""Initialize training and validation metrics.""" 79 80 # initialize training metrics 81 self.loss_cls_training_epoch = MeanMetricBatch() 82 83 # initialize validation metrics 84 self.loss_cls_val = { 85 task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks 86 } 87 88 def on_train_batch_end( 89 self, 90 trainer: Trainer, 91 pl_module: MTLAlgorithm, 92 outputs: dict[str, Any], 93 batch: Any, 94 batch_idx: int, 95 ) -> None: 96 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 97 98 **Args:** 99 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`. 100 - **batch** (`Any`): the training data batch. 101 """ 102 # get the batch size 103 batch_size = len(batch) 104 105 # get training metrics values of current training batch from the outputs of the `training_step()` 106 loss_cls_batch = outputs["loss_cls"] 107 108 # update accumulated training metrics to calculate training metrics of the epoch 109 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 110 111 # log training metrics of current training batch to Lightning loggers 112 pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True) 113 114 # log accumulated training metrics till this training batch to Lightning loggers 115 pl_module.log( 116 "task/train/loss_cls", 117 self.loss_cls_training_epoch.compute(), 118 prog_bar=True, 119 ) 120 121 def on_train_epoch_end( 122 self, 123 trainer: Trainer, 124 pl_module: MTLAlgorithm, 125 ) -> None: 126 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 127 128 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 129 pl_module.log( 130 "learning_curve/train/loss_cls", 131 self.loss_cls_training_epoch.compute(), 132 on_epoch=True, 133 prog_bar=True, 134 ) 135 136 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 137 self.loss_cls_training_epoch.reset() 138 139 def on_validation_batch_end( 140 self, 141 trainer: Trainer, 142 pl_module: MTLAlgorithm, 143 outputs: dict[str, Any], 144 batch: Any, 145 batch_idx: int, 146 dataloader_idx: int = 0, 147 ) -> None: 148 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 149 150 **Args:** 151 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`. 152 - **batch** (`Any`): the validation data batch. 153 - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 154 """ 155 # get the batch size 156 batch_size = len(batch) 157 158 # map dataloader index to task id 159 val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx) 160 161 # get the metrics values of the batch from the outputs 162 loss_cls_batch = outputs["loss_cls"] 163 164 # update the accumulated metrics in order to calculate the validation metrics 165 self.loss_cls_val[val_task_id].update(loss_cls_batch, batch_size) 166 167 def on_validation_epoch_end( 168 self, 169 trainer: Trainer, 170 pl_module: MTLAlgorithm, 171 ) -> None: 172 r"""Log validation metrics to plot learning curves.""" 173 174 # compute average validation loss over tasks for logging learning curves 175 average_val_loss = MeanMetric().to( 176 device=next(iter(self.loss_cls_val.values())).device 177 ) 178 for metric in self.loss_cls_val.values(): 179 average_val_loss(metric.compute()) 180 181 pl_module.log( 182 "learning_curve/val/loss_cls", 183 average_val_loss.compute(), 184 on_epoch=True, 185 prog_bar=True, 186 ) 187 188 def on_test_start( 189 self, 190 trainer: Trainer, 191 pl_module: MTLAlgorithm, 192 ) -> None: 193 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 194 195 # initialize test metrics for current and previous tasks 196 self.loss_cls_test = { 197 task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks 198 } 199 200 def on_test_batch_end( 201 self, 202 trainer: Trainer, 203 pl_module: MTLAlgorithm, 204 outputs: dict[str, Any], 205 batch: Any, 206 batch_idx: int, 207 dataloader_idx: int = 0, 208 ) -> None: 209 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 210 211 **Args:** 212 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`. 213 - **batch** (`Any`): the validation data batch. 214 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 215 """ 216 217 # get the batch size 218 batch_size = len(batch) 219 220 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 221 222 # get the metrics values of the batch from the outputs 223 loss_cls_batch = outputs["loss_cls"] 224 225 # update the accumulated metrics in order to calculate the metrics of the epoch 226 self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size) 227 228 def on_test_epoch_end( 229 self, 230 trainer: Trainer, 231 pl_module: MTLAlgorithm, 232 ) -> None: 233 r"""Save and plot test metrics at the end of test.""" 234 235 # save (update) the test metrics to CSV files 236 self.save_test_loss_cls_to_csv( 237 csv_path=self.test_loss_cls_csv_path, 238 ) 239 240 # plot the test metrics 241 if hasattr(self, "test_loss_cls_plot_path"): 242 self.plot_test_loss_cls_from_csv( 243 csv_path=self.test_loss_cls_csv_path, 244 plot_path=self.test_loss_cls_plot_path, 245 ) 246 247 def save_test_loss_cls_to_csv( 248 self, 249 csv_path: str, 250 ) -> None: 251 r"""Save the test classification loss metrics of all tasks in multi-task learning to an CSV file. 252 253 **Args:** 254 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'. 255 """ 256 all_task_ids = list(self.loss_cls_test.keys()) 257 fieldnames = ["average_classification_loss"] + [ 258 f"test_on_task_{task_id}" for task_id in all_task_ids 259 ] 260 new_line = {} 261 262 # construct the columns and calculate the average loss over tasks at the same time 263 average_loss_over_tasks = MeanMetric().to( 264 device=next(iter(self.loss_cls_test.values())).device 265 ) 266 for task_id in all_task_ids: 267 loss = self.loss_cls_test[task_id].compute().item() 268 new_line[f"test_on_task_{task_id}"] = loss 269 average_loss_over_tasks(loss) 270 new_line["average_classification_loss"] = ( 271 average_loss_over_tasks.compute().item() 272 ) 273 274 # write 275 with open(csv_path, "w", encoding="utf-8") as file: 276 writer = csv.DictWriter(file, fieldnames=fieldnames) 277 writer.writeheader() 278 writer.writerow(new_line) 279 280 def plot_test_loss_cls_from_csv(self, csv_path: str, plot_path: str) -> None: 281 """Plot the test classification loss bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory. 282 283 **Args:** 284 - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test classification loss metric. 285 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls.png'. 286 """ 287 data = pd.read_csv(csv_path) 288 289 # extract all accuracy columns including average 290 all_columns = data.columns.tolist() 291 task_ids = list(range(len(all_columns))) # assign index-based positions 292 labels = [ 293 ( 294 col.replace("test_on_task_", "Task ") 295 if "test_on_task_" in col 296 else "Average" 297 ) 298 for col in all_columns 299 ] 300 loss_cls = data.iloc[0][all_columns].values 301 302 # plot the classification loss bar chart over tasks 303 fig, ax = plt.subplots(figsize=(16, 9)) 304 ax.bar( 305 task_ids, 306 loss_cls, 307 color="skyblue", 308 edgecolor="black", 309 ) 310 ax.set_xlabel("Task", fontsize=16) 311 ax.set_ylabel("Classification Loss", fontsize=16) 312 ax.grid(True) 313 ax.set_xticks(task_ids) 314 ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14) 315 ax.set_yticks([i * 0.05 for i in range(21)]) 316 ax.set_yticklabels( 317 [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14 318 ) 319 fig.tight_layout() 320 fig.savefig(plot_path) 321 plt.close(fig)
Provides all actions that are related to MTL loss metrics, which include:
- Defining, initializing and recording loss metrics.
- Logging training and validation loss metrics to Lightning loggers in real time.
- Saving test loss metrics to files.
- Visualizing test loss metrics as plots.
The callback is able to produce the following outputs:
- CSV files for test classification loss of all tasks and average classification loss.
- Bar charts for test classification loss of all tasks.
40 def __init__( 41 self, 42 save_dir: str, 43 test_loss_cls_csv_name: str = "loss_cls.csv", 44 test_loss_cls_plot_name: str | None = None, 45 ) -> None: 46 r""" 47 **Args:** 48 - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder. 49 - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file. 50 - **test_loss_cls_plot_name** (`str` | `None`): file name to save classification loss plot. If `None`, no file will be saved. 51 """ 52 super().__init__(save_dir=save_dir) 53 54 # paths 55 self.test_loss_cls_csv_path: str = os.path.join( 56 save_dir, test_loss_cls_csv_name 57 ) 58 r"""The path to save test classification loss of all tasks and average classification loss CSV file.""" 59 if test_loss_cls_plot_name: 60 self.test_loss_cls_plot_path: str = os.path.join( 61 save_dir, test_loss_cls_plot_name 62 ) 63 r"""The path to save test classification loss plot.""" 64 65 # training accumulated metrics 66 self.loss_cls_training_epoch: MeanMetricBatch 67 r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """ 68 69 # validation accumulated metrics 70 self.loss_cls_val: dict[int, MeanMetricBatch] = {} 71 r"""Validation classification loss of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics. """ 72 73 # test accumulated metrics 74 self.loss_cls_test: dict[int, MeanMetricBatch] = {} 75 r"""Test classification loss of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. """
Args:
- save_dir (
str
): The directory where data and figures of metrics will be saved. Better inside the output folder. - test_loss_cls_csv_name(
str
): file name to save classification loss of all tasks and average classification loss as CSV file. - test_loss_cls_plot_name (
str
|None
): file name to save classification loss plot. IfNone
, no file will be saved.
The path to save test classification loss of all tasks and average classification loss CSV file.
Classification loss of training epoch. Accumulated and calculated from the training batches.
Validation classification loss of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics.
Test classification loss of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics.
77 def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None: 78 r"""Initialize training and validation metrics.""" 79 80 # initialize training metrics 81 self.loss_cls_training_epoch = MeanMetricBatch() 82 83 # initialize validation metrics 84 self.loss_cls_val = { 85 task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks 86 }
Initialize training and validation metrics.
88 def on_train_batch_end( 89 self, 90 trainer: Trainer, 91 pl_module: MTLAlgorithm, 92 outputs: dict[str, Any], 93 batch: Any, 94 batch_idx: int, 95 ) -> None: 96 r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers. 97 98 **Args:** 99 - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`. 100 - **batch** (`Any`): the training data batch. 101 """ 102 # get the batch size 103 batch_size = len(batch) 104 105 # get training metrics values of current training batch from the outputs of the `training_step()` 106 loss_cls_batch = outputs["loss_cls"] 107 108 # update accumulated training metrics to calculate training metrics of the epoch 109 self.loss_cls_training_epoch.update(loss_cls_batch, batch_size) 110 111 # log training metrics of current training batch to Lightning loggers 112 pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True) 113 114 # log accumulated training metrics till this training batch to Lightning loggers 115 pl_module.log( 116 "task/train/loss_cls", 117 self.loss_cls_training_epoch.compute(), 118 prog_bar=True, 119 )
Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
Args:
- outputs (
dict[str, Any]
): the outputs of the training step, the returns of thetraining_step()
method in theMTLAlgorithm
. - batch (
Any
): the training data batch.
121 def on_train_epoch_end( 122 self, 123 trainer: Trainer, 124 pl_module: MTLAlgorithm, 125 ) -> None: 126 r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.""" 127 128 # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves 129 pl_module.log( 130 "learning_curve/train/loss_cls", 131 self.loss_cls_training_epoch.compute(), 132 on_epoch=True, 133 prog_bar=True, 134 ) 135 136 # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test 137 self.loss_cls_training_epoch.reset()
Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.
139 def on_validation_batch_end( 140 self, 141 trainer: Trainer, 142 pl_module: MTLAlgorithm, 143 outputs: dict[str, Any], 144 batch: Any, 145 batch_idx: int, 146 dataloader_idx: int = 0, 147 ) -> None: 148 r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches. 149 150 **Args:** 151 - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`. 152 - **batch** (`Any`): the validation data batch. 153 - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 154 """ 155 # get the batch size 156 batch_size = len(batch) 157 158 # map dataloader index to task id 159 val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx) 160 161 # get the metrics values of the batch from the outputs 162 loss_cls_batch = outputs["loss_cls"] 163 164 # update the accumulated metrics in order to calculate the validation metrics 165 self.loss_cls_val[val_task_id].update(loss_cls_batch, batch_size)
Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
Args:
- outputs (
dict[str, Any]
): the outputs of the validation step, which is the returns of thevalidation_step()
method in theMTLAlgorithm
. - batch (
Any
): the validation data batch. - dataloader_idx (
int
): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.
167 def on_validation_epoch_end( 168 self, 169 trainer: Trainer, 170 pl_module: MTLAlgorithm, 171 ) -> None: 172 r"""Log validation metrics to plot learning curves.""" 173 174 # compute average validation loss over tasks for logging learning curves 175 average_val_loss = MeanMetric().to( 176 device=next(iter(self.loss_cls_val.values())).device 177 ) 178 for metric in self.loss_cls_val.values(): 179 average_val_loss(metric.compute()) 180 181 pl_module.log( 182 "learning_curve/val/loss_cls", 183 average_val_loss.compute(), 184 on_epoch=True, 185 prog_bar=True, 186 )
Log validation metrics to plot learning curves.
188 def on_test_start( 189 self, 190 trainer: Trainer, 191 pl_module: MTLAlgorithm, 192 ) -> None: 193 r"""Initialize the metrics for testing each seen task in the beginning of a task's testing.""" 194 195 # initialize test metrics for current and previous tasks 196 self.loss_cls_test = { 197 task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks 198 }
Initialize the metrics for testing each seen task in the beginning of a task's testing.
200 def on_test_batch_end( 201 self, 202 trainer: Trainer, 203 pl_module: MTLAlgorithm, 204 outputs: dict[str, Any], 205 batch: Any, 206 batch_idx: int, 207 dataloader_idx: int = 0, 208 ) -> None: 209 r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches. 210 211 **Args:** 212 - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`. 213 - **batch** (`Any`): the validation data batch. 214 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 215 """ 216 217 # get the batch size 218 batch_size = len(batch) 219 220 test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx) 221 222 # get the metrics values of the batch from the outputs 223 loss_cls_batch = outputs["loss_cls"] 224 225 # update the accumulated metrics in order to calculate the metrics of the epoch 226 self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size)
Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
Args:
- outputs (
dict[str, Any]
): the outputs of the test step, which is the returns of thetest_step()
method in theMTLAlgorithm
. - batch (
Any
): the validation data batch. - dataloader_idx (
int
): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.
228 def on_test_epoch_end( 229 self, 230 trainer: Trainer, 231 pl_module: MTLAlgorithm, 232 ) -> None: 233 r"""Save and plot test metrics at the end of test.""" 234 235 # save (update) the test metrics to CSV files 236 self.save_test_loss_cls_to_csv( 237 csv_path=self.test_loss_cls_csv_path, 238 ) 239 240 # plot the test metrics 241 if hasattr(self, "test_loss_cls_plot_path"): 242 self.plot_test_loss_cls_from_csv( 243 csv_path=self.test_loss_cls_csv_path, 244 plot_path=self.test_loss_cls_plot_path, 245 )
Save and plot test metrics at the end of test.
247 def save_test_loss_cls_to_csv( 248 self, 249 csv_path: str, 250 ) -> None: 251 r"""Save the test classification loss metrics of all tasks in multi-task learning to an CSV file. 252 253 **Args:** 254 - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'. 255 """ 256 all_task_ids = list(self.loss_cls_test.keys()) 257 fieldnames = ["average_classification_loss"] + [ 258 f"test_on_task_{task_id}" for task_id in all_task_ids 259 ] 260 new_line = {} 261 262 # construct the columns and calculate the average loss over tasks at the same time 263 average_loss_over_tasks = MeanMetric().to( 264 device=next(iter(self.loss_cls_test.values())).device 265 ) 266 for task_id in all_task_ids: 267 loss = self.loss_cls_test[task_id].compute().item() 268 new_line[f"test_on_task_{task_id}"] = loss 269 average_loss_over_tasks(loss) 270 new_line["average_classification_loss"] = ( 271 average_loss_over_tasks.compute().item() 272 ) 273 274 # write 275 with open(csv_path, "w", encoding="utf-8") as file: 276 writer = csv.DictWriter(file, fieldnames=fieldnames) 277 writer.writeheader() 278 writer.writerow(new_line)
Save the test classification loss metrics of all tasks in multi-task learning to an CSV file.
Args:
- csv_path (
str
): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
280 def plot_test_loss_cls_from_csv(self, csv_path: str, plot_path: str) -> None: 281 """Plot the test classification loss bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory. 282 283 **Args:** 284 - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test classification loss metric. 285 - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls.png'. 286 """ 287 data = pd.read_csv(csv_path) 288 289 # extract all accuracy columns including average 290 all_columns = data.columns.tolist() 291 task_ids = list(range(len(all_columns))) # assign index-based positions 292 labels = [ 293 ( 294 col.replace("test_on_task_", "Task ") 295 if "test_on_task_" in col 296 else "Average" 297 ) 298 for col in all_columns 299 ] 300 loss_cls = data.iloc[0][all_columns].values 301 302 # plot the classification loss bar chart over tasks 303 fig, ax = plt.subplots(figsize=(16, 9)) 304 ax.bar( 305 task_ids, 306 loss_cls, 307 color="skyblue", 308 edgecolor="black", 309 ) 310 ax.set_xlabel("Task", fontsize=16) 311 ax.set_ylabel("Classification Loss", fontsize=16) 312 ax.grid(True) 313 ax.set_xticks(task_ids) 314 ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14) 315 ax.set_yticks([i * 0.05 for i in range(21)]) 316 ax.set_yticklabels( 317 [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14 318 ) 319 fig.tight_layout() 320 fig.savefig(plot_path) 321 plt.close(fig)
Plot the test classification loss bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.
Args:
- csv_path (
str
): the path to the csv file where theutils.save_test_acc_csv()
saved the test classification loss metric. - plot_path (
str
): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls.png'.