clarena.metrics.stl_acc

The submodule in metrics for STLAccuracy.

  1r"""
  2The submodule in `metrics` for `STLAccuracy`.
  3"""
  4
  5__all__ = ["STLAccuracy"]
  6
  7import csv
  8import logging
  9import os
 10from typing import Any
 11
 12from lightning import Trainer
 13
 14from clarena.metrics import MetricCallback
 15from clarena.stl_algorithms import STLAlgorithm
 16from clarena.utils.metrics import MeanMetricBatch
 17
 18# always get logger for built-in logging in each module
 19pylogger = logging.getLogger(__name__)
 20
 21
 22class STLAccuracy(MetricCallback):
 23    r"""Provides all actions that are related to STL accuracy metric, which include:
 24
 25    - Defining, initializing and recording accuracy metric.
 26    - Logging training and validation accuracy metric to Lightning loggers in real time.
 27
 28    Saving test accuracy metric to files.
 29
 30    - The callback is able to produce the following outputs:
 31    - CSV files for test accuracy.
 32    """
 33
 34    def __init__(
 35        self,
 36        save_dir: str,
 37        test_acc_csv_name: str = "acc.csv",
 38    ) -> None:
 39        r"""
 40        **Args:**
 41        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 42        - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file.
 43        """
 44        super().__init__(save_dir=save_dir)
 45
 46        # paths
 47        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 48        r"""The path to save test accuracy of all tasks and average accuracy CSV file."""
 49
 50        # training accumulated metrics
 51        self.acc_training_epoch: MeanMetricBatch
 52        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """
 53
 54        # validation accumulated metrics
 55        self.acc_val: MeanMetricBatch
 56        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. """
 57
 58        # test accumulated metrics
 59        self.acc_test: MeanMetricBatch
 60        r"""Test classification accuracy. Accumulated and calculated from the test batches."""
 61
 62    def on_fit_start(self, trainer: Trainer, pl_module: STLAlgorithm) -> None:
 63        r"""Initialize training and validation metrics."""
 64
 65        # initialize training metrics
 66        self.acc_training_epoch = MeanMetricBatch()
 67
 68        # initialize validation metrics
 69        self.acc_val = MeanMetricBatch()
 70
 71    def on_train_batch_end(
 72        self,
 73        trainer: Trainer,
 74        pl_module: STLAlgorithm,
 75        outputs: dict[str, Any],
 76        batch: Any,
 77        batch_idx: int,
 78    ) -> None:
 79        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 80
 81        **Args:**
 82        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `STLAlgorithm`.
 83        - **batch** (`Any`): the training data batch.
 84        """
 85        # get the batch size
 86        batch_size = len(batch)
 87
 88        # get training metrics values of current training batch from the outputs of the `training_step()`
 89        acc_batch = outputs["acc"]
 90
 91        # update accumulated training metrics to calculate training metrics of the epoch
 92        self.acc_training_epoch.update(acc_batch, batch_size)
 93
 94        # log training metrics of current training batch to Lightning loggers
 95        pl_module.log("train/acc_batch", acc_batch, prog_bar=True)
 96
 97        # log accumulated training metrics till this training batch to Lightning loggers
 98        pl_module.log(
 99            "task/train/acc",
100            self.acc_training_epoch.compute(),
101            prog_bar=True,
102        )
103
104    def on_train_epoch_end(
105        self,
106        trainer: Trainer,
107        pl_module: STLAlgorithm,
108    ) -> None:
109        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
110
111        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
112        pl_module.log(
113            "learning_curve/train/acc",
114            self.acc_training_epoch.compute(),
115            on_epoch=True,
116            prog_bar=True,
117        )
118
119        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
120        self.acc_training_epoch.reset()
121
122    def on_validation_batch_end(
123        self,
124        trainer: Trainer,
125        pl_module: STLAlgorithm,
126        outputs: dict[str, Any],
127        batch: Any,
128        batch_idx: int,
129    ) -> None:
130        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
131
132        **Args:**
133        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `STLAlgorithm`.
134        - **batch** (`Any`): the validation data batch.
135        """
136
137        # get the batch size
138        batch_size = len(batch)
139
140        # get the metrics values of the batch from the outputs
141        acc_batch = outputs["acc"]
142
143        # update the accumulated metrics in order to calculate the validation metrics
144        self.acc_val.update(acc_batch, batch_size)
145
146    def on_validation_epoch_end(
147        self,
148        trainer: Trainer,
149        pl_module: STLAlgorithm,
150    ) -> None:
151        r"""Log validation metrics to plot learning curves."""
152
153        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
154        pl_module.log(
155            "learning_curve/val/acc",
156            self.acc_val.compute(),
157            on_epoch=True,
158            prog_bar=True,
159        )
160
161    def on_test_start(
162        self,
163        trainer: Trainer,
164        pl_module: STLAlgorithm,
165    ) -> None:
166        r"""Initialize the testing metrics."""
167
168        # initialize test metrics for current and previous tasks
169        self.acc_test = MeanMetricBatch()
170
171    def on_test_batch_end(
172        self,
173        trainer: Trainer,
174        pl_module: STLAlgorithm,
175        outputs: dict[str, Any],
176        batch: Any,
177        batch_idx: int,
178    ) -> None:
179        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
180
181        **Args:**
182        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `STLAlgorithm`.
183        - **batch** (`Any`): the test data batch.
184        """
185
186        # get the batch size
187        batch_size = len(batch)
188
189        # get the metrics values of the batch from the outputs
190        acc_batch = outputs["acc"]
191
192        # update the accumulated metrics in order to calculate the metrics of the epoch
193        self.acc_test.update(acc_batch, batch_size)
194
195    def on_test_epoch_end(
196        self,
197        trainer: Trainer,
198        pl_module: STLAlgorithm,
199    ) -> None:
200        r"""Save and plot test metrics at the end of test."""
201
202        # save (update) the test metrics to CSV files
203        self.save_test_acc_to_csv(
204            csv_path=self.test_acc_csv_path,
205        )
206
207    def save_test_acc_to_csv(
208        self,
209        csv_path: str,
210    ) -> None:
211        r"""Save the test accuracy metrics of all tasks in single-task learning to an CSV file.
212
213        **Args:**
214        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
215        """
216        fieldnames = ["accuracy"]
217        new_line = {}
218        new_line["accuracy"] = self.acc_test.compute().item()
219
220        # write
221        with open(csv_path, "w", encoding="utf-8") as file:
222            writer = csv.DictWriter(file, fieldnames=fieldnames)
223            writer.writeheader()
224            writer.writerow(new_line)
class STLAccuracy(clarena.metrics.base.MetricCallback):
 23class STLAccuracy(MetricCallback):
 24    r"""Provides all actions that are related to STL accuracy metric, which include:
 25
 26    - Defining, initializing and recording accuracy metric.
 27    - Logging training and validation accuracy metric to Lightning loggers in real time.
 28
 29    Saving test accuracy metric to files.
 30
 31    - The callback is able to produce the following outputs:
 32    - CSV files for test accuracy.
 33    """
 34
 35    def __init__(
 36        self,
 37        save_dir: str,
 38        test_acc_csv_name: str = "acc.csv",
 39    ) -> None:
 40        r"""
 41        **Args:**
 42        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 43        - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file.
 44        """
 45        super().__init__(save_dir=save_dir)
 46
 47        # paths
 48        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
 49        r"""The path to save test accuracy of all tasks and average accuracy CSV file."""
 50
 51        # training accumulated metrics
 52        self.acc_training_epoch: MeanMetricBatch
 53        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """
 54
 55        # validation accumulated metrics
 56        self.acc_val: MeanMetricBatch
 57        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. """
 58
 59        # test accumulated metrics
 60        self.acc_test: MeanMetricBatch
 61        r"""Test classification accuracy. Accumulated and calculated from the test batches."""
 62
 63    def on_fit_start(self, trainer: Trainer, pl_module: STLAlgorithm) -> None:
 64        r"""Initialize training and validation metrics."""
 65
 66        # initialize training metrics
 67        self.acc_training_epoch = MeanMetricBatch()
 68
 69        # initialize validation metrics
 70        self.acc_val = MeanMetricBatch()
 71
 72    def on_train_batch_end(
 73        self,
 74        trainer: Trainer,
 75        pl_module: STLAlgorithm,
 76        outputs: dict[str, Any],
 77        batch: Any,
 78        batch_idx: int,
 79    ) -> None:
 80        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 81
 82        **Args:**
 83        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `STLAlgorithm`.
 84        - **batch** (`Any`): the training data batch.
 85        """
 86        # get the batch size
 87        batch_size = len(batch)
 88
 89        # get training metrics values of current training batch from the outputs of the `training_step()`
 90        acc_batch = outputs["acc"]
 91
 92        # update accumulated training metrics to calculate training metrics of the epoch
 93        self.acc_training_epoch.update(acc_batch, batch_size)
 94
 95        # log training metrics of current training batch to Lightning loggers
 96        pl_module.log("train/acc_batch", acc_batch, prog_bar=True)
 97
 98        # log accumulated training metrics till this training batch to Lightning loggers
 99        pl_module.log(
100            "task/train/acc",
101            self.acc_training_epoch.compute(),
102            prog_bar=True,
103        )
104
105    def on_train_epoch_end(
106        self,
107        trainer: Trainer,
108        pl_module: STLAlgorithm,
109    ) -> None:
110        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
111
112        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
113        pl_module.log(
114            "learning_curve/train/acc",
115            self.acc_training_epoch.compute(),
116            on_epoch=True,
117            prog_bar=True,
118        )
119
120        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
121        self.acc_training_epoch.reset()
122
123    def on_validation_batch_end(
124        self,
125        trainer: Trainer,
126        pl_module: STLAlgorithm,
127        outputs: dict[str, Any],
128        batch: Any,
129        batch_idx: int,
130    ) -> None:
131        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
132
133        **Args:**
134        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `STLAlgorithm`.
135        - **batch** (`Any`): the validation data batch.
136        """
137
138        # get the batch size
139        batch_size = len(batch)
140
141        # get the metrics values of the batch from the outputs
142        acc_batch = outputs["acc"]
143
144        # update the accumulated metrics in order to calculate the validation metrics
145        self.acc_val.update(acc_batch, batch_size)
146
147    def on_validation_epoch_end(
148        self,
149        trainer: Trainer,
150        pl_module: STLAlgorithm,
151    ) -> None:
152        r"""Log validation metrics to plot learning curves."""
153
154        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
155        pl_module.log(
156            "learning_curve/val/acc",
157            self.acc_val.compute(),
158            on_epoch=True,
159            prog_bar=True,
160        )
161
162    def on_test_start(
163        self,
164        trainer: Trainer,
165        pl_module: STLAlgorithm,
166    ) -> None:
167        r"""Initialize the testing metrics."""
168
169        # initialize test metrics for current and previous tasks
170        self.acc_test = MeanMetricBatch()
171
172    def on_test_batch_end(
173        self,
174        trainer: Trainer,
175        pl_module: STLAlgorithm,
176        outputs: dict[str, Any],
177        batch: Any,
178        batch_idx: int,
179    ) -> None:
180        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
181
182        **Args:**
183        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `STLAlgorithm`.
184        - **batch** (`Any`): the test data batch.
185        """
186
187        # get the batch size
188        batch_size = len(batch)
189
190        # get the metrics values of the batch from the outputs
191        acc_batch = outputs["acc"]
192
193        # update the accumulated metrics in order to calculate the metrics of the epoch
194        self.acc_test.update(acc_batch, batch_size)
195
196    def on_test_epoch_end(
197        self,
198        trainer: Trainer,
199        pl_module: STLAlgorithm,
200    ) -> None:
201        r"""Save and plot test metrics at the end of test."""
202
203        # save (update) the test metrics to CSV files
204        self.save_test_acc_to_csv(
205            csv_path=self.test_acc_csv_path,
206        )
207
208    def save_test_acc_to_csv(
209        self,
210        csv_path: str,
211    ) -> None:
212        r"""Save the test accuracy metrics of all tasks in single-task learning to an CSV file.
213
214        **Args:**
215        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
216        """
217        fieldnames = ["accuracy"]
218        new_line = {}
219        new_line["accuracy"] = self.acc_test.compute().item()
220
221        # write
222        with open(csv_path, "w", encoding="utf-8") as file:
223            writer = csv.DictWriter(file, fieldnames=fieldnames)
224            writer.writeheader()
225            writer.writerow(new_line)

