clarena.stl_algorithms

Single-Task Learning Algorithms

This submodule provides the single-task learning algorithms in CLArena.

Here are the base classes for STL algorithms, which inherit from PyTorch Lightning LightningModule:

  • STLAlgorithm: the base class for all single-task learning algorithms.

Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement STL algorithms:

 1r"""
 2
 3# Single-Task Learning Algorithms
 4
 5This submodule provides the **single-task learning algorithms** in CLArena.
 6
 7Here are the base classes for STL algorithms, which inherit from PyTorch Lightning `LightningModule`:
 8
 9- `STLAlgorithm`: the base class for all single-task learning algorithms.
10
11Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement STL algorithms:
12
13- [**Configure STL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/stl-algorithm)
14- [**Implement Custom STL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/stl-algorithm)
15
16"""
17
18from .base import STLAlgorithm
19
20from .single_learning import SingleLearning
21
22__all__ = ["STLAlgorithm", "single_learning"]
class STLAlgorithm(lightning.pytorch.core.module.LightningModule):
 23class STLAlgorithm(LightningModule):
 24    r"""The base class of single-task learning algorithms."""
 25
 26    def __init__(
 27        self,
 28        backbone: Backbone,
 29        head: HeadSTL,
 30        non_algorithmic_hparams: dict[str, Any] = {},
 31    ) -> None:
 32        r"""
 33        **Args:**
 34        - **backbone** (`Backbone`): backbone network.
 35        - **head** (`HeadsSTL`): output head.
 36        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 37        """
 38        super().__init__()
 39        self.save_hyperparameters(non_algorithmic_hparams)
 40
 41        self.backbone: Backbone = backbone
 42        r"""The backbone network."""
 43        self.head: HeadSTL = head
 44        r"""The output head."""
 45        self.optimizer: Optimizer
 46        r"""Optimizer (partially initialized). Will be equipped with parameters in `configure_optimizers()`."""
 47        self.lr_scheduler: LRScheduler | None
 48        r"""The learning rate scheduler for the optimizer. If `None`, no scheduler is used."""
 49        self.criterion = nn.CrossEntropyLoss()
 50        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
 51
 52        self.if_forward_func_return_logits_only: bool = False
 53        r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`."""
 54
 55        STLAlgorithm.sanity_check(self)
 56
 57    def sanity_check(self) -> None:
 58        r"""Sanity check."""
 59
 60        # check backbone and heads compatibility
 61        if self.backbone.output_dim != self.head.input_dim:
 62            raise ValueError(
 63                "The output_dim of backbone network should be equal to the input_dim of STL head!"
 64            )
 65
 66    def setup_task(
 67        self,
 68        num_classes: int,
 69        optimizer: Optimizer,
 70        lr_scheduler: LRScheduler | None,
 71    ) -> None:
 72        r"""Setup the components for the STL algorithm. This must be done before `forward()` method is called.
 73
 74        **Args:**
 75        - **num_classes** (`int`): the number of classes for the single-task learning.
 76        - **optimizer** (`Optimizer`): the optimizer object (partially initialized).
 77        - **lr_scheduler** (`LRScheduler` | None): the learning rate scheduler for the optimizer. If `None`, no scheduler is used.
 78        """
 79        self.head.setup_task(num_classes=num_classes)
 80
 81        self.optimizer = optimizer
 82        self.lr_scheduler = lr_scheduler
 83
 84    def forward(self, input: Tensor, stage: str) -> Tensor:
 85        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many STL algorithm including the vanilla SingleLearning algorithm use.
 86
 87        **Args:**
 88        - **input** (`Tensor`): The input tensor from data.
 89        - **stage** (`str`): the stage of the forward pass; one of:
 90            1. 'train': training stage.
 91            2. 'validation': validation stage.
 92            3. 'test': testing stage.
 93
 94        **Returns:**
 95        - **logits** (`Tensor`): the output logits tensor.
 96        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
 97        """
 98        feature, activations = self.backbone(input, stage=stage)
 99        logits = self.head(feature)
100        return (
101            logits if self.if_forward_func_return_logits_only else (logits, activations)
102        )
103
104    def configure_optimizers(self) -> Optimizer:
105        r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details."""
106        # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `STLAlgorithm` (inherited from `LightningModule`) returns both backbone and head parameters
107        fully_initialized_optimizer = self.optimizer(params=self.parameters())
108
109        if self.lr_scheduler:
110            fully_initialized_lr_scheduler = self.lr_scheduler(
111                optimizer=fully_initialized_optimizer
112            )
113
114            return {
115                "optimizer": fully_initialized_optimizer,
116                "lr_scheduler": {
117                    "scheduler": fully_initialized_lr_scheduler,
118                    "monitor": "learning_curve/val/loss_cls",
119                    "interval": "epoch",
120                    "frequency": 1,
121                },
122            }
123
124        return {"optimizer": fully_initialized_optimizer}

The base class of single-task learning algorithms.

STLAlgorithm( backbone: clarena.backbones.Backbone, head: clarena.heads.HeadSTL, non_algorithmic_hparams: dict[str, typing.Any] = {})
26    def __init__(
27        self,
28        backbone: Backbone,
29        head: HeadSTL,
30        non_algorithmic_hparams: dict[str, Any] = {},
31    ) -> None:
32        r"""
33        **Args:**
34        - **backbone** (`Backbone`): backbone network.
35        - **head** (`HeadsSTL`): output head.
36        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
37        """
38        super().__init__()
39        self.save_hyperparameters(non_algorithmic_hparams)
40
41        self.backbone: Backbone = backbone
42        r"""The backbone network."""
43        self.head: HeadSTL = head
44        r"""The output head."""
45        self.optimizer: Optimizer
46        r"""Optimizer (partially initialized). Will be equipped with parameters in `configure_optimizers()`."""
47        self.lr_scheduler: LRScheduler | None
48        r"""The learning rate scheduler for the optimizer. If `None`, no scheduler is used."""
49        self.criterion = nn.CrossEntropyLoss()
50        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
51
52        self.if_forward_func_return_logits_only: bool = False
53        r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`."""
54
55        STLAlgorithm.sanity_check(self)

Args:

  • backbone (Backbone): backbone network.
  • head (HeadsSTL): output head.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.

The backbone network.

The output head.

optimizer: torch.optim.optimizer.Optimizer

Optimizer (partially initialized). Will be equipped with parameters in configure_optimizers().

lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None

The learning rate scheduler for the optimizer. If None, no scheduler is used.

criterion

The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.

if_forward_func_return_logits_only: bool

Whether the forward() method returns logits only. If False, it returns a dictionary containing logits and other information. Default is False.

def sanity_check(self) -> None:
57    def sanity_check(self) -> None:
58        r"""Sanity check."""
59
60        # check backbone and heads compatibility
61        if self.backbone.output_dim != self.head.input_dim:
62            raise ValueError(
63                "The output_dim of backbone network should be equal to the input_dim of STL head!"
64            )

Sanity check.

def setup_task( self, num_classes: int, optimizer: torch.optim.optimizer.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None) -> None:
66    def setup_task(
67        self,
68        num_classes: int,
69        optimizer: Optimizer,
70        lr_scheduler: LRScheduler | None,
71    ) -> None:
72        r"""Setup the components for the STL algorithm. This must be done before `forward()` method is called.
73
74        **Args:**
75        - **num_classes** (`int`): the number of classes for the single-task learning.
76        - **optimizer** (`Optimizer`): the optimizer object (partially initialized).
77        - **lr_scheduler** (`LRScheduler` | None): the learning rate scheduler for the optimizer. If `None`, no scheduler is used.
78        """
79        self.head.setup_task(num_classes=num_classes)
80
81        self.optimizer = optimizer
82        self.lr_scheduler = lr_scheduler

Setup the components for the STL algorithm. This must be done before forward() method is called.

Args:

  • num_classes (int): the number of classes for the single-task learning.
  • optimizer (Optimizer): the optimizer object (partially initialized).
  • lr_scheduler (LRScheduler | None): the learning rate scheduler for the optimizer. If None, no scheduler is used.
def forward(self, input: torch.Tensor, stage: str) -> torch.Tensor:
 84    def forward(self, input: Tensor, stage: str) -> Tensor:
 85        r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many STL algorithm including the vanilla SingleLearning algorithm use.
 86
 87        **Args:**
 88        - **input** (`Tensor`): The input tensor from data.
 89        - **stage** (`str`): the stage of the forward pass; one of:
 90            1. 'train': training stage.
 91            2. 'validation': validation stage.
 92            3. 'test': testing stage.
 93
 94        **Returns:**
 95        - **logits** (`Tensor`): the output logits tensor.
 96        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
 97        """
 98        feature, activations = self.backbone(input, stage=stage)
 99        logits = self.head(feature)
