clarena.metrics.mtl_loss

The submodule in metrics for MTLLoss.

  1r"""
  2The submodule in `metrics` for `MTLLoss`.
  3"""
  4
  5__all__ = ["MTLLoss"]
  6
  7import csv
  8import logging
  9import os
 10from typing import Any
 11
 12import pandas as pd
 13from lightning import Trainer
 14from matplotlib import pyplot as plt
 15from torchmetrics import MeanMetric
 16
 17from clarena.metrics import MetricCallback
 18from clarena.mtl_algorithms import MTLAlgorithm
 19from clarena.utils.metrics import MeanMetricBatch
 20
 21# always get logger for built-in logging in each module
 22pylogger = logging.getLogger(__name__)
 23
 24
 25class MTLLoss(MetricCallback):
 26    r"""Provides all actions that are related to MTL loss metrics, which include:
 27
 28    - Defining, initializing and recording loss metrics.
 29    - Logging training and validation loss metrics to Lightning loggers in real time.
 30    - Saving test loss metrics to files.
 31    - Visualizing test loss metrics as plots.
 32
 33    The callback is able to produce the following outputs:
 34
 35    - CSV files for test classification loss of all tasks and average classification loss.
 36    - Bar charts for test classification loss of all tasks.
 37    """
 38
 39    def __init__(
 40        self,
 41        save_dir: str,
 42        test_loss_cls_csv_name: str = "loss_cls.csv",
 43        test_loss_cls_plot_name: str | None = None,
 44    ) -> None:
 45        r"""
 46        **Args:**
 47        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 48        - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file.
 49        - **test_loss_cls_plot_name** (`str` | `None`): file name to save classification loss plot. If `None`, no file will be saved.
 50        """
 51        super().__init__(save_dir=save_dir)
 52
 53        # paths
 54        self.test_loss_cls_csv_path: str = os.path.join(
 55            save_dir, test_loss_cls_csv_name
 56        )
 57        r"""The path to save test classification loss of all tasks and average classification loss CSV file."""
 58        if test_loss_cls_plot_name:
 59            self.test_loss_cls_plot_path: str = os.path.join(
 60                save_dir, test_loss_cls_plot_name
 61            )
 62            r"""The path to save test classification loss plot."""
 63
 64        # training accumulated metrics
 65        self.loss_cls_training_epoch: MeanMetricBatch
 66        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """
 67
 68        # validation accumulated metrics
 69        self.loss_cls_val: dict[int, MeanMetricBatch] = {}
 70        r"""Validation classification loss of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics. """
 71
 72        # test accumulated metrics
 73        self.loss_cls_test: dict[int, MeanMetricBatch] = {}
 74        r"""Test classification loss of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. """
 75
 76    def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None:
 77        r"""Initialize training and validation metrics."""
 78
 79        # initialize training metrics
 80        self.loss_cls_training_epoch = MeanMetricBatch()
 81
 82        # initialize validation metrics
 83        self.loss_cls_val = {
 84            task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks
 85        }
 86
 87    def on_train_batch_end(
 88        self,
 89        trainer: Trainer,
 90        pl_module: MTLAlgorithm,
 91        outputs: dict[str, Any],
 92        batch: Any,
 93        batch_idx: int,
 94    ) -> None:
 95        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 96
 97        **Args:**
 98        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`.
 99        - **batch** (`Any`): the training data batch.
100        """
101        # get the batch size
102        batch_size = len(batch)
103
104        # get training metrics values of current training batch from the outputs of the `training_step()`
105        loss_cls_batch = outputs["loss_cls"]
106
107        # update accumulated training metrics to calculate training metrics of the epoch
108        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
109
110        # log training metrics of current training batch to Lightning loggers
111        pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True)
112
113        # log accumulated training metrics till this training batch to Lightning loggers
114        pl_module.log(
115            "task/train/loss_cls",
116            self.loss_cls_training_epoch.compute(),
117            prog_bar=True,
118        )
119
120    def on_train_epoch_end(
121        self,
122        trainer: Trainer,
123        pl_module: MTLAlgorithm,
124    ) -> None:
125        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
126
127        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
128        pl_module.log(
129            "learning_curve/train/loss_cls",
130            self.loss_cls_training_epoch.compute(),
131            on_epoch=True,
132            prog_bar=True,
133        )
134
135        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
136        self.loss_cls_training_epoch.reset()
137
138    def on_validation_batch_end(
139        self,
140        trainer: Trainer,
141        pl_module: MTLAlgorithm,
142        outputs: dict[str, Any],
143        batch: Any,
144        batch_idx: int,
145        dataloader_idx: int = 0,
146    ) -> None:
147        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
148
149        **Args:**
150        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`.
151        - **batch** (`Any`): the validation data batch.
152        - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
153        """
154        # get the batch size
155        batch_size = len(batch)
156
157        # map dataloader index to task id
158        val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx)
159
160        # get the metrics values of the batch from the outputs
161        loss_cls_batch = outputs["loss_cls"]
162
163        # update the accumulated metrics in order to calculate the validation metrics
164        self.loss_cls_val[val_task_id].update(loss_cls_batch, batch_size)
165
166    def on_validation_epoch_end(
167        self,
168        trainer: Trainer,
169        pl_module: MTLAlgorithm,
170    ) -> None:
171        r"""Log validation metrics to plot learning curves."""
172
173        # compute average validation loss over tasks for logging learning curves
174        average_val_loss = MeanMetric().to(
175            device=next(iter(self.loss_cls_val.values())).device
176        )
177        for metric in self.loss_cls_val.values():
178            average_val_loss(metric.compute())
179
180        pl_module.log(
181            "learning_curve/val/loss_cls",
182            average_val_loss.compute(),
183            on_epoch=True,
184            prog_bar=True,
185        )
186
187    def on_test_start(
188        self,
189        trainer: Trainer,
190        pl_module: MTLAlgorithm,
191    ) -> None:
192        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
193
194        # initialize test metrics for current and previous tasks
195        self.loss_cls_test = {
196            task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks
197        }
198
199    def on_test_batch_end(
200        self,
201        trainer: Trainer,
202        pl_module: MTLAlgorithm,
203        outputs: dict[str, Any],
204        batch: Any,
205        batch_idx: int,
206        dataloader_idx: int = 0,
207    ) -> None:
208        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
209
210        **Args:**
211        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`.
212        - **batch** (`Any`): the validation data batch.
213        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
214        """
215
216        # get the batch size
217        batch_size = len(batch)
218
219        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
220
221        # get the metrics values of the batch from the outputs
222        loss_cls_batch = outputs["loss_cls"]
223
224        # update the accumulated metrics in order to calculate the metrics of the epoch
225        self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size)
226
227    def on_test_epoch_end(
228        self,
229        trainer: Trainer,
230        pl_module: MTLAlgorithm,
231    ) -> None:
232        r"""Save and plot test metrics at the end of test."""
233
234        # save (update) the test metrics to CSV files
235        self.save_test_loss_cls_to_csv(
236            csv_path=self.test_loss_cls_csv_path,
237        )
238
239        # plot the test metrics
240        if hasattr(self, "test_loss_cls_plot_path"):
241            self.plot_test_loss_cls_from_csv(
242                csv_path=self.test_loss_cls_csv_path,
243                plot_path=self.test_loss_cls_plot_path,
244            )
245
246    def save_test_loss_cls_to_csv(
247        self,
248        csv_path: str,
249    ) -> None:
250        r"""Save the test classification loss metrics of all tasks in multi-task learning to an CSV file.
251
252        **Args:**
253        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
254        """
255        all_task_ids = list(self.loss_cls_test.keys())
256        fieldnames = ["average_classification_loss"] + [
257            f"test_on_task_{task_id}" for task_id in all_task_ids
258        ]
259        new_line = {}
260
261        # construct the columns and calculate the average loss over tasks at the same time
262        average_loss_over_tasks = MeanMetric().to(
263            device=next(iter(self.loss_cls_test.values())).device
264        )
265        for task_id in all_task_ids:
266            loss = self.loss_cls_test[task_id].compute().item()
267            new_line[f"test_on_task_{task_id}"] = loss
268            average_loss_over_tasks(loss)
269        new_line["average_classification_loss"] = (
270            average_loss_over_tasks.compute().item()
271        )
272
273        # write
274        with open(csv_path, "w", encoding="utf-8") as file:
275            writer = csv.DictWriter(file, fieldnames=fieldnames)
276            writer.writeheader()
277            writer.writerow(new_line)
278
279    def plot_test_loss_cls_from_csv(self, csv_path: str, plot_path: str) -> None:
280        """Plot the test classification loss bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.
281
282        **Args:**
283        - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test classification loss metric.
284        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls.png'.
285        """
286        data = pd.read_csv(csv_path)
287
288        # extract all accuracy columns including average
289        all_columns = data.columns.tolist()
290        task_ids = list(range(len(all_columns)))  # assign index-based positions
291        labels = [
292            (
293                col.replace("test_on_task_", "Task ")
294                if "test_on_task_" in col
295                else "Average"
296            )
297            for col in all_columns
298        ]
299        loss_cls = data.iloc[0][all_columns].values
300
301        # plot the classification loss bar chart over tasks
302        fig, ax = plt.subplots(figsize=(16, 9))
303        ax.bar(
304            task_ids,
305            loss_cls,
306            color="skyblue",
307            edgecolor="black",
308        )
309        ax.set_xlabel("Task", fontsize=16)
310        ax.set_ylabel("Classification Loss", fontsize=16)
311        ax.grid(True)
312        ax.set_xticks(task_ids)
313        ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14)
314        ax.set_yticks([i * 0.05 for i in range(21)])
315        ax.set_yticklabels(
316            [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14
317        )
318        fig.tight_layout()
319        fig.savefig(plot_path)
320        plt.close(fig)
class MTLLoss(clarena.metrics.base.MetricCallback):
 26class MTLLoss(MetricCallback):
 27    r"""Provides all actions that are related to MTL loss metrics, which include:
 28
 29    - Defining, initializing and recording loss metrics.
 30    - Logging training and validation loss metrics to Lightning loggers in real time.
 31    - Saving test loss metrics to files.
 32    - Visualizing test loss metrics as plots.
 33
 34    The callback is able to produce the following outputs:
 35
 36    - CSV files for test classification loss of all tasks and average classification loss.
 37    - Bar charts for test classification loss of all tasks.
 38    """
 39
 40    def __init__(
 41        self,
 42        save_dir: str,
 43        test_loss_cls_csv_name: str = "loss_cls.csv",
 44        test_loss_cls_plot_name: str | None = None,
 45    ) -> None:
 46        r"""
 47        **Args:**
 48        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 49        - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file.
 50        - **test_loss_cls_plot_name** (`str` | `None`): file name to save classification loss plot. If `None`, no file will be saved.
 51        """
 52        super().__init__(save_dir=save_dir)
 53
 54        # paths
 55        self.test_loss_cls_csv_path: str = os.path.join(
 56            save_dir, test_loss_cls_csv_name
 57        )
 58        r"""The path to save test classification loss of all tasks and average classification loss CSV file."""
 59        if test_loss_cls_plot_name:
 60            self.test_loss_cls_plot_path: str = os.path.join(
 61                save_dir, test_loss_cls_plot_name
 62            )
 63            r"""The path to save test classification loss plot."""
 64
 65        # training accumulated metrics
 66        self.loss_cls_training_epoch: MeanMetricBatch
 67        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """
 68
 69        # validation accumulated metrics
 70        self.loss_cls_val: dict[int, MeanMetricBatch] = {}
 71        r"""Validation classification loss of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics. """
 72
 73        # test accumulated metrics
 74        self.loss_cls_test: dict[int, MeanMetricBatch] = {}
 75        r"""Test classification loss of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. """
 76
 77    def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None:
 78        r"""Initialize training and validation metrics."""
 79
 80        # initialize training metrics
 81        self.loss_cls_training_epoch = MeanMetricBatch()
 82
 83        # initialize validation metrics
 84        self.loss_cls_val = {
 85            task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks
 86        }
 87
 88    def on_train_batch_end(
 89        self,
 90        trainer: Trainer,
 91        pl_module: MTLAlgorithm,
 92        outputs: dict[str, Any],
 93        batch: Any,
 94        batch_idx: int,
 95    ) -> None:
 96        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 97
 98        **Args:**
 99        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`.
100        - **batch** (`Any`): the training data batch.
101        """
102        # get the batch size
103        batch_size = len(batch)
104
105        # get training metrics values of current training batch from the outputs of the `training_step()`
106        loss_cls_batch = outputs["loss_cls"]
107
108        # update accumulated training metrics to calculate training metrics of the epoch
109        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
110
111        # log training metrics of current training batch to Lightning loggers
112        pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True)
113
114        # log accumulated training metrics till this training batch to Lightning loggers
115        pl_module.log(
116            "task/train/loss_cls",
117            self.loss_cls_training_epoch.compute(),
118            prog_bar=True,
119        )
120
121    def on_train_epoch_end(
122        self,
123        trainer: Trainer,
124        pl_module: MTLAlgorithm,
125    ) -> None:
126        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
127
128        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
129        pl_module.log(
130            "learning_curve/train/loss_cls",
131            self.loss_cls_training_epoch.compute(),
132            on_epoch=True,
133            prog_bar=True,
134        )
135
136        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
137        self.loss_cls_training_epoch.reset()
138
139    def on_validation_batch_end(
140        self,
141        trainer: Trainer,
142        pl_module: MTLAlgorithm,
143        outputs: dict[str, Any],
144        batch: Any,
145        batch_idx: int,
146        dataloader_idx: int = 0,
147    ) -> None:
148        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
149
150        **Args:**
151        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`.
152        - **batch** (`Any`): the validation data batch.
153        - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
154        """
155        # get the batch size
156        batch_size = len(batch)
157
158        # map dataloader index to task id
159        val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx)
160
161        # get the metrics values of the batch from the outputs
162        loss_cls_batch = outputs["loss_cls"]
163
164        # update the accumulated metrics in order to calculate the validation metrics
165        self.loss_cls_val[val_task_id].update(loss_cls_batch, batch_size)
166
167    def on_validation_epoch_end(
168        self,
169        trainer: Trainer,
170        pl_module: MTLAlgorithm,
171    ) -> None:
172        r"""Log validation metrics to plot learning curves."""
173
174        # compute average validation loss over tasks for logging learning curves
175        average_val_loss = MeanMetric().to(
176            device=next(iter(self.loss_cls_val.values())).device
177        )
178        for metric in self.loss_cls_val.values():
179            average_val_loss(metric.compute())
180
181        pl_module.log(
182            "learning_curve/val/loss_cls",
183            average_val_loss.compute(),
184            on_epoch=True,
185            prog_bar=True,
186        )
187
188    def on_test_start(
189        self,
190        trainer: Trainer,
191        pl_module: MTLAlgorithm,
192    ) -> None:
193        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
194
195        # initialize test metrics for current and previous tasks
196        self.loss_cls_test = {
197            task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks
198        }
199
200    def on_test_batch_end(
201        self,
202        trainer: Trainer,
203        pl_module: MTLAlgorithm,
204        outputs: dict[str, Any],
205        batch: Any,
206        batch_idx: int,
207        dataloader_idx: int = 0,
208    ) -> None:
209        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
210
211        **Args:**
212        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`.
213        - **batch** (`Any`): the validation data batch.
214        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
215        """
216
217        # get the batch size
218        batch_size = len(batch)
219
220        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
221
222        # get the metrics values of the batch from the outputs
223        loss_cls_batch = outputs["loss_cls"]
224
225        # update the accumulated metrics in order to calculate the metrics of the epoch
226        self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size)
227
228    def on_test_epoch_end(
229        self,
230        trainer: Trainer,
231        pl_module: MTLAlgorithm,
232    ) -> None:
233        r"""Save and plot test metrics at the end of test."""
234
235        # save (update) the test metrics to CSV files
236        self.save_test_loss_cls_to_csv(
237            csv_path=self.test_loss_cls_csv_path,
238        )
239
240        # plot the test metrics
241        if hasattr(self, "test_loss_cls_plot_path"):
242            self.plot_test_loss_cls_from_csv(
243                csv_path=self.test_loss_cls_csv_path,
244                plot_path=self.test_loss_cls_plot_path,
245            )
246
247    def save_test_loss_cls_to_csv(
248        self,
249        csv_path: str,
250    ) -> None:
251        r"""Save the test classification loss metrics of all tasks in multi-task learning to an CSV file.
252
253        **Args:**
254        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
255        """
256        all_task_ids = list(self.loss_cls_test.keys())
257        fieldnames = ["average_classification_loss"] + [
258            f"test_on_task_{task_id}" for task_id in all_task_ids
259        ]
260        new_line = {}
261
262        # construct the columns and calculate the average loss over tasks at the same time
263        average_loss_over_tasks = MeanMetric().to(
264            device=next(iter(self.loss_cls_test.values())).device
265        )
266        for task_id in all_task_ids:
267            loss = self.loss_cls_test[task_id].compute().item()
268            new_line[f"test_on_task_{task_id}"] = loss
269            average_loss_over_tasks(loss)
270        new_line["average_classification_loss"] = (
271            average_loss_over_tasks.compute().item()
272        )
273
274        # write
275        with open(csv_path, "w", encoding="utf-8") as file:
276            writer = csv.DictWriter(file, fieldnames=fieldnames)
277            writer.writeheader()
278            writer.writerow(new_line)
279
280    def plot_test_loss_cls_from_csv(self, csv_path: str, plot_path: str) -> None:
281        """Plot the test classification loss bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.
282
283        **Args:**
284        - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test classification loss metric.
285        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls.png'.
286        """
287        data = pd.read_csv(csv_path)
288
289        # extract all accuracy columns including average
290        all_columns = data.columns.tolist()
291        task_ids = list(range(len(all_columns)))  # assign index-based positions
292        labels = [
293            (
294                col.replace("test_on_task_", "Task ")
295                if "test_on_task_" in col
296                else "Average"
297            )
298            for col in all_columns
299        ]
300        loss_cls = data.iloc[0][all_columns].values
301
302        # plot the classification loss bar chart over tasks
303        fig, ax = plt.subplots(figsize=(16, 9))
304        ax.bar(
305            task_ids,
306            loss_cls,
307            color="skyblue",
308            edgecolor="black",
309        )
310        ax.set_xlabel("Task", fontsize=16)
311        ax.set_ylabel("Classification Loss", fontsize=16)
312        ax.grid(True)
313        ax.set_xticks(task_ids)
314        ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14)
315        ax.set_yticks([i * 0.05 for i in range(21)])
316        ax.set_yticklabels(
317            [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14
318        )
319        fig.tight_layout()
320        fig.savefig(plot_path)
321        plt.close(fig)

Provides all actions that are related to MTL loss metrics, which include:

  • Defining, initializing and recording loss metrics.
  • Logging training and validation loss metrics to Lightning loggers in real time.
  • Saving test loss metrics to files.
  • Visualizing test loss metrics as plots.

The callback is able to produce the following outputs:

  • CSV files for test classification loss of all tasks and average classification loss.
  • Bar charts for test classification loss of all tasks.
MTLLoss( save_dir: str, test_loss_cls_csv_name: str = 'loss_cls.csv', test_loss_cls_plot_name: str | None = None)
40    def __init__(
41        self,
42        save_dir: str,
43        test_loss_cls_csv_name: str = "loss_cls.csv",
44        test_loss_cls_plot_name: str | None = None,
45    ) -> None:
46        r"""
47        **Args:**
48        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
49        - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file.
50        - **test_loss_cls_plot_name** (`str` | `None`): file name to save classification loss plot. If `None`, no file will be saved.
51        """
52        super().__init__(save_dir=save_dir)
53
54        # paths
55        self.test_loss_cls_csv_path: str = os.path.join(
56            save_dir, test_loss_cls_csv_name
57        )
58        r"""The path to save test classification loss of all tasks and average classification loss CSV file."""
59        if test_loss_cls_plot_name:
60            self.test_loss_cls_plot_path: str = os.path.join(
61                save_dir, test_loss_cls_plot_name
62            )
63            r"""The path to save test classification loss plot."""
64
65        # training accumulated metrics
66        self.loss_cls_training_epoch: MeanMetricBatch
67        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """
68
69        # validation accumulated metrics
70        self.loss_cls_val: dict[int, MeanMetricBatch] = {}
71        r"""Validation classification loss of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics. """
72
73        # test accumulated metrics
74        self.loss_cls_test: dict[int, MeanMetricBatch] = {}
75        r"""Test classification loss of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics. """

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
  • test_loss_cls_csv_name(str): file name to save classification loss of all tasks and average classification loss as CSV file.
  • test_loss_cls_plot_name (str | None): file name to save classification loss plot. If None, no file will be saved.
test_loss_cls_csv_path: str

The path to save test classification loss of all tasks and average classification loss CSV file.

loss_cls_training_epoch: clarena.utils.metrics.MeanMetricBatch

Classification loss of training epoch. Accumulated and calculated from the training batches.

loss_cls_val: dict[int, clarena.utils.metrics.MeanMetricBatch]

Validation classification loss of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics.

loss_cls_test: dict[int, clarena.utils.metrics.MeanMetricBatch]

Test classification loss of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
77    def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None:
78        r"""Initialize training and validation metrics."""
79
80        # initialize training metrics
81        self.loss_cls_training_epoch = MeanMetricBatch()
82
83        # initialize validation metrics
84        self.loss_cls_val = {
85            task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks
86        }

Initialize training and validation metrics.

def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 88    def on_train_batch_end(
 89        self,
 90        trainer: Trainer,
 91        pl_module: MTLAlgorithm,
 92        outputs: dict[str, Any],
 93        batch: Any,
 94        batch_idx: int,
 95    ) -> None:
 96        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 97
 98        **Args:**
 99        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`.
100        - **batch** (`Any`): the training data batch.
101        """
102        # get the batch size
103        batch_size = len(batch)
104
105        # get training metrics values of current training batch from the outputs of the `training_step()`
106        loss_cls_batch = outputs["loss_cls"]
107
108        # update accumulated training metrics to calculate training metrics of the epoch
109        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
110
111        # log training metrics of current training batch to Lightning loggers
112        pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True)
113
114        # log accumulated training metrics till this training batch to Lightning loggers
115        pl_module.log(
116            "task/train/loss_cls",
117            self.loss_cls_training_epoch.compute(),
118            prog_bar=True,
119        )

Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, the returns of the training_step() method in the MTLAlgorithm.
  • batch (Any): the training data batch.
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
121    def on_train_epoch_end(
122        self,
123        trainer: Trainer,
124        pl_module: MTLAlgorithm,
125    ) -> None:
126        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
127
128        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
129        pl_module.log(
130            "learning_curve/train/loss_cls",
131            self.loss_cls_training_epoch.compute(),
132            on_epoch=True,
133            prog_bar=True,
134        )
135
136        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
137        self.loss_cls_training_epoch.reset()

Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.