Provides all actions that are related to STL accuracy metric, which include:

  • Defining, initializing and recording accuracy metric.
  • Logging training and validation accuracy metric to Lightning loggers in real time.

Saving test accuracy metric to files.

  • The callback is able to produce the following outputs:
  • CSV files for test accuracy.
STLAccuracy(save_dir: str, test_acc_csv_name: str = 'acc.csv')
35    def __init__(
36        self,
37        save_dir: str,
38        test_acc_csv_name: str = "acc.csv",
39    ) -> None:
40        r"""
41        **Args:**
42        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
43        - **test_acc_csv_name** (`str`): file name to save test accuracy of all tasks and average accuracy as CSV file.
44        """
45        super().__init__(save_dir=save_dir)
46
47        # paths
48        self.test_acc_csv_path: str = os.path.join(save_dir, test_acc_csv_name)
49        r"""The path to save test accuracy of all tasks and average accuracy CSV file."""
50
51        # training accumulated metrics
52        self.acc_training_epoch: MeanMetricBatch
53        r"""Classification accuracy of training epoch. Accumulated and calculated from the training batches. """
54
55        # validation accumulated metrics
56        self.acc_val: MeanMetricBatch
57        r"""Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches. """
58
59        # test accumulated metrics
60        self.acc_test: MeanMetricBatch
61        r"""Test classification accuracy. Accumulated and calculated from the test batches."""

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
  • test_acc_csv_name (str): file name to save test accuracy of all tasks and average accuracy as CSV file.
