clarena.metrics.stl_loss

The submodule in metrics for STLLoss.

  1r"""
  2The submodule in `metrics` for `STLLoss`.
  3"""
  4
  5__all__ = ["STLLoss"]
  6
  7import csv
  8import logging
  9import os
 10from typing import Any
 11
 12from lightning import Trainer
 13
 14from clarena.metrics import MetricCallback
 15from clarena.stl_algorithms import STLAlgorithm
 16from clarena.utils.metrics import MeanMetricBatch
 17
 18# always get logger for built-in logging in each module
 19pylogger = logging.getLogger(__name__)
 20
 21
 22class STLLoss(MetricCallback):
 23    r"""Provides all actions that are related to STL loss metrics, which include:
 24
 25    - Defining, initializing and recording loss metrics.
 26    - Logging training and validation loss metrics to Lightning loggers in real time.
 27
 28    Saving test loss metrics to files.
 29
 30    - The callback is able to produce the following outputs:
 31    - CSV files for test classification loss.
 32    """
 33
 34    def __init__(
 35        self,
 36        save_dir: str,
 37        test_loss_cls_csv_name: str = "loss_cls.csv",
 38    ) -> None:
 39        r"""
 40        **Args:**
 41        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 42        - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file.
 43        """
 44        super().__init__(save_dir=save_dir)
 45
 46        # paths
 47        self.test_loss_cls_csv_path: str = os.path.join(
 48            save_dir, test_loss_cls_csv_name
 49        )
 50        r"""The path to save test classification loss of all tasks and average classification loss CSV file."""
 51
 52        # training accumulated metrics
 53        self.loss_cls_training_epoch: MeanMetricBatch
 54        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """
 55
 56        # validation accumulated metrics
 57        self.loss_cls_val: MeanMetricBatch
 58        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. """
 59
 60        # test accumulated metrics
 61        self.loss_cls_test: MeanMetricBatch
 62        r"""Test classification loss. Accumulated and calculated from the test batches."""
 63
 64    def on_fit_start(self, trainer: Trainer, pl_module: STLAlgorithm) -> None:
 65        r"""Initialize training and validation metrics."""
 66
 67        # initialize training metrics
 68        self.loss_cls_training_epoch = MeanMetricBatch()
 69
 70        # initialize validation metrics
 71        self.loss_cls_val = MeanMetricBatch()
 72
 73    def on_train_batch_end(
 74        self,
 75        trainer: Trainer,
 76        pl_module: STLAlgorithm,
 77        outputs: dict[str, Any],
 78        batch: Any,
 79        batch_idx: int,
 80    ) -> None:
 81        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 82
 83        **Args:**
 84        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `STLAlgorithm`.
 85        - **batch** (`Any`): the training data batch.
 86        """
 87        # get the batch size
 88        batch_size = len(batch)
 89
 90        # get training metrics values of current training batch from the outputs of the `training_step()`
 91        loss_cls_batch = outputs["loss_cls"]
 92
 93        # update accumulated training metrics to calculate training metrics of the epoch
 94        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
 95
 96        # log training metrics of current training batch to Lightning loggers
 97        pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True)
 98
 99        # log accumulated training metrics till this training batch to Lightning loggers
100        pl_module.log(
101            "task/train/loss_cls",
102            self.loss_cls_training_epoch.compute(),
103            prog_bar=True,
104        )
105
106    def on_train_epoch_end(
107        self,
108        trainer: Trainer,
109        pl_module: STLAlgorithm,
110    ) -> None:
111        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
112
113        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
114        pl_module.log(
115            "learning_curve/train/loss_cls",
116            self.loss_cls_training_epoch.compute(),
117            on_epoch=True,
118            prog_bar=True,
119        )
120
121        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
122        self.loss_cls_training_epoch.reset()
123
124    def on_validation_batch_end(
125        self,
126        trainer: Trainer,
127        pl_module: STLAlgorithm,
128        outputs: dict[str, Any],
129        batch: Any,
130        batch_idx: int,
131    ) -> None:
132        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
133
134        **Args:**
135        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `STLAlgorithm`.
136        - **batch** (`Any`): the validation data batch.
137        """
138
139        # get the batch size
140        batch_size = len(batch)
141
142        # get the metrics values of the batch from the outputs
143        loss_cls_batch = outputs["loss_cls"]
144
145        # update the accumulated metrics in order to calculate the validation metrics
146        self.loss_cls_val.update(loss_cls_batch, batch_size)
147
148    def on_validation_epoch_end(
149        self,
150        trainer: Trainer,
151        pl_module: STLAlgorithm,
152    ) -> None:
153        r"""Log validation metrics to plot learning curves."""
154
155        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
156        pl_module.log(
157            "learning_curve/val/loss_cls",
158            self.loss_cls_val.compute(),
159            on_epoch=True,
160            prog_bar=True,
161        )
162
163    def on_test_start(
164        self,
165        trainer: Trainer,
166        pl_module: STLAlgorithm,
167    ) -> None:
168        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
169
170        # initialize test metrics
171        self.loss_cls_test = MeanMetricBatch()
172
173    def on_test_batch_end(
174        self,
175        trainer: Trainer,
176        pl_module: STLAlgorithm,
177        outputs: dict[str, Any],
178        batch: Any,
179        batch_idx: int,
180    ) -> None:
181        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
182
183        **Args:**
184        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `STLAlgorithm`.
185        - **batch** (`Any`): the test data batch.
186        """
187
188        # get the batch size
189        batch_size = len(batch)
190
191        # get the metrics values of the batch from the outputs
192        loss_cls_batch = outputs["loss_cls"]
193
194        # update the accumulated metrics in order to calculate the metrics of the epoch
195        self.loss_cls_test.update(loss_cls_batch, batch_size)
196
197    def on_test_epoch_end(
198        self,
199        trainer: Trainer,
200        pl_module: STLAlgorithm,
201    ) -> None:
202        r"""Save and plot test metrics at the end of test."""
203
204        # save (update) the test metrics to CSV files
205        self.save_test_loss_cls_to_csv(
206            csv_path=self.test_loss_cls_csv_path,
207        )
208
209    def save_test_loss_cls_to_csv(
210        self,
211        csv_path: str,
212    ) -> None:
213        r"""Save the test classification loss metrics in single-task learning to an CSV file.
214
215        **Args:**
216        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
217        """
218        fieldnames = ["loss_cls"]
219        new_line = {"loss_cls": self.loss_cls_test.compute().item()}
220
221        # write
222        with open(csv_path, "w", encoding="utf-8") as file:
223            writer = csv.DictWriter(file, fieldnames=fieldnames)
224            writer.writeheader()
225            writer.writerow(new_line)
class STLLoss(clarena.metrics.base.MetricCallback):
 23class STLLoss(MetricCallback):
 24    r"""Provides all actions that are related to STL loss metrics, which include:
 25
 26    - Defining, initializing and recording loss metrics.
 27    - Logging training and validation loss metrics to Lightning loggers in real time.
 28
 29    Saving test loss metrics to files.
 30
 31    - The callback is able to produce the following outputs:
 32    - CSV files for test classification loss.
 33    """
 34
 35    def __init__(
 36        self,
 37        save_dir: str,
 38        test_loss_cls_csv_name: str = "loss_cls.csv",
 39    ) -> None:
 40        r"""
 41        **Args:**
 42        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
 43        - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file.
 44        """
 45        super().__init__(save_dir=save_dir)
 46
 47        # paths
 48        self.test_loss_cls_csv_path: str = os.path.join(
 49            save_dir, test_loss_cls_csv_name
 50        )
 51        r"""The path to save test classification loss of all tasks and average classification loss CSV file."""
 52
 53        # training accumulated metrics
 54        self.loss_cls_training_epoch: MeanMetricBatch
 55        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """
 56
 57        # validation accumulated metrics
 58        self.loss_cls_val: MeanMetricBatch
 59        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. """
 60
 61        # test accumulated metrics
 62        self.loss_cls_test: MeanMetricBatch
 63        r"""Test classification loss. Accumulated and calculated from the test batches."""
 64
 65    def on_fit_start(self, trainer: Trainer, pl_module: STLAlgorithm) -> None:
 66        r"""Initialize training and validation metrics."""
 67
 68        # initialize training metrics
 69        self.loss_cls_training_epoch = MeanMetricBatch()
 70
 71        # initialize validation metrics
 72        self.loss_cls_val = MeanMetricBatch()
 73
 74    def on_train_batch_end(
 75        self,
 76        trainer: Trainer,
 77        pl_module: STLAlgorithm,
 78        outputs: dict[str, Any],
 79        batch: Any,
 80        batch_idx: int,
 81    ) -> None:
 82        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 83
 84        **Args:**
 85        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `STLAlgorithm`.
 86        - **batch** (`Any`): the training data batch.
 87        """
 88        # get the batch size
 89        batch_size = len(batch)
 90
 91        # get training metrics values of current training batch from the outputs of the `training_step()`
 92        loss_cls_batch = outputs["loss_cls"]
 93
 94        # update accumulated training metrics to calculate training metrics of the epoch
 95        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
 96
 97        # log training metrics of current training batch to Lightning loggers
 98        pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True)
 99
