clarena.stl_algorithms.single_learning

The submodule in stl_algorithms for single learning algorithm.

  1r"""
  2The submodule in `stl_algorithms` for single learning algorithm.
  3"""
  4
  5__all__ = ["SingleLearning"]
  6
  7import logging
  8from typing import Any
  9
 10from torch import Tensor
 11from torch.utils.data import DataLoader
 12
 13from clarena.backbones import Backbone
 14from clarena.heads import HeadSTL
 15from clarena.stl_algorithms import STLAlgorithm
 16
 17# always get logger for built-in logging in each module
 18pylogger = logging.getLogger(__name__)
 19
 20
 21class SingleLearning(STLAlgorithm):
 22    r"""Single learning algorithm.
 23
 24    The most naive way for single-task learning. It directly trains the task.
 25    """
 26
 27    def __init__(
 28        self,
 29        backbone: Backbone,
 30        head: HeadSTL,
 31        non_algorithmic_hparams: dict[str, Any] = {},
 32    ) -> None:
 33        r"""Initialize the SingleLearning algorithm with the network. It has no additional hyperparameters.
 34
 35        **Args:**
 36        - **backbone** (`Backbone`): backbone network.
 37        - **head** (`HeadSTL`): output head.
 38        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 39        """
 40        super().__init__(
 41            backbone=backbone,
 42            head=head,
 43            non_algorithmic_hparams=non_algorithmic_hparams,
 44        )
 45
 46    def training_step(self, batch: Any) -> dict[str, Tensor]:
 47        r"""Training step.
 48
 49        **Args:**
 50        - **batch** (`Any`): a batch of training data.
 51
 52        **Returns:**
 53        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 54        """
 55        x, y = batch
 56
 57        # classification loss
 58        logits, activations = self.forward(x, stage="train")
 59        loss_cls = self.criterion(logits, y)
 60
 61        # total loss
 62        loss = loss_cls
 63
 64        # predicted labels
 65        preds = logits.argmax(dim=1)
 66
 67        # accuracy of the batch
 68        acc = (preds == y).float().mean()
 69
 70        return {
 71            "preds": preds,
 72            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
 73            "loss_cls": loss_cls,
 74            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
 75            "activations": activations,
 76        }
 77
 78    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 79        r"""Validation step.
 80
 81        **Args:**
 82        - **batch** (`Any`): a batch of validation data.
 83
 84        **Returns:**
 85        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 86        """
 87
 88        x, y = batch
 89        logits, _ = self.forward(x, stage="validation")
 90        loss_cls = self.criterion(logits, y)
 91        preds = logits.argmax(dim=1)
 92        acc = (preds == y).float().mean()
 93
 94        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
 95        return {
 96            "preds": preds,
 97            "loss_cls": loss_cls,
 98            "acc": acc,
 99        }
100
101    def test_step(self, batch: DataLoader) -> dict[str, Tensor]:
102        r"""Test step.
103
104        **Args:**
105        - **batch** (`Any`): a batch of test data.
106
107        **Returns:**
108        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
109        """
110
111        x, y = batch
112        logits, _ = self.forward(x, stage="test")
113        loss_cls = self.criterion(logits, y)
114        preds = logits.argmax(dim=1)
115        acc = (preds == y).float().mean()
116
117        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
118        return {
119            "preds": preds,
120            "loss_cls": loss_cls,
121            "acc": acc,
122        }
class SingleLearning(clarena.stl_algorithms.base.STLAlgorithm):
 22class SingleLearning(STLAlgorithm):
 23    r"""Single learning algorithm.
 24
 25    The most naive way for single-task learning. It directly trains the task.
 26    """
 27
 28    def __init__(
 29        self,
 30        backbone: Backbone,
 31        head: HeadSTL,
 32        non_algorithmic_hparams: dict[str, Any] = {},
 33    ) -> None:
 34        r"""Initialize the SingleLearning algorithm with the network. It has no additional hyperparameters.
 35
 36        **Args:**
 37        - **backbone** (`Backbone`): backbone network.
 38        - **head** (`HeadSTL`): output head.
 39        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 40        """
 41        super().__init__(
 42            backbone=backbone,
 43            head=head,
 44            non_algorithmic_hparams=non_algorithmic_hparams,
 45        )
 46
 47    def training_step(self, batch: Any) -> dict[str, Tensor]:
 48        r"""Training step.
 49
 50        **Args:**
 51        - **batch** (`Any`): a batch of training data.
 52
 53        **Returns:**
 54        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 55        """
 56        x, y = batch
 57
 58        # classification loss
 59        logits, activations = self.forward(x, stage="train")
 60        loss_cls = self.criterion(logits, y)
 61
 62        # total loss
 63        loss = loss_cls
 64
 65        # predicted labels
 66        preds = logits.argmax(dim=1)
 67
 68        # accuracy of the batch
 69        acc = (preds == y).float().mean()
 70
 71        return {
 72            "preds": preds,
 73            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
 74            "loss_cls": loss_cls,
 75            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
 76            "activations": activations,
 77        }
 78
 79    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 80        r"""Validation step.
 81
 82        **Args:**
 83        - **batch** (`Any`): a batch of validation data.
 84
 85        **Returns:**
 86        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 87        """
 88
 89        x, y = batch
 90        logits, _ = self.forward(x, stage="validation")
 91        loss_cls = self.criterion(logits, y)
 92        preds = logits.argmax(dim=1)
 93        acc = (preds == y).float().mean()
 94
 95        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
 96        return {
 97            "preds": preds,
 98            "loss_cls": loss_cls,
 99            "acc": acc,
100        }
101
102    def test_step(self, batch: DataLoader) -> dict[str, Tensor]:
103        r"""Test step.
104
105        **Args:**
106        - **batch** (`Any`): a batch of test data.
107
108        **Returns:**
109        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
110        """
111
112        x, y = batch
113        logits, _ = self.forward(x, stage="test")
114        loss_cls = self.criterion(logits, y)
115        preds = logits.argmax(dim=1)
116        acc = (preds == y).float().mean()
117
118        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
119        return {
120            "preds": preds,
121            "loss_cls": loss_cls,
122            "acc": acc,
123        }

