clarena.stl_algorithms
Single-Task Learning Algorithms
This submodule provides the single-task learning algorithms in CLArena.
Here are the base classes for STL algorithms, which inherit from PyTorch Lightning LightningModule:
STLAlgorithm: the base class for all single-task learning algorithms.
Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement STL algorithms:
1r""" 2 3# Single-Task Learning Algorithms 4 5This submodule provides the **single-task learning algorithms** in CLArena. 6 7Here are the base classes for STL algorithms, which inherit from PyTorch Lightning `LightningModule`: 8 9- `STLAlgorithm`: the base class for all single-task learning algorithms. 10 11Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement STL algorithms: 12 13- [**Configure STL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/stl-algorithm) 14- [**Implement Custom STL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/stl-algorithm) 15 16""" 17 18from .base import STLAlgorithm 19 20from .single_learning import SingleLearning 21 22__all__ = ["STLAlgorithm", "single_learning"]
23class STLAlgorithm(LightningModule): 24 r"""The base class of single-task learning algorithms.""" 25 26 def __init__( 27 self, 28 backbone: Backbone, 29 head: HeadSTL, 30 non_algorithmic_hparams: dict[str, Any] = {}, 31 ) -> None: 32 r""" 33 **Args:** 34 - **backbone** (`Backbone`): backbone network. 35 - **head** (`HeadsSTL`): output head. 36 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 37 """ 38 super().__init__() 39 self.save_hyperparameters(non_algorithmic_hparams) 40 41 self.backbone: Backbone = backbone 42 r"""The backbone network.""" 43 self.head: HeadSTL = head 44 r"""The output head.""" 45 self.optimizer: Optimizer 46 r"""Optimizer (partially initialized). Will be equipped with parameters in `configure_optimizers()`.""" 47 self.lr_scheduler: LRScheduler | None 48 r"""The learning rate scheduler for the optimizer. If `None`, no scheduler is used.""" 49 self.criterion = nn.CrossEntropyLoss() 50 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 51 52 self.if_forward_func_return_logits_only: bool = False 53 r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`.""" 54 55 STLAlgorithm.sanity_check(self) 56 57 def sanity_check(self) -> None: 58 r"""Sanity check.""" 59 60 # check backbone and heads compatibility 61 if self.backbone.output_dim != self.head.input_dim: 62 raise ValueError( 63 "The output_dim of backbone network should be equal to the input_dim of STL head!" 64 ) 65 66 def setup_task( 67 self, 68 num_classes: int, 69 optimizer: Optimizer, 70 lr_scheduler: LRScheduler | None, 71 ) -> None: 72 r"""Setup the components for the STL algorithm. This must be done before `forward()` method is called. 73 74 **Args:** 75 - **num_classes** (`int`): the number of classes for the single-task learning. 76 - **optimizer** (`Optimizer`): the optimizer object (partially initialized). 77 - **lr_scheduler** (`LRScheduler` | None): the learning rate scheduler for the optimizer. If `None`, no scheduler is used. 78 """ 79 self.head.setup_task(num_classes=num_classes) 80 81 self.optimizer = optimizer 82 self.lr_scheduler = lr_scheduler 83 84 def forward(self, input: Tensor, stage: str) -> Tensor: 85 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many STL algorithm including the vanilla SingleLearning algorithm use. 86 87 **Args:** 88 - **input** (`Tensor`): The input tensor from data. 89 - **stage** (`str`): the stage of the forward pass; one of: 90 1. 'train': training stage. 91 2. 'validation': validation stage. 92 3. 'test': testing stage. 93 94 **Returns:** 95 - **logits** (`Tensor`): the output logits tensor. 96 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 97 """ 98 feature, activations = self.backbone(input, stage=stage) 99 logits = self.head(feature) 100 return ( 101 logits if self.if_forward_func_return_logits_only else (logits, activations) 102 ) 103 104 def configure_optimizers(self) -> Optimizer: 105 r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details.""" 106 # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `STLAlgorithm` (inherited from `LightningModule`) returns both backbone and head parameters 107 fully_initialized_optimizer = self.optimizer(params=self.parameters()) 108 109 if self.lr_scheduler: 110 fully_initialized_lr_scheduler = self.lr_scheduler( 111 optimizer=fully_initialized_optimizer 112 ) 113 114 return { 115 "optimizer": fully_initialized_optimizer, 116 "lr_scheduler": { 117 "scheduler": fully_initialized_lr_scheduler, 118 "monitor": "learning_curve/val/loss_cls", 119 "interval": "epoch", 120 "frequency": 1, 121 }, 122 } 123 124 return {"optimizer": fully_initialized_optimizer}
The base class of single-task learning algorithms.
26 def __init__( 27 self, 28 backbone: Backbone, 29 head: HeadSTL, 30 non_algorithmic_hparams: dict[str, Any] = {}, 31 ) -> None: 32 r""" 33 **Args:** 34 - **backbone** (`Backbone`): backbone network. 35 - **head** (`HeadsSTL`): output head. 36 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 37 """ 38 super().__init__() 39 self.save_hyperparameters(non_algorithmic_hparams) 40 41 self.backbone: Backbone = backbone 42 r"""The backbone network.""" 43 self.head: HeadSTL = head 44 r"""The output head.""" 45 self.optimizer: Optimizer 46 r"""Optimizer (partially initialized). Will be equipped with parameters in `configure_optimizers()`.""" 47 self.lr_scheduler: LRScheduler | None 48 r"""The learning rate scheduler for the optimizer. If `None`, no scheduler is used.""" 49 self.criterion = nn.CrossEntropyLoss() 50 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 51 52 self.if_forward_func_return_logits_only: bool = False 53 r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`.""" 54 55 STLAlgorithm.sanity_check(self)
Args:
- backbone (
Backbone): backbone network. - head (
HeadsSTL): output head. - non_algorithmic_hparams (
dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModuleobject from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()method. This is useful for the experiment configuration and reproducibility.
Optimizer (partially initialized). Will be equipped with parameters in configure_optimizers().
The learning rate scheduler for the optimizer. If None, no scheduler is used.
The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.
Whether the forward() method returns logits only. If False, it returns a dictionary containing logits and other information. Default is False.
57 def sanity_check(self) -> None: 58 r"""Sanity check.""" 59 60 # check backbone and heads compatibility 61 if self.backbone.output_dim != self.head.input_dim: 62 raise ValueError( 63 "The output_dim of backbone network should be equal to the input_dim of STL head!" 64 )
Sanity check.
66 def setup_task( 67 self, 68 num_classes: int, 69 optimizer: Optimizer, 70 lr_scheduler: LRScheduler | None, 71 ) -> None: 72 r"""Setup the components for the STL algorithm. This must be done before `forward()` method is called. 73 74 **Args:** 75 - **num_classes** (`int`): the number of classes for the single-task learning. 76 - **optimizer** (`Optimizer`): the optimizer object (partially initialized). 77 - **lr_scheduler** (`LRScheduler` | None): the learning rate scheduler for the optimizer. If `None`, no scheduler is used. 78 """ 79 self.head.setup_task(num_classes=num_classes) 80 81 self.optimizer = optimizer 82 self.lr_scheduler = lr_scheduler
Setup the components for the STL algorithm. This must be done before forward() method is called.
Args:
- num_classes (
int): the number of classes for the single-task learning. - optimizer (
Optimizer): the optimizer object (partially initialized). - lr_scheduler (
LRScheduler| None): the learning rate scheduler for the optimizer. IfNone, no scheduler is used.
84 def forward(self, input: Tensor, stage: str) -> Tensor: 85 r"""The forward pass for data from task `task_id`. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many STL algorithm including the vanilla SingleLearning algorithm use. 86 87 **Args:** 88 - **input** (`Tensor`): The input tensor from data. 89 - **stage** (`str`): the stage of the forward pass; one of: 90 1. 'train': training stage. 91 2. 'validation': validation stage. 92 3. 'test': testing stage. 93 94 **Returns:** 95 - **logits** (`Tensor`): the output logits tensor. 96 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 97 """ 98 feature, activations = self.backbone(input, stage=stage) 99 logits = self.head(feature) 100 return ( 101 logits if self.if_forward_func_return_logits_only else (logits, activations) 102 )
The forward pass for data from task task_id. Note that it is nothing to do with forward() method in nn.Module. This definition provides a template that many STL algorithm including the vanilla SingleLearning algorithm use.
Args:
- input (
Tensor): The input tensor from data. - stage (
str): the stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
Returns:
- logits (
Tensor): the output logits tensor. - activations (
dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
104 def configure_optimizers(self) -> Optimizer: 105 r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details.""" 106 # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `STLAlgorithm` (inherited from `LightningModule`) returns both backbone and head parameters 107 fully_initialized_optimizer = self.optimizer(params=self.parameters()) 108 109 if self.lr_scheduler: 110 fully_initialized_lr_scheduler = self.lr_scheduler( 111 optimizer=fully_initialized_optimizer 112 ) 113 114 return { 115 "optimizer": fully_initialized_optimizer, 116 "lr_scheduler": { 117 "scheduler": fully_initialized_lr_scheduler, 118 "monitor": "learning_curve/val/loss_cls", 119 "interval": "epoch", 120 "frequency": 1, 121 }, 122 } 123 124 return {"optimizer": fully_initialized_optimizer}
Configure optimizer hooks by Lightning. See Lightning docs for more details.