100        return (
101            logits if self.if_forward_func_return_logits_only else (logits, activations)
102        )

The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module. This definition provides a template that many STL algorithm including the vanilla SingleLearning algorithm use.

Args:

  • input (Tensor): The input tensor from data.
  • stage (str): the stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.

Returns:

  • logits (Tensor): the output logits tensor.
  • activations (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
def configure_optimizers(self) -> torch.optim.optimizer.Optimizer:
104    def configure_optimizers(self) -> Optimizer:
105        r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details."""
106        # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `STLAlgorithm` (inherited from `LightningModule`) returns both backbone and head parameters
107        fully_initialized_optimizer = self.optimizer(params=self.parameters())
108
109        if self.lr_scheduler:
110            fully_initialized_lr_scheduler = self.lr_scheduler(
111                optimizer=fully_initialized_optimizer
112            )
113
114            return {
115                "optimizer": fully_initialized_optimizer,
116                "lr_scheduler": {
117                    "scheduler": fully_initialized_lr_scheduler,
118                    "monitor": "learning_curve/val/loss_cls",
119                    "interval": "epoch",
120                    "frequency": 1,
121                },
122            }
123
124        return {"optimizer": fully_initialized_optimizer}

Configure optimizer hooks by Lightning. See Lightning docs for more details.