Single learning algorithm.

The most naive way for single-task learning. It directly trains the task.

SingleLearning( backbone: clarena.backbones.Backbone, head: clarena.heads.HeadSTL, non_algorithmic_hparams: dict[str, typing.Any] = {})
28    def __init__(
29        self,
30        backbone: Backbone,
31        head: HeadSTL,
32        non_algorithmic_hparams: dict[str, Any] = {},
33    ) -> None:
34        r"""Initialize the SingleLearning algorithm with the network. It has no additional hyperparameters.
35
36        **Args:**
37        - **backbone** (`Backbone`): backbone network.
38        - **head** (`HeadSTL`): output head.
39        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
40        """
41        super().__init__(
42            backbone=backbone,
43            head=head,
44            non_algorithmic_hparams=non_algorithmic_hparams,
45        )

Initialize the SingleLearning algorithm with the network. It has no additional hyperparameters.

Args:

  • backbone (Backbone): backbone network.
  • head (HeadSTL): output head.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
47    def training_step(self, batch: Any) -> dict[str, Tensor]:
48        r"""Training step.
49
50        **Args:**
51        - **batch** (`Any`): a batch of training data.
52
53        **Returns:**
54        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
55        """
56        x, y = batch
57
58        # classification loss
59        logits, activations = self.forward(x, stage="train")
60        loss_cls = self.criterion(logits, y)
61
62        # total loss
63        loss = loss_cls
64
65        # predicted labels
66        preds = logits.argmax(dim=1)
67
68        # accuracy of the batch
69        acc = (preds == y).float().mean()
70
71        return {
72            "preds": preds,
73            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
74            "loss_cls": loss_cls,
75            "acc": acc,  # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
76            "activations": activations,
77        }

Training step.

Args:

  • batch (Any): a batch of training data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and accuracy from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
 79    def validation_step(self, batch: Any) -> dict[str, Tensor]:
 80        r"""Validation step.
 81
 82        **Args:**
 83        - **batch** (`Any`): a batch of validation data.
 84
 85        **Returns:**
 86        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
 87        """
 88
 89        x, y = batch
 90        logits, _ = self.forward(x, stage="validation")
 91        loss_cls = self.criterion(logits, y)
 92        preds = logits.argmax(dim=1)
 93        acc = (preds == y).float().mean()
 94
 95        # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
 96        return {
 97            "preds": preds,
 98            "loss_cls": loss_cls,
 99            "acc": acc,
100        }

Validation step.

Args:

  • batch (Any): a batch of validation data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and accuracy from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader) -> dict[str, torch.Tensor]:
102    def test_step(self, batch: DataLoader) -> dict[str, Tensor]:
103        r"""Test step.
104
105        **Args:**
106        - **batch** (`Any`): a batch of test data.
107
108        **Returns:**
109        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
110        """
111
112        x, y = batch
113        logits, _ = self.forward(x, stage="test")
114        loss_cls = self.criterion(logits, y)
115        preds = logits.argmax(dim=1)
116        acc = (preds == y).float().mean()
117
118        # Return metrics for lightning loggers callback to handle at `on_test_batch_end()`
119        return {
120            "preds": preds,
121            "loss_cls": loss_cls,
122            "acc": acc,
123        }

Test step.

Args:

  • batch (Any): a batch of test data.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and accuracy from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.