clarena.metrics.mtl_acc

The submodule in metrics for MTLAccuracy.

  1r"""
  2The submodule in `metrics` for `MTLAccuracy`.
  3"""
  4
  5__all__ = ["MTLAccuracy"]
  6
  7import csv
  8import logging
  9import os
 10from typing import Any
 11
 12import pandas as pd
 13from lightning import Trainer
 14from matplotlib import pyplot as plt
 15from torchmetrics import MeanMetric
 16
 17from clarena.metrics import MetricCallback
 18from clarena.mtl_algorithms import MTLAlgorithm
 19from clarena.utils.metrics import MeanMetricBatch
 20
 21# always get logger for built-in logging in each module
 22pylogger = logging.getLogger(__name__)
 23
 24
 25class MTLAccuracy(MetricCallback):
 26    r"""Provides all actions that are related to MTL accuracy metrics, which include:
 27
 28    - Defining, initializing and recording accuracy metric.
 29    - Logging training and validation accuracy metric to Lightning loggers in real time.
 30    - Saving test accuracy metric to files.
 31    - Visualizing test accuracy metric as plots.
 32
 33    The callback is able to produce the following outputs:
 34
 35    - CSV files for test accuracy of all tasks and average accuracy.
 36    - Bar charts for test accuracy of all tasks.
 37    """
 38
 39    def __init__(
 40        self,
 41        save_dir: str,
 42        test_acc_csv_name: str = "acc.csv",
 43        test_acc_plot_name: str | None = None,
 44    ) -> None:
 45        r"""
 46        **Args:**
 47        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 48        - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file.
 49        - **test_acc_plot_name** (`str` | `None`): file name to save accuracy plot. If `None`, no file will be saved.
 50        """
 51        super().__init__(save_dir=save_dir)
 52
 53        # paths
 54        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 55        r"""The path to save test accuracy of all tasks and average accuracy CSV file."""
 56        if test_acc_plot_name:
 57            self.test_acc_plot_path: str = os.path.join(save_dir, test_acc_plot_name)
 58            r"""The path to save test accuracy plot."""
 59
 60        # training accumulated metrics
 61        self.acc_training_epoch: MeanMetricBatch
 62        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """
 63
 64        # validation accumulated metrics
 65        self.acc_val: dict[int, MeanMetricBatch] = {}
 66        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics."""
 67
 68        # test accumulated metrics
 69        self.acc_test: dict[int, MeanMetricBatch] = {}
 70        r"""Test classification accuracy of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics."""
 71
 72    def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None:
 73        r"""Initialize training and validation metrics."""
 74
 75        # initialize training metrics
 76        self.acc_training_epoch = MeanMetricBatch()
 77
 78        # initialize validation metrics
 79        self.acc_val = {
 80            task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks
 81        }
 82
 83    def on_train_batch_end(
 84        self,
 85        trainer: Trainer,
 86        pl_module: MTLAlgorithm,
 87        outputs: dict[str, Any],
 88        batch: Any,
 89        batch_idx: int,
 90    ) -> None:
 91        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 92
 93        **Args:**
 94        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`.
 95        - **batch** (`Any`): the training data batch.
 96        """
 97        # get the batch size
 98        batch_size = len(batch)
 99
100        # get training metrics values of current training batch from the outputs of the `training_step()`
101        acc_batch = outputs["acc"]
102
103        # update accumulated training metrics to calculate training metrics of the epoch
104        self.acc_training_epoch.update(acc_batch, batch_size)
105
106        # log training metrics of current training batch to Lightning loggers
107        pl_module.log("train/acc_batch", acc_batch, prog_bar=True)
108
109        # log accumulated training metrics till this training batch to Lightning loggers
110        pl_module.log(
111            "task/train/acc",
112            self.acc_training_epoch.compute(),
113            prog_bar=True,
114        )
115
116    def on_train_epoch_end(
117        self,
118        trainer: Trainer,
119        pl_module: MTLAlgorithm,
120    ) -> None:
121        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
122
123        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
124        self.acc_training_epoch.reset()
125
126    def on_validation_batch_end(
127        self,
128        trainer: Trainer,
129        pl_module: MTLAlgorithm,
130        outputs: dict[str, Any],
131        batch: Any,
132        batch_idx: int,
133        dataloader_idx: int = 0,
134    ) -> None:
135        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
136
137        **Args:**
138        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`.
139        - **batch** (`Any`): the validation data batch.
140        - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
141        """
142
143        # get the batch size
144        batch_size = len(batch)
145
146        val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx)
147
148        # get the metrics values of the batch from the outputs
149        acc_batch = outputs["acc"]
150
151        # update the accumulated metrics in order to calculate the validation metrics
152        self.acc_val[val_task_id].update(acc_batch, batch_size)
153
154    def on_test_start(
155        self,
156        trainer: Trainer,
157        pl_module: MTLAlgorithm,
158    ) -> None:
159        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
160
161        # initialize test metrics for current and previous tasks
162        self.acc_test = {
163            task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks
164        }
165
166    def on_test_batch_end(
167        self,
168        trainer: Trainer,
169        pl_module: MTLAlgorithm,
170        outputs: dict[str, Any],
171        batch: Any,
172        batch_idx: int,
173        dataloader_idx: int = 0,
174    ) -> None:
175        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
176
177        **Args:**
178        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`.
179        - **batch** (`Any`): the validation data batch.
180        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
181        """
182
183        # get the batch size
184        batch_size = len(batch)
185
186        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
187
188        # get the metrics values of the batch from the outputs
189        acc_batch = outputs["acc"]
190
191        # update the accumulated metrics in order to calculate the metrics of the epoch
192        self.acc_test[test_task_id].update(acc_batch, batch_size)
193
194    def on_test_epoch_end(
195        self,
196        trainer: Trainer,
197        pl_module: MTLAlgorithm,
198    ) -> None:
199        r"""Save and plot test metrics at the end of test."""
200
201        # save (update) the test metrics to CSV files
202        self.save_test_acc_to_csv(
203            csv_path=self.test_acc_csv_path,
204        )
205
206        # plot the test metrics
207        if hasattr(self, "test_acc_plot_path"):
208            self.plot_test_acc_from_csv(
209                csv_path=self.test_acc_csv_path,
210                plot_path=self.test_acc_plot_path,
211            )
212
213    def save_test_acc_to_csv(
214        self,
215        csv_path: str,
216    ) -> None:
217        r"""Save the test accuracy metrics of all tasks in multi-task learning to an CSV file.
218
219        **Args:**
220        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
221        """
222        all_task_ids = list(self.acc_test.keys())
223        fieldnames = ["average_accuracy"] + [
224            f"test_on_task_{task_id}" for task_id in all_task_ids
225        ]
226        new_line = {}
227
228        # construct the columns and calculate the average accuracy over tasks at the same time
229        average_accuracy_over_tasks = MeanMetric().to(
230            device=next(iter(self.acc_test.values())).device
231        )
232        for task_id in all_task_ids:
233            acc = self.acc_test[task_id].compute().item()
234            new_line[f"test_on_task_{task_id}"] = acc
235            average_accuracy_over_tasks(acc)
236        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
237
238        # write
239        with open(csv_path, "w", encoding="utf-8") as file:
240            writer = csv.DictWriter(file, fieldnames=fieldnames)
241            writer.writeheader()
242            writer.writerow(new_line)
243
244    def plot_test_acc_from_csv(self, csv_path: str, plot_path: str) -> None:
245        """Plot the test accuracy bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.
246
247        **Args:**
248        - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test accuracy metric.
249        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc.png'.
250        """
251        data = pd.read_csv(csv_path)
252
253        # extract all accuracy columns including average
254        all_columns = data.columns.tolist()
255        task_ids = list(range(len(all_columns)))  # assign index-based positions
256        labels = [
257            (
258                col.replace("test_on_task_", "Task ")
259                if "test_on_task_" in col
260                else "Average"
261            )
262            for col in all_columns
263        ]
264        accuracies = data.iloc[0][all_columns].values
265
266        # plot the accuracy bar chart over tasks
267        fig, ax = plt.subplots(figsize=(16, 9))
268        ax.bar(
269            task_ids,
270            accuracies,
271            color="skyblue",
272            edgecolor="black",
273        )
274        ax.set_xlabel("Task", fontsize=16)
275        ax.set_ylabel("Accuracy", fontsize=16)
276        ax.grid(True)
277        ax.set_xticks(task_ids)
278        ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14)
279        ax.set_yticks([i * 0.05 for i in range(21)])
280        ax.set_yticklabels(
281            [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14
282        )
283        fig.tight_layout()
284        fig.savefig(plot_path)
285        plt.close(fig)
class MTLAccuracy(clarena.metrics.base.MetricCallback):
 26class MTLAccuracy(MetricCallback):
 27    r"""Provides all actions that are related to MTL accuracy metrics, which include:
 28
 29    - Defining, initializing and recording accuracy metric.
 30    - Logging training and validation accuracy metric to Lightning loggers in real time.
 31    - Saving test accuracy metric to files.
 32    - Visualizing test accuracy metric as plots.
 33
 34    The callback is able to produce the following outputs:
 35
 36    - CSV files for test accuracy of all tasks and average accuracy.
 37    - Bar charts for test accuracy of all tasks.
 38    """
 39
 40    def __init__(
 41        self,
 42        save_dir: str,
 43        test_acc_csv_name: str = "acc.csv",
 44        test_acc_plot_name: str | None = None,
 45    ) -> None:
 46        r"""
 47        **Args:**
 48        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 49        - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file.
 50        - **test_acc_plot_name** (`str` | `None`): file name to save accuracy plot. If `None`, no file will be saved.
 51        """
 52        super().__init__(save_dir=save_dir)
 53
 54        # paths
 55        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 56        r"""The path to save test accuracy of all tasks and average accuracy CSV file."""
 57        if test_acc_plot_name:
 58            self.test_acc_plot_path: str = os.path.join(save_dir, test_acc_plot_name)
 59            r"""The path to save test accuracy plot."""
 60
 61        # training accumulated metrics
 62        self.acc_training_epoch: MeanMetricBatch
 63        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """
 64
 65        # validation accumulated metrics
 66        self.acc_val: dict[int, MeanMetricBatch] = {}
 67        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics."""
 68
 69        # test accumulated metrics
 70        self.acc_test: dict[int, MeanMetricBatch] = {}
 71        r"""Test classification accuracy of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics."""
 72
 73    def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None:
 74        r"""Initialize training and validation metrics."""
 75
 76        # initialize training metrics
 77        self.acc_training_epoch = MeanMetricBatch()
 78
 79        # initialize validation metrics
 80        self.acc_val = {
 81            task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks
 82        }
 83
 84    def on_train_batch_end(
 85        self,
 86        trainer: Trainer,
 87        pl_module: MTLAlgorithm,
 88        outputs: dict[str, Any],
 89        batch: Any,
 90        batch_idx: int,
 91    ) -> None:
 92        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 93
 94        **Args:**
 95        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`.
 96        - **batch** (`Any`): the training data batch.
 97        """
 98        # get the batch size
 99        batch_size = len(batch)
100
101        # get training metrics values of current training batch from the outputs of the `training_step()`
102        acc_batch = outputs["acc"]
103
104        # update accumulated training metrics to calculate training metrics of the epoch
105        self.acc_training_epoch.update(acc_batch, batch_size)
106
107        # log training metrics of current training batch to Lightning loggers
108        pl_module.log("train/acc_batch", acc_batch, prog_bar=True)
109
110        # log accumulated training metrics till this training batch to Lightning loggers
111        pl_module.log(
112            "task/train/acc",
113            self.acc_training_epoch.compute(),
114            prog_bar=True,
115        )
116
117    def on_train_epoch_end(
118        self,
119        trainer: Trainer,
120        pl_module: MTLAlgorithm,
121    ) -> None:
122        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
123
124        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
125        self.acc_training_epoch.reset()
126
127    def on_validation_batch_end(
128        self,
129        trainer: Trainer,
130        pl_module: MTLAlgorithm,
131        outputs: dict[str, Any],
132        batch: Any,
133        batch_idx: int,
134        dataloader_idx: int = 0,
135    ) -> None:
136        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
137
138        **Args:**
139        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`.
140        - **batch** (`Any`): the validation data batch.
141        - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
142        """
143
144        # get the batch size
145        batch_size = len(batch)
146
147        val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx)
148
149        # get the metrics values of the batch from the outputs
150        acc_batch = outputs["acc"]
151
152        # update the accumulated metrics in order to calculate the validation metrics
153        self.acc_val[val_task_id].update(acc_batch, batch_size)
154
155    def on_test_start(
156        self,
157        trainer: Trainer,
158        pl_module: MTLAlgorithm,
159    ) -> None:
160        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
161
162        # initialize test metrics for current and previous tasks
163        self.acc_test = {
164            task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks
165        }
166
167    def on_test_batch_end(
168        self,
169        trainer: Trainer,
170        pl_module: MTLAlgorithm,
171        outputs: dict[str, Any],
172        batch: Any,
173        batch_idx: int,
174        dataloader_idx: int = 0,
175    ) -> None:
176        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
177
178        **Args:**
179        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`.
180        - **batch** (`Any`): the validation data batch.
181        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
182        """
183
184        # get the batch size
185        batch_size = len(batch)
186
187        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
188
189        # get the metrics values of the batch from the outputs
190        acc_batch = outputs["acc"]
191
192        # update the accumulated metrics in order to calculate the metrics of the epoch
193        self.acc_test[test_task_id].update(acc_batch, batch_size)
194
195    def on_test_epoch_end(
196        self,
197        trainer: Trainer,
198        pl_module: MTLAlgorithm,
199    ) -> None:
200        r"""Save and plot test metrics at the end of test."""
201
202        # save (update) the test metrics to CSV files
203        self.save_test_acc_to_csv(
204            csv_path=self.test_acc_csv_path,
205        )
206
207        # plot the test metrics
208        if hasattr(self, "test_acc_plot_path"):
209            self.plot_test_acc_from_csv(
210                csv_path=self.test_acc_csv_path,
211                plot_path=self.test_acc_plot_path,
212            )
213
214    def save_test_acc_to_csv(
215        self,
216        csv_path: str,
217    ) -> None:
218        r"""Save the test accuracy metrics of all tasks in multi-task learning to an CSV file.
219
220        **Args:**
221        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
222        """
223        all_task_ids = list(self.acc_test.keys())
224        fieldnames = ["average_accuracy"] + [
225            f"test_on_task_{task_id}" for task_id in all_task_ids
226        ]
227        new_line = {}
228
229        # construct the columns and calculate the average accuracy over tasks at the same time
230        average_accuracy_over_tasks = MeanMetric().to(
231            device=next(iter(self.acc_test.values())).device
232        )
233        for task_id in all_task_ids:
234            acc = self.acc_test[task_id].compute().item()
235            new_line[f"test_on_task_{task_id}"] = acc
236            average_accuracy_over_tasks(acc)
237        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
238
239        # write
240        with open(csv_path, "w", encoding="utf-8") as file:
241            writer = csv.DictWriter(file, fieldnames=fieldnames)
242            writer.writeheader()
243            writer.writerow(new_line)
244
245    def plot_test_acc_from_csv(self, csv_path: str, plot_path: str) -> None:
246        """Plot the test accuracy bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.
247
248        **Args:**
249        - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test accuracy metric.
250        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc.png'.
251        """
252        data = pd.read_csv(csv_path)
253
254        # extract all accuracy columns including average
255        all_columns = data.columns.tolist()
256        task_ids = list(range(len(all_columns)))  # assign index-based positions
257        labels = [
258            (
259                col.replace("test_on_task_", "Task ")
260                if "test_on_task_" in col
261                else "Average"
262            )
263            for col in all_columns
264        ]
265        accuracies = data.iloc[0][all_columns].values
266
267        # plot the accuracy bar chart over tasks
268        fig, ax = plt.subplots(figsize=(16, 9))
269        ax.bar(
270            task_ids,
271            accuracies,
272            color="skyblue",
273            edgecolor="black",
274        )
275        ax.set_xlabel("Task", fontsize=16)
276        ax.set_ylabel("Accuracy", fontsize=16)
277        ax.grid(True)
278        ax.set_xticks(task_ids)
279        ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14)
280        ax.set_yticks([i * 0.05 for i in range(21)])
281        ax.set_yticklabels(
282            [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14
283        )
284        fig.tight_layout()
285        fig.savefig(plot_path)
286        plt.close(fig)

Provides all actions that are related to MTL accuracy metrics, which include:

  • Defining, initializing and recording accuracy metric.
  • Logging training and validation accuracy metric to Lightning loggers in real time.
  • Saving test accuracy metric to files.
  • Visualizing test accuracy metric as plots.

The callback is able to produce the following outputs:

  • CSV files for test accuracy of all tasks and average accuracy.
  • Bar charts for test accuracy of all tasks.
MTLAccuracy( save_dir: str, test_acc_csv_name: str = 'acc.csv', test_acc_plot_name: str | None = None)
40    def __init__(
41        self,
42        save_dir: str,
43        test_acc_csv_name: str = "acc.csv",
44        test_acc_plot_name: str | None = None,
45    ) -> None:
46        r"""
47        **Args:**
48        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
49        - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file.
50        - **test_acc_plot_name** (`str` | `None`): file name to save accuracy plot. If `None`, no file will be saved.
51        """
52        super().__init__(save_dir=save_dir)
53
54        # paths
55        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
56        r"""The path to save test accuracy of all tasks and average accuracy CSV file."""
57        if test_acc_plot_name:
58            self.test_acc_plot_path: str = os.path.join(save_dir, test_acc_plot_name)
59            r"""The path to save test accuracy plot."""
60
61        # training accumulated metrics
62        self.acc_training_epoch: MeanMetricBatch
63        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """
64
65        # validation accumulated metrics
66        self.acc_val: dict[int, MeanMetricBatch] = {}
67        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics."""
68
69        # test accumulated metrics
70        self.acc_test: dict[int, MeanMetricBatch] = {}
71        r"""Test classification accuracy of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics."""

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
  • test_acc_csv_name (str): file name to save test accuracy of all tasks and average accuracy as CSV file.
  • test_acc_plot_name (str | None): file name to save accuracy plot. If None, no file will be saved.
test_acc_csv_path: str

The path to save test accuracy of all tasks and average accuracy CSV file.

Classification accuracy of training epoch. Accumulated and calculated from the training batches.

Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. Keys are task IDs and values are the corresponding metrics.

Test classification accuracy of all tasks. Accumulated and calculated from the test batches. Keys are task IDs and values are the corresponding metrics.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
73    def on_fit_start(self, trainer: Trainer, pl_module: MTLAlgorithm) -> None:
74        r"""Initialize training and validation metrics."""
75
76        # initialize training metrics
77        self.acc_training_epoch = MeanMetricBatch()
78
79        # initialize validation metrics
80        self.acc_val = {
81            task_id: MeanMetricBatch() for task_id in trainer.datamodule.train_tasks
82        }

Initialize training and validation metrics.

def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 84    def on_train_batch_end(
 85        self,
 86        trainer: Trainer,
 87        pl_module: MTLAlgorithm,
 88        outputs: dict[str, Any],
 89        batch: Any,
 90        batch_idx: int,
 91    ) -> None:
 92        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 93
 94        **Args:**
 95        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `MTLAlgorithm`.
 96        - **batch** (`Any`): the training data batch.
 97        """
 98        # get the batch size
 99        batch_size = len(batch)
100
101        # get training metrics values of current training batch from the outputs of the `training_step()`
102        acc_batch = outputs["acc"]
103
104        # update accumulated training metrics to calculate training metrics of the epoch
105        self.acc_training_epoch.update(acc_batch, batch_size)
106
107        # log training metrics of current training batch to Lightning loggers
108        pl_module.log("train/acc_batch", acc_batch, prog_bar=True)
109
110        # log accumulated training metrics till this training batch to Lightning loggers
111        pl_module.log(
112            "task/train/acc",
113            self.acc_training_epoch.compute(),
114            prog_bar=True,
115        )

Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, the returns of the training_step() method in the MTLAlgorithm.
  • batch (Any): the training data batch.
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
117    def on_train_epoch_end(
118        self,
119        trainer: Trainer,
120        pl_module: MTLAlgorithm,
121    ) -> None:
122        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
123
124        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
125        self.acc_training_epoch.reset()

Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.

def on_validation_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
127    def on_validation_batch_end(
128        self,
129        trainer: Trainer,
130        pl_module: MTLAlgorithm,
131        outputs: dict[str, Any],
132        batch: Any,
133        batch_idx: int,
134        dataloader_idx: int = 0,
135    ) -> None:
136        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
137
138        **Args:**
139        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `MTLAlgorithm`.
140        - **batch** (`Any`): the validation data batch.
141        - **dataloader_idx** (`int`): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
142        """
143
144        # get the batch size
145        batch_size = len(batch)
146
147        val_task_id = pl_module.get_val_task_id_from_dataloader_idx(dataloader_idx)
148
149        # get the metrics values of the batch from the outputs
150        acc_batch = outputs["acc"]
151
152        # update the accumulated metrics in order to calculate the validation metrics
153        self.acc_val[val_task_id].update(acc_batch, batch_size)

Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.

Args:

  • outputs (dict[str, Any]): the outputs of the validation step, which is the returns of the validation_step() method in the MTLAlgorithm.
  • batch (Any): the validation data batch.
  • dataloader_idx (int): the task ID of the validation dataloader. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
155    def on_test_start(
156        self,
157        trainer: Trainer,
158        pl_module: MTLAlgorithm,
159    ) -> None:
160        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
161
162        # initialize test metrics for current and previous tasks
163        self.acc_test = {
164            task_id: MeanMetricBatch() for task_id in trainer.datamodule.eval_tasks
165        }

Initialize the metrics for testing each seen task in the beginning of a task's testing.

def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
167    def on_test_batch_end(
168        self,
169        trainer: Trainer,
170        pl_module: MTLAlgorithm,
171        outputs: dict[str, Any],
172        batch: Any,
173        batch_idx: int,
174        dataloader_idx: int = 0,
175    ) -> None:
176        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
177
178        **Args:**
179        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `MTLAlgorithm`.
180        - **batch** (`Any`): the validation data batch.
181        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
182        """
183
184        # get the batch size
185        batch_size = len(batch)
186
187        test_task_id = pl_module.get_test_task_id_from_dataloader_idx(dataloader_idx)
188
189        # get the metrics values of the batch from the outputs
190        acc_batch = outputs["acc"]
191
192        # update the accumulated metrics in order to calculate the metrics of the epoch
193        self.acc_test[test_task_id].update(acc_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the MTLAlgorithm.
  • batch (Any): the validation data batch.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.mtl_algorithms.MTLAlgorithm) -> None:
195    def on_test_epoch_end(
196        self,
197        trainer: Trainer,
198        pl_module: MTLAlgorithm,
199    ) -> None:
200        r"""Save and plot test metrics at the end of test."""
201
202        # save (update) the test metrics to CSV files
203        self.save_test_acc_to_csv(
204            csv_path=self.test_acc_csv_path,
205        )
206
207        # plot the test metrics
208        if hasattr(self, "test_acc_plot_path"):
209            self.plot_test_acc_from_csv(
210                csv_path=self.test_acc_csv_path,
211                plot_path=self.test_acc_plot_path,
212            )

Save and plot test metrics at the end of test.

def save_test_acc_to_csv(self, csv_path: str) -> None:
214    def save_test_acc_to_csv(
215        self,
216        csv_path: str,
217    ) -> None:
218        r"""Save the test accuracy metrics of all tasks in multi-task learning to an CSV file.
219
220        **Args:**
221        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
222        """
223        all_task_ids = list(self.acc_test.keys())
224        fieldnames = ["average_accuracy"] + [
225            f"test_on_task_{task_id}" for task_id in all_task_ids
226        ]
227        new_line = {}
228
229        # construct the columns and calculate the average accuracy over tasks at the same time
230        average_accuracy_over_tasks = MeanMetric().to(
231            device=next(iter(self.acc_test.values())).device
232        )
233        for task_id in all_task_ids:
234            acc = self.acc_test[task_id].compute().item()
235            new_line[f"test_on_task_{task_id}"] = acc
236            average_accuracy_over_tasks(acc)
237        new_line["average_accuracy"] = average_accuracy_over_tasks.compute().item()
238
239        # write
240        with open(csv_path, "w", encoding="utf-8") as file:
241            writer = csv.DictWriter(file, fieldnames=fieldnames)
242            writer.writeheader()
243            writer.writerow(new_line)

Save the test accuracy metrics of all tasks in multi-task learning to an CSV file.

Args:

  • csv_path (str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
def plot_test_acc_from_csv(self, csv_path: str, plot_path: str) -> None:
245    def plot_test_acc_from_csv(self, csv_path: str, plot_path: str) -> None:
246        """Plot the test accuracy bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.
247
248        **Args:**
249        - **csv_path** (`str`): the path to the csv file where the `utils.save_test_acc_csv()` saved the test accuracy metric.
250        - **plot_path** (`str`): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc.png'.
251        """
252        data = pd.read_csv(csv_path)
253
254        # extract all accuracy columns including average
255        all_columns = data.columns.tolist()
256        task_ids = list(range(len(all_columns)))  # assign index-based positions
257        labels = [
258            (
259                col.replace("test_on_task_", "Task ")
260                if "test_on_task_" in col
261                else "Average"
262            )
263            for col in all_columns
264        ]
265        accuracies = data.iloc[0][all_columns].values
266
267        # plot the accuracy bar chart over tasks
268        fig, ax = plt.subplots(figsize=(16, 9))
269        ax.bar(
270            task_ids,
271            accuracies,
272            color="skyblue",
273            edgecolor="black",
274        )
275        ax.set_xlabel("Task", fontsize=16)
276        ax.set_ylabel("Accuracy", fontsize=16)
277        ax.grid(True)
278        ax.set_xticks(task_ids)
279        ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=14)
280        ax.set_yticks([i * 0.05 for i in range(21)])
281        ax.set_yticklabels(
282            [f"{tick:.2f}" for tick in [i * 0.05 for i in range(21)]], fontsize=14
283        )
284        fig.tight_layout()
285        fig.savefig(plot_path)
286        plt.close(fig)

Plot the test accuracy bar chart of all tasks in multi-task learning from saved CSV file and save the plot to the designated directory.

Args:

  • csv_path (str): the path to the csv file where the utils.save_test_acc_csv() saved the test accuracy metric.
  • plot_path (str): the path to save plot. Better same as the output directory of the experiment. E.g. './outputs/expr_name/1970-01-01_00-00-00/acc.png'.