clarena.stl_algorithms.single_learning
The submodule in stl_algorithms
for single learning algorithm.
1r""" 2The submodule in `stl_algorithms` for single learning algorithm. 3""" 4 5__all__ = ["SingleLearning"] 6 7import logging 8from typing import Any 9 10from torch import Tensor 11from torch.utils.data import DataLoader 12 13from clarena.backbones import Backbone 14from clarena.heads import HeadSTL 15from clarena.stl_algorithms import STLAlgorithm 16 17# always get logger for built-in logging in each module 18pylogger = logging.getLogger(__name__) 19 20 21class SingleLearning(STLAlgorithm): 22 r"""Single learning algorithm. 23 24 The most naive way for single-task learning. It directly trains the task. 25 """ 26 27 def __init__( 28 self, 29 backbone: Backbone, 30 head: HeadSTL, 31 non_algorithmic_hparams: dict[str, Any] = {}, 32 ) -> None: 33 r"""Initialize the SingleLearning algorithm with the network. It has no additional hyperparameters. 34 35 **Args:** 36 - **backbone** (`Backbone`): backbone network. 37 - **head** (`HeadSTL`): output head. 38 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 39 """ 40 super().__init__( 41 backbone=backbone, 42 head=head, 43 non_algorithmic_hparams=non_algorithmic_hparams, 44 ) 45 46 def training_step(self, batch: Any) -> dict[str, Tensor]: 47 r"""Training step. 48 49 **Args:** 50 - **batch** (`Any`): a batch of training data. 51 52 **Returns:** 53 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 54 """ 55 x, y = batch 56 57 # classification loss 58 logits, activations = self.forward(x, stage="train") 59 loss_cls = self.criterion(logits, y) 60 61 # total loss 62 loss = loss_cls 63 64 # predicted labels 65 preds = logits.argmax(dim=1) 66 67 # accuracy of the batch 68 acc = (preds == y).float().mean() 69 70 return { 71 "preds": preds, 72 "loss": loss, # return loss is essential for training step, or backpropagation will fail 73 "loss_cls": loss_cls, 74 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 75 "activations": activations, 76 } 77 78 def validation_step(self, batch: Any) -> dict[str, Tensor]: 79 r"""Validation step. 80 81 **Args:** 82 - **batch** (`Any`): a batch of validation data. 83 84 **Returns:** 85 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 86 """ 87 88 x, y = batch 89 logits, _ = self.forward(x, stage="validation") 90 loss_cls = self.criterion(logits, y) 91 preds = logits.argmax(dim=1) 92 acc = (preds == y).float().mean() 93 94 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 95 return { 96 "preds": preds, 97 "loss_cls": loss_cls, 98 "acc": acc, 99 } 100 101 def test_step(self, batch: DataLoader) -> dict[str, Tensor]: 102 r"""Test step. 103 104 **Args:** 105 - **batch** (`Any`): a batch of test data. 106 107 **Returns:** 108 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 109 """ 110 111 x, y = batch 112 logits, _ = self.forward(x, stage="test") 113 loss_cls = self.criterion(logits, y) 114 preds = logits.argmax(dim=1) 115 acc = (preds == y).float().mean() 116 117 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 118 return { 119 "preds": preds, 120 "loss_cls": loss_cls, 121 "acc": acc, 122 }
class
SingleLearning(clarena.stl_algorithms.base.STLAlgorithm):
22class SingleLearning(STLAlgorithm): 23 r"""Single learning algorithm. 24 25 The most naive way for single-task learning. It directly trains the task. 26 """ 27 28 def __init__( 29 self, 30 backbone: Backbone, 31 head: HeadSTL, 32 non_algorithmic_hparams: dict[str, Any] = {}, 33 ) -> None: 34 r"""Initialize the SingleLearning algorithm with the network. It has no additional hyperparameters. 35 36 **Args:** 37 - **backbone** (`Backbone`): backbone network. 38 - **head** (`HeadSTL`): output head. 39 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 40 """ 41 super().__init__( 42 backbone=backbone, 43 head=head, 44 non_algorithmic_hparams=non_algorithmic_hparams, 45 ) 46 47 def training_step(self, batch: Any) -> dict[str, Tensor]: 48 r"""Training step. 49 50 **Args:** 51 - **batch** (`Any`): a batch of training data. 52 53 **Returns:** 54 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 55 """ 56 x, y = batch 57 58 # classification loss 59 logits, activations = self.forward(x, stage="train") 60 loss_cls = self.criterion(logits, y) 61 62 # total loss 63 loss = loss_cls 64 65 # predicted labels 66 preds = logits.argmax(dim=1) 67 68 # accuracy of the batch 69 acc = (preds == y).float().mean() 70 71 return { 72 "preds": preds, 73 "loss": loss, # return loss is essential for training step, or backpropagation will fail 74 "loss_cls": loss_cls, 75 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 76 "activations": activations, 77 } 78 79 def validation_step(self, batch: Any) -> dict[str, Tensor]: 80 r"""Validation step. 81 82 **Args:** 83 - **batch** (`Any`): a batch of validation data. 84 85 **Returns:** 86 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 87 """ 88 89 x, y = batch 90 logits, _ = self.forward(x, stage="validation") 91 loss_cls = self.criterion(logits, y) 92 preds = logits.argmax(dim=1) 93 acc = (preds == y).float().mean() 94 95 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 96 return { 97 "preds": preds, 98 "loss_cls": loss_cls, 99 "acc": acc, 100 } 101 102 def test_step(self, batch: DataLoader) -> dict[str, Tensor]: 103 r"""Test step. 104 105 **Args:** 106 - **batch** (`Any`): a batch of test data. 107 108 **Returns:** 109 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 110 """ 111 112 x, y = batch 113 logits, _ = self.forward(x, stage="test") 114 loss_cls = self.criterion(logits, y) 115 preds = logits.argmax(dim=1) 116 acc = (preds == y).float().mean() 117 118 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 119 return { 120 "preds": preds, 121 "loss_cls": loss_cls, 122 "acc": acc, 123 }
Single learning algorithm.
The most naive way for single-task learning. It directly trains the task.
SingleLearning( backbone: clarena.backbones.Backbone, head: clarena.heads.HeadSTL, non_algorithmic_hparams: dict[str, typing.Any] = {})
28 def __init__( 29 self, 30 backbone: Backbone, 31 head: HeadSTL, 32 non_algorithmic_hparams: dict[str, Any] = {}, 33 ) -> None: 34 r"""Initialize the SingleLearning algorithm with the network. It has no additional hyperparameters. 35 36 **Args:** 37 - **backbone** (`Backbone`): backbone network. 38 - **head** (`HeadSTL`): output head. 39 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 40 """ 41 super().__init__( 42 backbone=backbone, 43 head=head, 44 non_algorithmic_hparams=non_algorithmic_hparams, 45 )
Initialize the SingleLearning algorithm with the network. It has no additional hyperparameters.
Args:
- backbone (
Backbone
): backbone network. - head (
HeadSTL
): output head. - non_algorithmic_hparams (
dict[str, Any]
): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModule
object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()
method. This is useful for the experiment configuration and reproducibility.
def
training_step(self, batch: Any) -> dict[str, torch.Tensor]:
47 def training_step(self, batch: Any) -> dict[str, Tensor]: 48 r"""Training step. 49 50 **Args:** 51 - **batch** (`Any`): a batch of training data. 52 53 **Returns:** 54 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 55 """ 56 x, y = batch 57 58 # classification loss 59 logits, activations = self.forward(x, stage="train") 60 loss_cls = self.criterion(logits, y) 61 62 # total loss 63 loss = loss_cls 64 65 # predicted labels 66 preds = logits.argmax(dim=1) 67 68 # accuracy of the batch 69 acc = (preds == y).float().mean() 70 71 return { 72 "preds": preds, 73 "loss": loss, # return loss is essential for training step, or backpropagation will fail 74 "loss_cls": loss_cls, 75 "acc": acc, # Return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 76 "activations": activations, 77 }
Training step.
Args:
- batch (
Any
): a batch of training data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and accuracy from this training step. Keys (str
) are the metrics names, and values (Tensor
) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def
validation_step(self, batch: Any) -> dict[str, torch.Tensor]:
79 def validation_step(self, batch: Any) -> dict[str, Tensor]: 80 r"""Validation step. 81 82 **Args:** 83 - **batch** (`Any`): a batch of validation data. 84 85 **Returns:** 86 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 87 """ 88 89 x, y = batch 90 logits, _ = self.forward(x, stage="validation") 91 loss_cls = self.criterion(logits, y) 92 preds = logits.argmax(dim=1) 93 acc = (preds == y).float().mean() 94 95 # Return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 96 return { 97 "preds": preds, 98 "loss_cls": loss_cls, 99 "acc": acc, 100 }
Validation step.
Args:
- batch (
Any
): a batch of validation data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and accuracy from this validation step. Keys (str
) are the metrics names, and values (Tensor
) are the metrics.
def
test_step( self, batch: torch.utils.data.dataloader.DataLoader) -> dict[str, torch.Tensor]:
102 def test_step(self, batch: DataLoader) -> dict[str, Tensor]: 103 r"""Test step. 104 105 **Args:** 106 - **batch** (`Any`): a batch of test data. 107 108 **Returns:** 109 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 110 """ 111 112 x, y = batch 113 logits, _ = self.forward(x, stage="test") 114 loss_cls = self.criterion(logits, y) 115 preds = logits.argmax(dim=1) 116 acc = (preds == y).float().mean() 117 118 # Return metrics for lightning loggers callback to handle at `on_test_batch_end()` 119 return { 120 "preds": preds, 121 "loss_cls": loss_cls, 122 "acc": acc, 123 }
Test step.
Args:
- batch (
Any
): a batch of test data.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and accuracy from this test step. Keys (str
) are the metrics names, and values (Tensor
) are the metrics.