test_acc_csv_path: str

The path to save test accuracy of all tasks and average accuracy CSV file.

Classification accuracy of training epoch. Accumulated and calculated from the training batches.

Validation classification accuracy of the model after training epoch. Accumulated and calculated from the validation batches.

Test classification accuracy. Accumulated and calculated from the test batches.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
63    def on_fit_start(self, trainer: Trainer, pl_module: STLAlgorithm) -> None:
64        r"""Initialize training and validation metrics."""
65
66        # initialize training metrics
67        self.acc_training_epoch = MeanMetricBatch()
68
69        # initialize validation metrics
70        self.acc_val = MeanMetricBatch()

Initialize training and validation metrics.

def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 72    def on_train_batch_end(
 73        self,
 74        trainer: Trainer,
 75        pl_module: STLAlgorithm,
 76        outputs: dict[str, Any],
 77        batch: Any,
 78        batch_idx: int,
 79    ) -> None:
 80        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 81
 82        **Args:**
 83        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `STLAlgorithm`.
 84        - **batch** (`Any`): the training data batch.
 85        """
 86        # get the batch size
 87        batch_size = len(batch)
 88
 89        # get training metrics values of current training batch from the outputs of the `training_step()`
 90        acc_batch = outputs["acc"]
 91
 92        # update accumulated training metrics to calculate training metrics of the epoch
 93        self.acc_training_epoch.update(acc_batch, batch_size)
 94
 95        # log training metrics of current training batch to Lightning loggers
 96        pl_module.log("train/acc_batch", acc_batch, prog_bar=True)
 97
 98        # log accumulated training metrics till this training batch to Lightning loggers
 99        pl_module.log(
100            "task/train/acc",
101            self.acc_training_epoch.compute(),
102            prog_bar=True,
103        )

Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, the returns of the training_step() method in the STLAlgorithm.
  • batch (Any): the training data batch.
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
105    def on_train_epoch_end(
106        self,
107        trainer: Trainer,
108        pl_module: STLAlgorithm,
109    ) -> None:
110        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
111
112        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
113        pl_module.log(
114            "learning_curve/train/acc",
115            self.acc_training_epoch.compute(),
116            on_epoch=True,
117            prog_bar=True,
118        )
119
120        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
121        self.acc_training_epoch.reset()

Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.

def on_validation_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
123    def on_validation_batch_end(
124        self,
125        trainer: Trainer,
126        pl_module: STLAlgorithm,
127        outputs: dict[str, Any],
128        batch: Any,
129        batch_idx: int,
130    ) -> None:
131        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
132
133        **Args:**
134        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `STLAlgorithm`.
135        - **batch** (`Any`): the validation data batch.
136        """
137
138        # get the batch size
139        batch_size = len(batch)
140
141        # get the metrics values of the batch from the outputs
142        acc_batch = outputs["acc"]
143
144        # update the accumulated metrics in order to calculate the validation metrics
145        self.acc_val.update(acc_batch, batch_size)

Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.