100        # log accumulated training metrics till this training batch to Lightning loggers
101        pl_module.log(
102            "task/train/loss_cls",
103            self.loss_cls_training_epoch.compute(),
104            prog_bar=True,
105        )
106
107    def on_train_epoch_end(
108        self,
109        trainer: Trainer,
110        pl_module: STLAlgorithm,
111    ) -> None:
112        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
113
114        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
115        pl_module.log(
116            "learning_curve/train/loss_cls",
117            self.loss_cls_training_epoch.compute(),
118            on_epoch=True,
119            prog_bar=True,
120        )
121
122        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
123        self.loss_cls_training_epoch.reset()
124
125    def on_validation_batch_end(
126        self,
127        trainer: Trainer,
128        pl_module: STLAlgorithm,
129        outputs: dict[str, Any],
130        batch: Any,
131        batch_idx: int,
132    ) -> None:
133        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
134
135        **Args:**
136        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `STLAlgorithm`.
137        - **batch** (`Any`): the validation data batch.
138        """
139
140        # get the batch size
141        batch_size = len(batch)
142
143        # get the metrics values of the batch from the outputs
144        loss_cls_batch = outputs["loss_cls"]
145
146        # update the accumulated metrics in order to calculate the validation metrics
147        self.loss_cls_val.update(loss_cls_batch, batch_size)
148
149    def on_validation_epoch_end(
150        self,
151        trainer: Trainer,
152        pl_module: STLAlgorithm,
153    ) -> None:
154        r"""Log validation metrics to plot learning curves."""
155
156        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
157        pl_module.log(
158            "learning_curve/val/loss_cls",
159            self.loss_cls_val.compute(),
160            on_epoch=True,
161            prog_bar=True,
162        )
163
164    def on_test_start(
165        self,
166        trainer: Trainer,
167        pl_module: STLAlgorithm,
168    ) -> None:
169        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
170
171        # initialize test metrics
172        self.loss_cls_test = MeanMetricBatch()
173
174    def on_test_batch_end(
175        self,
176        trainer: Trainer,
177        pl_module: STLAlgorithm,
178        outputs: dict[str, Any],
179        batch: Any,
180        batch_idx: int,
181    ) -> None:
182        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
183
184        **Args:**
185        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `STLAlgorithm`.
186        - **batch** (`Any`): the test data batch.
187        """
188
189        # get the batch size
190        batch_size = len(batch)
191
192        # get the metrics values of the batch from the outputs
193        loss_cls_batch = outputs["loss_cls"]
194
195        # update the accumulated metrics in order to calculate the metrics of the epoch
196        self.loss_cls_test.update(loss_cls_batch, batch_size)
197
198    def on_test_epoch_end(
199        self,
200        trainer: Trainer,
201        pl_module: STLAlgorithm,
202    ) -> None:
203        r"""Save and plot test metrics at the end of test."""
204
205        # save (update) the test metrics to CSV files
206        self.save_test_loss_cls_to_csv(
207            csv_path=self.test_loss_cls_csv_path,
208        )
209
210    def save_test_loss_cls_to_csv(
211        self,
212        csv_path: str,
213    ) -> None:
214        r"""Save the test classification loss metrics in single-task learning to an CSV file.
215
216        **Args:**
217        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
218        """
219        fieldnames = ["loss_cls"]
220        new_line = {"loss_cls": self.loss_cls_test.compute().item()}
221
222        # write
223        with open(csv_path, "w", encoding="utf-8") as file:
224            writer = csv.DictWriter(file, fieldnames=fieldnames)
225            writer.writeheader()
226            writer.writerow(new_line)

