clarena.mtl_algorithms
Multi-Task Learning Algorithms
This submodule provides the multi-task learning algorithms in CLArena.
Here are the base classes for MTL algorithms, which inherit from PyTorch Lightning LightningModule
:
MTLAlgorithm
: the base class for all multi-task learning algorithms.
Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement MTL algorithms:
1r""" 2 3# Multi-Task Learning Algorithms 4 5This submodule provides the **multi-task learning algorithms** in CLArena. 6 7Here are the base classes for MTL algorithms, which inherit from PyTorch Lightning `LightningModule`: 8 9- `MTLAlgorithm`: the base class for all multi-task learning algorithms. 10 11 12Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement MTL algorithms: 13 14- [**Configure MTL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/mtl-algorithm) 15- [**Implement Custom MTL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/mtl-algorithm) 16 17""" 18 19from .base import MTLAlgorithm 20 21from .joint_learning import JointLearning 22 23__all__ = ["MTLAlgorithm", "joint_learning"]
23class MTLAlgorithm(LightningModule): 24 r"""The base class of multi-task learning algorithms.""" 25 26 def __init__( 27 self, 28 backbone: Backbone, 29 heads: HeadsMTL, 30 non_algorithmic_hparams: dict[str, Any] = {}, 31 ) -> None: 32 r""" 33 **Args:** 34 - **backbone** (`Backbone`): backbone network. 35 - **heads** (`HeadsMTL`): output heads. 36 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 37 """ 38 super().__init__() 39 self.save_hyperparameters(non_algorithmic_hparams) 40 41 # components 42 self.backbone: Backbone = backbone 43 r"""The backbone network.""" 44 self.heads: HeadsMTL = heads 45 r"""The output heads.""" 46 self.optimizer: Optimizer 47 r"""Optimizer (partially initialized) for the backpropagation. Will be equipped with parameters in `configure_optimizers()`.""" 48 self.lr_scheduler: LRScheduler | None 49 r"""The learning rate scheduler for the optimizer. If `None`, no scheduler is used.""" 50 self.criterion = nn.CrossEntropyLoss() 51 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 52 53 self.if_forward_func_return_logits_only: bool = False 54 r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`.""" 55 56 MTLAlgorithm.sanity_check(self) 57 58 def sanity_check(self) -> None: 59 r"""Sanity check.""" 60 61 # check backbone and heads compatibility 62 if self.backbone.output_dim != self.heads.input_dim: 63 raise ValueError( 64 "The output_dim of backbone network should be equal to the input_dim of MTL heads!" 65 ) 66 67 def setup_tasks( 68 self, 69 task_ids: list[int], 70 num_classes: dict[int, int], 71 optimizer: Optimizer, 72 lr_scheduler: LRScheduler | None, 73 ) -> None: 74 r"""Set up tasks for the MTL algorithm. This must be done before `forward()` method is called. 75 76 **Args:** 77 - **task_ids** (`list[int]`): the list of task IDs. 78 - **num_classes** (`dict[int, int]`): a dictionary mapping each task ID to its number of classes. 79 - **optimizer** (`Optimizer`): the optimizer object (partially initialized). 80 - **lr_scheduler** (`LRScheduler` | None): the learning rate scheduler for the optimizer. If `None`, no scheduler is used. 81 """ 82 self.heads.setup_tasks(task_ids=task_ids, num_classes=num_classes) 83 self.optimizer = optimizer 84 self.lr_scheduler = lr_scheduler 85 86 def get_val_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 87 r"""Get the validation task ID from the dataloader index. 88 89 **Args:** 90 - **dataloader_idx** (`int`): the dataloader index. 91 92 **Returns:** 93 - **val_task_id** (`int`): the validation task ID. 94 """ 95 dataset_val = self.trainer.datamodule.dataset_val 96 val_task_id = list(dataset_val.keys())[dataloader_idx] 97 return val_task_id 98 99 def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 100 r"""Get the test task ID from the dataloader index. 101 102 **Args:** 103 - **dataloader_idx** (`int`): the dataloader index. 104 105 **Returns:** 106 - **test_task_id** (`int`): the test task ID. 107 """ 108 dataset_test = self.trainer.datamodule.dataset_test 109 test_task_id = list(dataset_test.keys())[dataloader_idx] 110 return test_task_id 111 112 def set_forward_func_return_logits_only( 113 self, forward_func_return_logits_only: bool 114 ) -> None: 115 r"""Set whether the `forward()` method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT. 116 117 **Args:** 118 - **forward_func_return_logits_only** (`bool`): whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. 119 """ 120 self.if_forward_func_return_logits_only = forward_func_return_logits_only 121 122 def forward(self, input: Tensor, task_ids: int | Tensor, stage: str) -> Tensor: 123 r"""The forward pass for data. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many MTL algorithm including the vanilla JointLearning algorithm use. 124 125 This forward pass does not accept input batch in different tasks. Please make sure the input batch is from the same task. If you want to use this forward pass for different tasks, please divide the input batch by tasks and call this forward pass for each task separately. 126 127 **Args:** 128 - **input** (`Tensor`): The input tensor from data. 129 - **task_ids** (`int` | `Tensor`): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer. 130 - **stage** (`str`): the stage of the forward pass; one of: 131 1. 'train': training stage. 132 2. 'validation': validation stage. 133 3. 'test': testing stage. 134 135 **Returns:** 136 - **logits** (`Tensor`): the output logits tensor. 137 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 138 """ 139 feature, activations = self.backbone(input, stage=stage) 140 logits = self.heads(feature, task_ids) 141 return ( 142 logits if self.if_forward_func_return_logits_only else (logits, activations) 143 ) 144 145 def configure_optimizers(self) -> Optimizer: 146 r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details.""" 147 # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `MTLAlgorithm` (inherited from `LightningModule`) returns both backbone and heads parameters 148 fully_initialized_optimizer = self.optimizer(params=self.parameters()) 149 150 if self.lr_scheduler: 151 fully_initialized_lr_scheduler = self.lr_scheduler( 152 optimizer=fully_initialized_optimizer 153 ) 154 155 return { 156 "optimizer": fully_initialized_optimizer, 157 "lr_scheduler": { 158 "scheduler": fully_initialized_lr_scheduler, 159 "monitor": "learning_curve/val/loss_cls", 160 "interval": "epoch", 161 "frequency": 1, 162 }, 163 } 164 165 return {"optimizer": fully_initialized_optimizer}
The base class of multi-task learning algorithms.
26 def __init__( 27 self, 28 backbone: Backbone, 29 heads: HeadsMTL, 30 non_algorithmic_hparams: dict[str, Any] = {}, 31 ) -> None: 32 r""" 33 **Args:** 34 - **backbone** (`Backbone`): backbone network. 35 - **heads** (`HeadsMTL`): output heads. 36 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 37 """ 38 super().__init__() 39 self.save_hyperparameters(non_algorithmic_hparams) 40 41 # components 42 self.backbone: Backbone = backbone 43 r"""The backbone network.""" 44 self.heads: HeadsMTL = heads 45 r"""The output heads.""" 46 self.optimizer: Optimizer 47 r"""Optimizer (partially initialized) for the backpropagation. Will be equipped with parameters in `configure_optimizers()`.""" 48 self.lr_scheduler: LRScheduler | None 49 r"""The learning rate scheduler for the optimizer. If `None`, no scheduler is used.""" 50 self.criterion = nn.CrossEntropyLoss() 51 r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.""" 52 53 self.if_forward_func_return_logits_only: bool = False 54 r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`.""" 55 56 MTLAlgorithm.sanity_check(self)
Args:
- backbone (
Backbone
): backbone network. - heads (
HeadsMTL
): output heads. - non_algorithmic_hparams (
dict[str, Any]
): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModule
object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()
method. This is useful for the experiment configuration and reproducibility.
Optimizer (partially initialized) for the backpropagation. Will be equipped with parameters in configure_optimizers()
.
The learning rate scheduler for the optimizer. If None
, no scheduler is used.
The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.
Whether the forward()
method returns logits only. If False
, it returns a dictionary containing logits and other information. Default is False
.
58 def sanity_check(self) -> None: 59 r"""Sanity check.""" 60 61 # check backbone and heads compatibility 62 if self.backbone.output_dim != self.heads.input_dim: 63 raise ValueError( 64 "The output_dim of backbone network should be equal to the input_dim of MTL heads!" 65 )
Sanity check.
67 def setup_tasks( 68 self, 69 task_ids: list[int], 70 num_classes: dict[int, int], 71 optimizer: Optimizer, 72 lr_scheduler: LRScheduler | None, 73 ) -> None: 74 r"""Set up tasks for the MTL algorithm. This must be done before `forward()` method is called. 75 76 **Args:** 77 - **task_ids** (`list[int]`): the list of task IDs. 78 - **num_classes** (`dict[int, int]`): a dictionary mapping each task ID to its number of classes. 79 - **optimizer** (`Optimizer`): the optimizer object (partially initialized). 80 - **lr_scheduler** (`LRScheduler` | None): the learning rate scheduler for the optimizer. If `None`, no scheduler is used. 81 """ 82 self.heads.setup_tasks(task_ids=task_ids, num_classes=num_classes) 83 self.optimizer = optimizer 84 self.lr_scheduler = lr_scheduler
Set up tasks for the MTL algorithm. This must be done before forward()
method is called.
Args:
- task_ids (
list[int]
): the list of task IDs. - num_classes (
dict[int, int]
): a dictionary mapping each task ID to its number of classes. - optimizer (
Optimizer
): the optimizer object (partially initialized). - lr_scheduler (
LRScheduler
| None): the learning rate scheduler for the optimizer. IfNone
, no scheduler is used.
86 def get_val_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 87 r"""Get the validation task ID from the dataloader index. 88 89 **Args:** 90 - **dataloader_idx** (`int`): the dataloader index. 91 92 **Returns:** 93 - **val_task_id** (`int`): the validation task ID. 94 """ 95 dataset_val = self.trainer.datamodule.dataset_val 96 val_task_id = list(dataset_val.keys())[dataloader_idx] 97 return val_task_id
Get the validation task ID from the dataloader index.
Args:
- dataloader_idx (
int
): the dataloader index.
Returns:
- val_task_id (
int
): the validation task ID.
99 def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int: 100 r"""Get the test task ID from the dataloader index. 101 102 **Args:** 103 - **dataloader_idx** (`int`): the dataloader index. 104 105 **Returns:** 106 - **test_task_id** (`int`): the test task ID. 107 """ 108 dataset_test = self.trainer.datamodule.dataset_test 109 test_task_id = list(dataset_test.keys())[dataloader_idx] 110 return test_task_id
Get the test task ID from the dataloader index.
Args:
- dataloader_idx (
int
): the dataloader index.
Returns:
- test_task_id (
int
): the test task ID.
112 def set_forward_func_return_logits_only( 113 self, forward_func_return_logits_only: bool 114 ) -> None: 115 r"""Set whether the `forward()` method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT. 116 117 **Args:** 118 - **forward_func_return_logits_only** (`bool`): whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. 119 """ 120 self.if_forward_func_return_logits_only = forward_func_return_logits_only
Set whether the forward()
method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT.
Args:
- forward_func_return_logits_only (
bool
): whether theforward()
method returns logits only. IfFalse
, it returns a dictionary containing logits and other information.
122 def forward(self, input: Tensor, task_ids: int | Tensor, stage: str) -> Tensor: 123 r"""The forward pass for data. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many MTL algorithm including the vanilla JointLearning algorithm use. 124 125 This forward pass does not accept input batch in different tasks. Please make sure the input batch is from the same task. If you want to use this forward pass for different tasks, please divide the input batch by tasks and call this forward pass for each task separately. 126 127 **Args:** 128 - **input** (`Tensor`): The input tensor from data. 129 - **task_ids** (`int` | `Tensor`): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer. 130 - **stage** (`str`): the stage of the forward pass; one of: 131 1. 'train': training stage. 132 2. 'validation': validation stage. 133 3. 'test': testing stage. 134 135 **Returns:** 136 - **logits** (`Tensor`): the output logits tensor. 137 - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes. 138 """ 139 feature, activations = self.backbone(input, stage=stage) 140 logits = self.heads(feature, task_ids) 141 return ( 142 logits if self.if_forward_func_return_logits_only else (logits, activations) 143 )
The forward pass for data. Note that it is nothing to do with forward()
method in nn.Module
. This definition provides a template that many MTL algorithm including the vanilla JointLearning algorithm use.
This forward pass does not accept input batch in different tasks. Please make sure the input batch is from the same task. If you want to use this forward pass for different tasks, please divide the input batch by tasks and call this forward pass for each task separately.
Args:
- input (
Tensor
): The input tensor from data. - task_ids (
int
|Tensor
): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer. - stage (
str
): the stage of the forward pass; one of:- 'train': training stage.
- 'validation': validation stage.
- 'test': testing stage.
Returns:
- logits (
Tensor
): the output logits tensor. - activations (
dict[str, Tensor]
): the hidden features (after activation) in each weighted layer. Key (str
) is the weighted layer name, value (Tensor
) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
145 def configure_optimizers(self) -> Optimizer: 146 r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details.""" 147 # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `MTLAlgorithm` (inherited from `LightningModule`) returns both backbone and heads parameters 148 fully_initialized_optimizer = self.optimizer(params=self.parameters()) 149 150 if self.lr_scheduler: 151 fully_initialized_lr_scheduler = self.lr_scheduler( 152 optimizer=fully_initialized_optimizer 153 ) 154 155 return { 156 "optimizer": fully_initialized_optimizer, 157 "lr_scheduler": { 158 "scheduler": fully_initialized_lr_scheduler, 159 "monitor": "learning_curve/val/loss_cls", 160 "interval": "epoch", 161 "frequency": 1, 162 }, 163 } 164 165 return {"optimizer": fully_initialized_optimizer}
Configure optimizer hooks by Lightning. See Lightning docs for more details.