clarena.mtl_algorithms.joint_learning
The submodule in mtl_algorithms
for joint learning algorithm.
1r""" 2The submodule in `mtl_algorithms` for joint learning algorithm. 3""" 4 5__all__ = ["JointLearning"] 6 7import logging 8from typing import Any 9 10import torch 11from torch import Tensor 12from torch.utils.data import DataLoader 13 14from clarena.backbones import Backbone 15from clarena.heads import HeadsMTL 16from clarena.mtl_algorithms import MTLAlgorithm 17 18# always get logger for built-in logging in each module 19pylogger = logging.getLogger(__name__) 20 21 22class JointLearning(MTLAlgorithm): 23 r"""Joint learning algorithm. 24 25 The most naive way for multi-task learning. It directly trains all tasks. 26 """ 27 28 def __init__( 29 self, 30 backbone: Backbone, 31 heads: HeadsMTL, 32 non_algorithmic_hparams: dict[str, Any] = {}, 33 ) -> None: 34 r"""Initialize the JointLearning algorithm with the network. It has no additional hyperparameters. 35 36 **Args:** 37 - **backbone** (`Backbone`): backbone network. 38 - **heads** (`HeadsMTL`): output heads. 39 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 40 """ 41 super().__init__( 42 backbone=backbone, 43 heads=heads, 44 non_algorithmic_hparams=non_algorithmic_hparams, 45 ) 46 47 def training_step(self, batch: Any) -> dict[str, Tensor]: 48 r"""Training step. 49 50 **Args:** 51 - **batch** (`Any`): a batch of training data, which can be from any mixed tasks. Must include task IDs in the batch. 52 53 **Returns:** 54 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 55 """ 56 x, y, task_ids = batch # train data are provided task ID in case of MTL 57 logits, activations = self.forward(x, stage="train", task_ids=task_ids) 58 59 # the data are from different tasks, so we need to calculate the loss and accuracy for each task separately 60 preds = torch.zeros_like(y) 61 loss_cls = 0.0 62 acc = 0.0 63 64 for task_id in torch.unique(task_ids): # for each unique task in the batch 65 idx = (task_ids == task_id).nonzero(as_tuple=True)[ 66 0 67 ] # indices of the current task in the batch 68 logits_t = logits[idx] # get the logits for the current task 69 y_t = y[idx] # class labels for the current task 70 71 # classification loss 72 loss_cls_t = self.criterion(logits_t, y_t) 73 loss_cls = loss_cls + loss_cls_t 74 75 # predicted labels of this task 76 preds_t = logits_t.argmax(dim=1) 77 preds[idx] = preds_t 78 79 # accuracy of this task 80 acc_task = (preds_t == y_t).float().mean() 81 acc = acc + acc_task 82 83 loss_cls = loss_cls / len(torch.unique(task_ids)) # average loss over tasks 84 acc = acc / len(torch.unique(task_ids)) # average accuracy over tasks 85 86 # total loss 87 loss = loss_cls 88 89 return { 90 "preds": preds, 91 "loss": loss, # return loss is essential for training step, or backpropagation will fail 92 "loss_cls": loss_cls, 93 "acc": acc, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 94 "activations": activations, 95 } 96 97 def validation_step( 98 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 99 ) -> dict[str, Tensor]: 100 r"""Validation step. This is done task by task rather than mixing the tasks in batches. 101 102 **Args:** 103 - **batch** (`Any`): a batch of validation data. 104 - **dataloader_idx** (`int`): the task ID of seen tasks to be validated. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 105 106 **Returns:** 107 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 108 """ 109 val_task_id = self.get_val_task_id_from_dataloader_idx(dataloader_idx) 110 111 x, y, _ = batch # validation data are not provided task ID 112 113 # the batch is from the same task, so no need to divide the input batch by tasks 114 logits, _ = self.forward( 115 x, stage="validation", task_ids=val_task_id 116 ) # use the corresponding head to get the logits 117 loss_cls = self.criterion(logits, y) 118 preds = logits.argmax(dim=1) 119 acc = (preds == y).float().mean() 120 121 # return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 122 return { 123 "preds": preds, 124 "loss_cls": loss_cls, 125 "acc": acc, 126 } 127 128 def test_step( 129 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 130 ) -> dict[str, Tensor]: 131 r"""Test step. This is done task by task rather than mixing the tasks in batches. 132 133 **Args:** 134 - **batch** (`Any`): a batch of test data. 135 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 136 137 **Returns:** 138 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 139 """ 140 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 141 142 x, y, _ = batch 143 144 # the batch is from the same task, so no need to divide the input batch by tasks 145 logits, _ = self.forward( 146 x, stage="test", task_ids=test_task_id 147 ) # use the corresponding head to get the logits 148 loss_cls = self.criterion(logits, y) 149 preds = logits.argmax(dim=1) 150 acc = (preds == y).float().mean() 151 152 # return metrics for lightning loggers callback to handle at `on_test_batch_end()` 153 return { 154 "preds": preds, 155 "loss_cls": loss_cls, 156 "acc": acc, 157 }
class
JointLearning(clarena.mtl_algorithms.base.MTLAlgorithm):
23class JointLearning(MTLAlgorithm): 24 r"""Joint learning algorithm. 25 26 The most naive way for multi-task learning. It directly trains all tasks. 27 """ 28 29 def __init__( 30 self, 31 backbone: Backbone, 32 heads: HeadsMTL, 33 non_algorithmic_hparams: dict[str, Any] = {}, 34 ) -> None: 35 r"""Initialize the JointLearning algorithm with the network. It has no additional hyperparameters. 36 37 **Args:** 38 - **backbone** (`Backbone`): backbone network. 39 - **heads** (`HeadsMTL`): output heads. 40 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 41 """ 42 super().__init__( 43 backbone=backbone, 44 heads=heads, 45 non_algorithmic_hparams=non_algorithmic_hparams, 46 ) 47 48 def training_step(self, batch: Any) -> dict[str, Tensor]: 49 r"""Training step. 50 51 **Args:** 52 - **batch** (`Any`): a batch of training data, which can be from any mixed tasks. Must include task IDs in the batch. 53 54 **Returns:** 55 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 56 """ 57 x, y, task_ids = batch # train data are provided task ID in case of MTL 58 logits, activations = self.forward(x, stage="train", task_ids=task_ids) 59 60 # the data are from different tasks, so we need to calculate the loss and accuracy for each task separately 61 preds = torch.zeros_like(y) 62 loss_cls = 0.0 63 acc = 0.0 64 65 for task_id in torch.unique(task_ids): # for each unique task in the batch 66 idx = (task_ids == task_id).nonzero(as_tuple=True)[ 67 0 68 ] # indices of the current task in the batch 69 logits_t = logits[idx] # get the logits for the current task 70 y_t = y[idx] # class labels for the current task 71 72 # classification loss 73 loss_cls_t = self.criterion(logits_t, y_t) 74 loss_cls = loss_cls + loss_cls_t 75 76 # predicted labels of this task 77 preds_t = logits_t.argmax(dim=1) 78 preds[idx] = preds_t 79 80 # accuracy of this task 81 acc_task = (preds_t == y_t).float().mean() 82 acc = acc + acc_task 83 84 loss_cls = loss_cls / len(torch.unique(task_ids)) # average loss over tasks 85 acc = acc / len(torch.unique(task_ids)) # average accuracy over tasks 86 87 # total loss 88 loss = loss_cls 89 90 return { 91 "preds": preds, 92 "loss": loss, # return loss is essential for training step, or backpropagation will fail 93 "loss_cls": loss_cls, 94 "acc": acc, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 95 "activations": activations, 96 } 97 98 def validation_step( 99 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 100 ) -> dict[str, Tensor]: 101 r"""Validation step. This is done task by task rather than mixing the tasks in batches. 102 103 **Args:** 104 - **batch** (`Any`): a batch of validation data. 105 - **dataloader_idx** (`int`): the task ID of seen tasks to be validated. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 106 107 **Returns:** 108 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 109 """ 110 val_task_id = self.get_val_task_id_from_dataloader_idx(dataloader_idx) 111 112 x, y, _ = batch # validation data are not provided task ID 113 114 # the batch is from the same task, so no need to divide the input batch by tasks 115 logits, _ = self.forward( 116 x, stage="validation", task_ids=val_task_id 117 ) # use the corresponding head to get the logits 118 loss_cls = self.criterion(logits, y) 119 preds = logits.argmax(dim=1) 120 acc = (preds == y).float().mean() 121 122 # return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 123 return { 124 "preds": preds, 125 "loss_cls": loss_cls, 126 "acc": acc, 127 } 128 129 def test_step( 130 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 131 ) -> dict[str, Tensor]: 132 r"""Test step. This is done task by task rather than mixing the tasks in batches. 133 134 **Args:** 135 - **batch** (`Any`): a batch of test data. 136 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 137 138 **Returns:** 139 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 140 """ 141 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 142 143 x, y, _ = batch 144 145 # the batch is from the same task, so no need to divide the input batch by tasks 146 logits, _ = self.forward( 147 x, stage="test", task_ids=test_task_id 148 ) # use the corresponding head to get the logits 149 loss_cls = self.criterion(logits, y) 150 preds = logits.argmax(dim=1) 151 acc = (preds == y).float().mean() 152 153 # return metrics for lightning loggers callback to handle at `on_test_batch_end()` 154 return { 155 "preds": preds, 156 "loss_cls": loss_cls, 157 "acc": acc, 158 }
Joint learning algorithm.
The most naive way for multi-task learning. It directly trains all tasks.
JointLearning( backbone: clarena.backbones.Backbone, heads: clarena.heads.HeadsMTL, non_algorithmic_hparams: dict[str, typing.Any] = {})
29 def __init__( 30 self, 31 backbone: Backbone, 32 heads: HeadsMTL, 33 non_algorithmic_hparams: dict[str, Any] = {}, 34 ) -> None: 35 r"""Initialize the JointLearning algorithm with the network. It has no additional hyperparameters. 36 37 **Args:** 38 - **backbone** (`Backbone`): backbone network. 39 - **heads** (`HeadsMTL`): output heads. 40 - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility. 41 """ 42 super().__init__( 43 backbone=backbone, 44 heads=heads, 45 non_algorithmic_hparams=non_algorithmic_hparams, 46 )
Initialize the JointLearning algorithm with the network. It has no additional hyperparameters.
Args:
- backbone (
Backbone
): backbone network. - heads (
HeadsMTL
): output heads. - non_algorithmic_hparams (
dict[str, Any]
): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to thisLightningModule
object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs fromsave_hyperparameters()
method. This is useful for the experiment configuration and reproducibility.
def
training_step(self, batch: Any) -> dict[str, torch.Tensor]:
48 def training_step(self, batch: Any) -> dict[str, Tensor]: 49 r"""Training step. 50 51 **Args:** 52 - **batch** (`Any`): a batch of training data, which can be from any mixed tasks. Must include task IDs in the batch. 53 54 **Returns:** 55 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs. 56 """ 57 x, y, task_ids = batch # train data are provided task ID in case of MTL 58 logits, activations = self.forward(x, stage="train", task_ids=task_ids) 59 60 # the data are from different tasks, so we need to calculate the loss and accuracy for each task separately 61 preds = torch.zeros_like(y) 62 loss_cls = 0.0 63 acc = 0.0 64 65 for task_id in torch.unique(task_ids): # for each unique task in the batch 66 idx = (task_ids == task_id).nonzero(as_tuple=True)[ 67 0 68 ] # indices of the current task in the batch 69 logits_t = logits[idx] # get the logits for the current task 70 y_t = y[idx] # class labels for the current task 71 72 # classification loss 73 loss_cls_t = self.criterion(logits_t, y_t) 74 loss_cls = loss_cls + loss_cls_t 75 76 # predicted labels of this task 77 preds_t = logits_t.argmax(dim=1) 78 preds[idx] = preds_t 79 80 # accuracy of this task 81 acc_task = (preds_t == y_t).float().mean() 82 acc = acc + acc_task 83 84 loss_cls = loss_cls / len(torch.unique(task_ids)) # average loss over tasks 85 acc = acc / len(torch.unique(task_ids)) # average accuracy over tasks 86 87 # total loss 88 loss = loss_cls 89 90 return { 91 "preds": preds, 92 "loss": loss, # return loss is essential for training step, or backpropagation will fail 93 "loss_cls": loss_cls, 94 "acc": acc, # return other metrics for lightning loggers callback to handle at `on_train_batch_end()` 95 "activations": activations, 96 }
Training step.
Args:
- batch (
Any
): a batch of training data, which can be from any mixed tasks. Must include task IDs in the batch.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and accuracy from this training step. Keys (str
) are the metrics names, and values (Tensor
) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def
validation_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
98 def validation_step( 99 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 100 ) -> dict[str, Tensor]: 101 r"""Validation step. This is done task by task rather than mixing the tasks in batches. 102 103 **Args:** 104 - **batch** (`Any`): a batch of validation data. 105 - **dataloader_idx** (`int`): the task ID of seen tasks to be validated. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 106 107 **Returns:** 108 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 109 """ 110 val_task_id = self.get_val_task_id_from_dataloader_idx(dataloader_idx) 111 112 x, y, _ = batch # validation data are not provided task ID 113 114 # the batch is from the same task, so no need to divide the input batch by tasks 115 logits, _ = self.forward( 116 x, stage="validation", task_ids=val_task_id 117 ) # use the corresponding head to get the logits 118 loss_cls = self.criterion(logits, y) 119 preds = logits.argmax(dim=1) 120 acc = (preds == y).float().mean() 121 122 # return metrics for lightning loggers callback to handle at `on_validation_batch_end()` 123 return { 124 "preds": preds, 125 "loss_cls": loss_cls, 126 "acc": acc, 127 }
Validation step. This is done task by task rather than mixing the tasks in batches.
Args:
- batch (
Any
): a batch of validation data. - dataloader_idx (
int
): the task ID of seen tasks to be validated. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and accuracy from this validation step. Keys (str
) are the metrics names, and values (Tensor
) are the metrics.
def
test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
129 def test_step( 130 self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0 131 ) -> dict[str, Tensor]: 132 r"""Test step. This is done task by task rather than mixing the tasks in batches. 133 134 **Args:** 135 - **batch** (`Any`): a batch of test data. 136 - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`. 137 138 **Returns:** 139 - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. 140 """ 141 test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx) 142 143 x, y, _ = batch 144 145 # the batch is from the same task, so no need to divide the input batch by tasks 146 logits, _ = self.forward( 147 x, stage="test", task_ids=test_task_id 148 ) # use the corresponding head to get the logits 149 loss_cls = self.criterion(logits, y) 150 preds = logits.argmax(dim=1) 151 acc = (preds == y).float().mean() 152 153 # return metrics for lightning loggers callback to handle at `on_test_batch_end()` 154 return { 155 "preds": preds, 156 "loss_cls": loss_cls, 157 "acc": acc, 158 }
Test step. This is done task by task rather than mixing the tasks in batches.
Args:
- batch (
Any
): a batch of test data. - dataloader_idx (
int
): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise aRuntimeError
.
Returns:
- outputs (
dict[str, Tensor]
): a dictionary contains loss and accuracy from this test step. Keys (str
) are the metrics names, and values (Tensor
) are the metrics.