Provides all actions that are related to STL loss metrics, which include:

  • Defining, initializing and recording loss metrics.
  • Logging training and validation loss metrics to Lightning loggers in real time.

Saving test loss metrics to files.

  • The callback is able to produce the following outputs:
  • CSV files for test classification loss.
STLLoss(save_dir: str, test_loss_cls_csv_name: str = 'loss_cls.csv')
35    def __init__(
36        self,
37        save_dir: str,
38        test_loss_cls_csv_name: str = "loss_cls.csv",
39    ) -> None:
40        r"""
41        **Args:**
42        - **save_dir** (`str`): The directory where data and figures of metrics will be saved. Better inside the output folder.
43        - **test_loss_cls_csv_name**(`str`): file name to save classification loss of all tasks and average classification loss as CSV file.
44        """
45        super().__init__(save_dir=save_dir)
46
47        # paths
48        self.test_loss_cls_csv_path: str = os.path.join(
49            save_dir, test_loss_cls_csv_name
50        )
51        r"""The path to save test classification loss of all tasks and average classification loss CSV file."""
52
53        # training accumulated metrics
54        self.loss_cls_training_epoch: MeanMetricBatch
55        r"""Classification loss of training epoch. Accumulated and calculated from the training batches. """
56
57        # validation accumulated metrics
58        self.loss_cls_val: MeanMetricBatch
59        r"""Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches. """
60
61        # test accumulated metrics
62        self.loss_cls_test: MeanMetricBatch
63        r"""Test classification loss. Accumulated and calculated from the test batches."""