def on_validation_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
139    def on_validation_batch_end(
140        self,
141        trainer: Trainer,
142        pl_module: MTLAlgorithm,
143        outputs: dict[str, Any],
144        batch: Any,
145        batch_idx: int,
146        dataloader_idx: int = 0,
147    ) -> None:
148        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
149
150        **Args:**
151        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`.
152        - **batch** (`Any`): the validation data batch.
153        - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
154        """
155        # get the batch size
156        batch_size = len(batch)
157
158        # map dataloader index to task id
159        val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx)
160
161        # get the metrics values of the batch from the outputs
162        loss_cls_batch = outputs["loss_cls"]
163
164        # update the accumulated metrics in order to calculate the validation metrics
165        self.loss_cls_val[val_task_id].update(loss_cls_batch, batch_size)

Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.

Args:

  • outputs (dict[str, Any]): the outputs of the validation step, which is the returns of the validation_step() method in the MTLAlgorithm.
  • batch (Any): the validation data batch.
  • dataloader_idx (int): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
def on_validation_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
167    def on_validation_epoch_end(
168        self,
169        trainer: Trainer,
170        pl_module: MTLAlgorithm,
171    ) -> None:
172        r"""Log validation metrics to plot learning curves."""
173
174        # compute average validation loss over tasks for logging learning curves
175        average_val_loss = MeanMetric().to(
176            device=next(iter(self.loss_cls_val.values())).device
177        )
178        for metric in self.loss_cls_val.values():
179            average_val_loss(metric.compute())
180
181        pl_module.log(
182            "learning_curve/val/loss_cls",
183            average_val_loss.compute(),
184            on_epoch=True,
185            prog_bar=True,
186        )

Log validation metrics to plot learning curves.

def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
188    def on_test_start(
189        self,
190        trainer: Trainer,
191        pl_module: MTLAlgorithm,
192    ) -> None:
193        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
194
195        # initialize test metrics for current and previous tasks
196        self.loss_cls_test = {
197            task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks
198        }

Initialize the metrics for testing each seen task in the beginning of a task's testing.

def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
200    def on_test_batch_end(
201        self,
202        trainer: Trainer,
203        pl_module: MTLAlgorithm,
204        outputs: dict[str, Any],
205        batch: Any,
206        batch_idx: int,
207        dataloader_idx: int = 0,
208    ) -> None:
209        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
210
211        **Args:**
212        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`.
213        - **batch** (`Any`): the validation data batch.
214        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
215        """
216
217        # get the batch size
218        batch_size = len(batch)
219
220        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
221
222        # get the metrics values of the batch from the outputs
223        loss_cls_batch = outputs["loss_cls"]
224
225        # update the accumulated metrics in order to calculate the metrics of the epoch
226        self.loss_cls_test[test_task_id].update(loss_cls_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the MTLAlgorithm.
  • batch (Any): the validation data batch.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
228    def on_test_epoch_end(
229        self,
230        trainer: Trainer,
231        pl_module: MTLAlgorithm,
232    ) -> None:
233        r"""Save and plot test metrics at the end of test."""
234
235        # save (update) the test metrics to CSV files
236        self.save_test_loss_cls_to_csv(
237            csv_path=self.test_loss_cls_csv_path,
238        )
239
240        # plot the test metrics
241        if hasattr(self, "test_loss_cls_plot_path"):
242            self.plot_test_loss_cls_from_csv(
243                csv_path=self.test_loss_cls_csv_path,
244                plot_path=self.test_loss_cls_plot_path,
245            )

Save and plot test metrics at the end of test.

def save_test_loss_cls_to_csv(self, csv_path: str) -> None:
247    def save_test_loss_cls_to_csv(
248        self,
249        csv_path: str,
250    ) -> None:
251        r"""Save the test classification loss metrics of all tasks in multi-task learning to an CSV file.
252
253        **Args:**
254        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
255        """
256        all_task_ids = list(self.loss_cls_test.keys())
257        fieldnames = ["average_classification_loss"] + [
258            f"test_on_task_{task_id}" for task_id in all_task_ids
259        ]
260        new_line = {}
261
262        # construct the columns and calculate the average loss over tasks at the same time
263        average_loss_over_tasks = MeanMetric().to(
264            device=next(iter(self.loss_cls_test.values())).device
265        )
266        for task_id in all_task_ids:
267            loss = self.loss_cls_test[task_id].compute().item()
268            new_line[f"test_on_task_{task_id}"] = loss
269            average_loss_over_tasks(loss)
270        new_line["average_classification_loss"] = (
271            average_loss_over_tasks.compute().item()
272        )
273
274        # write
275        with open(csv_path, "w", encoding="utf-8") as file:
276            writer = csv.DictWriter(file, fieldnames=fieldnames)
277            writer.writeheader()
278            writer.writerow(new_line)

Save the test classification loss metrics of all tasks in multi-task learning to an CSV file.

Args:

  • csv_path (str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
def plot_test_loss_cls_from_csv(self, csv_path: str, plot_path: str) -> None:
280    def plot_test_loss_cls_from_csv(self, csv_path: str, plot_path: str) -> None:
281        """Plot the test classification loss bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.
282
283        **Args:**
284        - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test classification loss metric.
285        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls.png'.
286        """
287        data = pd.read_csv(csv_path)
288
289        # extract all accuracy columns including average
290        all_columns = data.columns.tolist()
291        task_ids = list(range(len(all_columns)))  # assign index-based positions
292        labels = [
293            (
294                col.replace("test_on_task_", "Task ")
295                if "test_on_task_" in col
296                else "Average"
297            )
298            for col in all_columns
299        ]
300        loss_cls = data.iloc[0][all_columns].values
301
302        # plot the classification loss bar chart over tasks
303        fig, ax = plt.subplots(figsize=(16, 9))
304        ax.bar(
305            task_ids,
306            loss_cls,
307            color="skyblue",
308            edgecolor="black",
309        )
310        ax.set_xlabel("Task", fontsize=16)
311        ax.set_ylabel("Classification Loss", fontsize=16)
312        ax.grid(True)
313        ax.set_xticks(task_ids)
314        ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14)
315        ax.set_yticks([i * 0.05 for i in range(21)])
316        ax.set_yticklabels(
317            [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14
318        )
319        fig.tight_layout()
320        fig.savefig(plot_path)
321        plt.close(fig)

Plot the test classification loss bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the csv file where the utils.save_test_acc_csv() saved the test classification loss metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/loss_cls.png'.