Args:

  • outputs (dict[str, Any]): the outputs of the validation step, which is the returns of the validation_step() method in the STLAlgorithm.
  • batch (Any): the validation data batch.
def on_validation_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
147    def on_validation_epoch_end(
148        self,
149        trainer: Trainer,
150        pl_module: STLAlgorithm,
151    ) -> None:
152        r"""Log validation metrics to plot learning curves."""
153
154        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
155        pl_module.log(
156            "learning_curve/val/acc",
157            self.acc_val.compute(),
158            on_epoch=True,
159            prog_bar=True,
160        )

Log validation metrics to plot learning curves.

def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
162    def on_test_start(
163        self,
164        trainer: Trainer,
165        pl_module: STLAlgorithm,
166    ) -> None:
167        r"""Initialize the testing metrics."""
168
169        # initialize test metrics for current and previous tasks
170        self.acc_test = MeanMetricBatch()

Initialize the testing metrics.

def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
172    def on_test_batch_end(
173        self,
174        trainer: Trainer,
175        pl_module: STLAlgorithm,
176        outputs: dict[str, Any],
177        batch: Any,
178        batch_idx: int,
179    ) -> None:
180        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
181
182        **Args:**
183        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `STLAlgorithm`.
184        - **batch** (`Any`): the test data batch.
185        """
186
187        # get the batch size
188        batch_size = len(batch)
189
190        # get the metrics values of the batch from the outputs
191        acc_batch = outputs["acc"]
192
193        # update the accumulated metrics in order to calculate the metrics of the epoch
194        self.acc_test.update(acc_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the STLAlgorithm.
  • batch (Any): the test data batch.
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
196    def on_test_epoch_end(
197        self,
198        trainer: Trainer,
199        pl_module: STLAlgorithm,
200    ) -> None:
201        r"""Save and plot test metrics at the end of test."""
202
203        # save (update) the test metrics to CSV files
204        self.save_test_acc_to_csv(
205            csv_path=self.test_acc_csv_path,
206        )

Save and plot test metrics at the end of test.

def save_test_acc_to_csv(self, csv_path: str) -> None:
208    def save_test_acc_to_csv(
209        self,
210        csv_path: str,
211    ) -> None:
212        r"""Save the test accuracy metrics of all tasks in single-task learning to an CSV file.
213
214        **Args:**
215        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.
216        """
217        fieldnames = ["accuracy"]
218        new_line = {}
219        new_line["accuracy"] = self.acc_test.compute().item()
220
221        # write
222        with open(csv_path, "w", encoding="utf-8") as file:
223            writer = csv.DictWriter(file, fieldnames=fieldnames)
224            writer.writeheader()
225            writer.writerow(new_line)

Save the test accuracy metrics of all tasks in single-task learning to an CSV file.

Args:

  • csv_path (str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/acc.csv'.