Args:

  • save_dir (str): The directory where data and figures of metrics will be saved. Better inside the output folder.
  • test_loss_cls_csv_name(str): file name to save classification loss of all tasks and average classification loss as CSV file.
test_loss_cls_csv_path: str

The path to save test classification loss of all tasks and average classification loss CSV file.

loss_cls_training_epoch: clarena.utils.metrics.MeanMetricBatch

Classification loss of training epoch. Accumulated and calculated from the training batches.

Validation classification of the model loss after training epoch. Accumulated and calculated from the validation batches.

Test classification loss. Accumulated and calculated from the test batches.

def on_fit_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
65    def on_fit_start(self, trainer: Trainer, pl_module: STLAlgorithm) -> None:
66        r"""Initialize training and validation metrics."""
67
68        # initialize training metrics
69        self.loss_cls_training_epoch = MeanMetricBatch()
70
71        # initialize validation metrics
72        self.loss_cls_val = MeanMetricBatch()

Initialize training and validation metrics.

def on_train_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 74    def on_train_batch_end(
 75        self,
 76        trainer: Trainer,
 77        pl_module: STLAlgorithm,
 78        outputs: dict[str, Any],
 79        batch: Any,
 80        batch_idx: int,
 81    ) -> None:
 82        r"""Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.
 83
 84        **Args:**
 85        - **outputs** (`dict[str, Any]`): the outputs of the training step, the returns of the `training_step()` method in the `STLAlgorithm`.
 86        - **batch** (`Any`): the training data batch.
 87        """
 88        # get the batch size
 89        batch_size = len(batch)
 90
 91        # get training metrics values of current training batch from the outputs of the `training_step()`
 92        loss_cls_batch = outputs["loss_cls"]
 93
 94        # update accumulated training metrics to calculate training metrics of the epoch
 95        self.loss_cls_training_epoch.update(loss_cls_batch, batch_size)
 96
 97        # log training metrics of current training batch to Lightning loggers
 98        pl_module.log("train/loss_cls_batch", loss_cls_batch, prog_bar=True)
 99
100        # log accumulated training metrics till this training batch to Lightning loggers
101        pl_module.log(
102            "task/train/loss_cls",
103            self.loss_cls_training_epoch.compute(),
104            prog_bar=True,
105        )

Record training metrics from training batch, log metrics of training batch and accumulated metrics of the epoch to Lightning loggers.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, the returns of the training_step() method in the STLAlgorithm.
  • batch (Any): the training data batch.
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
107    def on_train_epoch_end(
108        self,
109        trainer: Trainer,
110        pl_module: STLAlgorithm,
111    ) -> None:
112        r"""Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch."""
113
114        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
115        pl_module.log(
116            "learning_curve/train/loss_cls",
117            self.loss_cls_training_epoch.compute(),
118            on_epoch=True,
119            prog_bar=True,
120        )
121
122        # reset the metrics of training epoch as there are more epochs to go and not only one epoch like in the validation and test
123        self.loss_cls_training_epoch.reset()

Log metrics of training epoch to plot learning curves and reset the metrics accumulation at the end of training epoch.

def on_validation_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
125    def on_validation_batch_end(
126        self,
127        trainer: Trainer,
128        pl_module: STLAlgorithm,
129        outputs: dict[str, Any],
130        batch: Any,
131        batch_idx: int,
132    ) -> None:
133        r"""Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.
134
135        **Args:**
136        - **outputs** (`dict[str, Any]`): the outputs of the validation step, which is the returns of the `validation_step()` method in the `STLAlgorithm`.
137        - **batch** (`Any`): the validation data batch.
138        """
139
140        # get the batch size
141        batch_size = len(batch)
142
143        # get the metrics values of the batch from the outputs
144        loss_cls_batch = outputs["loss_cls"]
145
146        # update the accumulated metrics in order to calculate the validation metrics
147        self.loss_cls_val.update(loss_cls_batch, batch_size)

Accumulating metrics from validation batch. We don't need to log and monitor the metrics of validation batches.

Args:

  • outputs (dict[str, Any]): the outputs of the validation step, which is the returns of the validation_step() method in the STLAlgorithm.
  • batch (Any): the validation data batch.
def on_validation_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
149    def on_validation_epoch_end(
150        self,
151        trainer: Trainer,
152        pl_module: STLAlgorithm,
153    ) -> None:
154        r"""Log validation metrics to plot learning curves."""
155
156        # log the accumulated and computed metrics of the epoch to Lightning loggers, specially for plotting learning curves
157        pl_module.log(
158            "learning_curve/val/loss_cls",
159            self.loss_cls_val.compute(),
160            on_epoch=True,
161            prog_bar=True,
162        )

Log validation metrics to plot learning curves.

def on_test_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
164    def on_test_start(
165        self,
166        trainer: Trainer,
167        pl_module: STLAlgorithm,
168    ) -> None:
169        r"""Initialize the metrics for testing each seen task in the beginning of a task's testing."""
170
171        # initialize test metrics
172        self.loss_cls_test = MeanMetricBatch()

Initialize the metrics for testing each seen task in the beginning of a task's testing.

def on_test_batch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
174    def on_test_batch_end(
175        self,
176        trainer: Trainer,
177        pl_module: STLAlgorithm,
178        outputs: dict[str, Any],
179        batch: Any,
180        batch_idx: int,
181    ) -> None:
182        r"""Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.
183
184        **Args:**
185        - **outputs** (`dict[str, Any]`): the outputs of the test step, which is the returns of the `test_step()` method in the `STLAlgorithm`.
186        - **batch** (`Any`): the test data batch.
187        """
188
189        # get the batch size
190        batch_size = len(batch)
191
192        # get the metrics values of the batch from the outputs
193        loss_cls_batch = outputs["loss_cls"]
194
195        # update the accumulated metrics in order to calculate the metrics of the epoch
196        self.loss_cls_test.update(loss_cls_batch, batch_size)

Accumulating metrics from test batch. We don't need to log and monitor the metrics of test batches.

Args:

  • outputs (dict[str, Any]): the outputs of the test step, which is the returns of the test_step() method in the STLAlgorithm.
  • batch (Any): the test data batch.
def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: clarena.stl_algorithms.STLAlgorithm) -> None:
198    def on_test_epoch_end(
199        self,
200        trainer: Trainer,
201        pl_module: STLAlgorithm,
202    ) -> None:
203        r"""Save and plot test metrics at the end of test."""
204
205        # save (update) the test metrics to CSV files
206        self.save_test_loss_cls_to_csv(
207            csv_path=self.test_loss_cls_csv_path,
208        )

Save and plot test metrics at the end of test.

def save_test_loss_cls_to_csv(self, csv_path: str) -> None:
210    def save_test_loss_cls_to_csv(
211        self,
212        csv_path: str,
213    ) -> None:
214        r"""Save the test classification loss metrics in single-task learning to an CSV file.
215
216        **Args:**
217        - **csv_path** (`str`): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.
218        """
219        fieldnames = ["loss_cls"]
220        new_line = {"loss_cls": self.loss_cls_test.compute().item()}
221
222        # write
223        with open(csv_path, "w", encoding="utf-8") as file:
224            writer = csv.DictWriter(file, fieldnames=fieldnames)
225            writer.writeheader()
226            writer.writerow(new_line)

Save the test classification loss metrics in single-task learning to an CSV file.

Args:

  • csv_path (str): save the test metric to path. E.g. './outputs/expr_name/1970-01-01_00-00-00/results/loss_cls.csv'.