clarena.mtl_algorithms.joint_learning

The submodule in mtl_algorithms for joint learning algorithm.

  1r"""
  2The submodule in `mtl_algorithms` for joint learning algorithm.
  3"""
  4
  5__all__ = ["JointLearning"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12from torch.utils.data import DataLoader
 13
 14from clarena.backbones import Backbone
 15from clarena.heads import HeadsMTL
 16from clarena.mtl_algorithms import MTLAlgorithm
 17
 18# always get logger for built-in logging in each module
 19pylogger = logging.getLogger(__name__)
 20
 21
 22class JointLearning(MTLAlgorithm):
 23    r"""Joint learning algorithm.
 24
 25    The most naive way for multi-task learning. It directly trains all tasks.
 26    """
 27
 28    def __init__(
 29        self,
 30        backbone: Backbone,
 31        heads: HeadsMTL,
 32        non_algorithmic_hparams: dict[str, Any] = {},
 33    ) -> None:
 34        r"""Initialize the JointLearning algorithm with the network. It has no additional hyperparameters.
 35
 36        **Args:**
 37        - **backbone** (`Backbone`): backbone network.
 38        - **heads** (`HeadsMTL`): output heads.
 39        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 40        """
 41        super().__init__(
 42            backbone=backbone,
 43            heads=heads,
 44            non_algorithmic_hparams=non_algorithmic_hparams,
 45        )
 46
 47    def training_step(self, batch: Any) -> dict[str, Tensor]:
 48        r"""Training step.
 49
 50        **Args:**
 51        - **batch** (`Any`): a batch of training data, which can be from any mixed tasks. Must include task IDs in the batch.
 52
 53        **Returns:**
 54        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 55        """
 56        x, y, task_ids = batch  # train data are provided task ID in case of MTL
 57        logits, activations = self.forward(x, stage="train", task_ids=task_ids)
 58
 59        # the data are from different tasks, so we need to calculate the loss and accuracy for each task separately
 60        preds = torch.zeros_like(y)
 61        loss_cls = 0.0
 62        acc = 0.0
 63
 64        for task_id in torch.unique(task_ids):  # for each unique task in the batch
 65            idx = (task_ids == task_id).nonzero(as_tuple=True)[
 66                0
 67            ]  # indices of the current task in the batch
 68            logits_t = logits[idx]  # get the logits for the current task
 69            y_t = y[idx]  # class labels for the current task
 70
 71            # classification loss
 72            loss_cls_t = self.criterion(logits_t, y_t)
 73            loss_cls = loss_cls + loss_cls_t
 74
 75            # predicted labels of this task
 76            preds_t = logits_t.argmax(dim=1)
 77            preds[idx] = preds_t
 78
 79            # accuracy of this task
 80            acc_task = (preds_t == y_t).float().mean()
 81            acc = acc + acc_task
 82
 83        loss_cls = loss_cls / len(torch.unique(task_ids))  # average loss over tasks
 84        acc = acc / len(torch.unique(task_ids))  # average accuracy over tasks
 85
 86        # total loss
 87        loss = loss_cls
 88
 89        return {
 90            "preds": preds,
 91            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
 92            "loss_cls": loss_cls,
 93            "acc": acc,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
 94            "activations": activations,
 95        }
 96
 97    def validation_step(
 98        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
 99    ) -> dict[str, Tensor]:
100        r"""Validation step. This is done task by task rather than mixing the tasks in batches.
101
102        **Args:**
103        - **batch** (`Any`): a batch of validation data.
104        - **dataloader_idx** (`int`): the task ID of seen tasks to be validated. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
105
106        **Returns:**
107        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
108        """
109        val_task_id = self.get_val_task_id_from_dataloader_idx(dataloader_idx)
110
111        x, y, _ = batch  # validation data are not provided task ID
112
113        # the batch is from the same task, so no need to divide the input batch by tasks
114        logits, _ = self.forward(
115            x, stage="validation", task_ids=val_task_id
116        )  # use the corresponding head to get the logits
117        loss_cls = self.criterion(logits, y)
118        preds = logits.argmax(dim=1)
119        acc = (preds == y).float().mean()
120
121        # return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
122        return {
123            "preds": preds,
124            "loss_cls": loss_cls,
125            "acc": acc,
126        }
127
128    def test_step(
129        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
130    ) -> dict[str, Tensor]:
131        r"""Test step. This is done task by task rather than mixing the tasks in batches.
132
133        **Args:**
134        - **batch** (`Any`): a batch of test data.
135        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
136
137        **Returns:**
138        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
139        """
140        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
141
142        x, y, _ = batch
143
144        # the batch is from the same task, so no need to divide the input batch by tasks
145        logits, _ = self.forward(
146            x, stage="test", task_ids=test_task_id
147        )  # use the corresponding head to get the logits
148        loss_cls = self.criterion(logits, y)
149        preds = logits.argmax(dim=1)
150        acc = (preds == y).float().mean()
151
152        # return metrics for lightning loggers callback to handle at `on_test_batch_end()`
153        return {
154            "preds": preds,
155            "loss_cls": loss_cls,
156            "acc": acc,
157        }
class JointLearning(clarena.mtl_algorithms.base.MTLAlgorithm):
 23class JointLearning(MTLAlgorithm):
 24    r"""Joint learning algorithm.
 25
 26    The most naive way for multi-task learning. It directly trains all tasks.
 27    """
 28
 29    def __init__(
 30        self,
 31        backbone: Backbone,
 32        heads: HeadsMTL,
 33        non_algorithmic_hparams: dict[str, Any] = {},
 34    ) -> None:
 35        r"""Initialize the JointLearning algorithm with the network. It has no additional hyperparameters.
 36
 37        **Args:**
 38        - **backbone** (`Backbone`): backbone network.
 39        - **heads** (`HeadsMTL`): output heads.
 40        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 41        """
 42        super().__init__(
 43            backbone=backbone,
 44            heads=heads,
 45            non_algorithmic_hparams=non_algorithmic_hparams,
 46        )
 47
 48    def training_step(self, batch: Any) -> dict[str, Tensor]:
 49        r"""Training step.
 50
 51        **Args:**
 52        - **batch** (`Any`): a batch of training data, which can be from any mixed tasks. Must include task IDs in the batch.
 53
 54        **Returns:**
 55        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
 56        """
 57        x, y, task_ids = batch  # train data are provided task ID in case of MTL
 58        logits, activations = self.forward(x, stage="train", task_ids=task_ids)
 59
 60        # the data are from different tasks, so we need to calculate the loss and accuracy for each task separately
 61        preds = torch.zeros_like(y)
 62        loss_cls = 0.0
 63        acc = 0.0
 64
 65        for task_id in torch.unique(task_ids):  # for each unique task in the batch
 66            idx = (task_ids == task_id).nonzero(as_tuple=True)[
 67                0
 68            ]  # indices of the current task in the batch
 69            logits_t = logits[idx]  # get the logits for the current task
 70            y_t = y[idx]  # class labels for the current task
 71
 72            # classification loss
 73            loss_cls_t = self.criterion(logits_t, y_t)
 74            loss_cls = loss_cls + loss_cls_t
 75
 76            # predicted labels of this task
 77            preds_t = logits_t.argmax(dim=1)
 78            preds[idx] = preds_t
 79
 80            # accuracy of this task
 81            acc_task = (preds_t == y_t).float().mean()
 82            acc = acc + acc_task
 83
 84        loss_cls = loss_cls / len(torch.unique(task_ids))  # average loss over tasks
 85        acc = acc / len(torch.unique(task_ids))  # average accuracy over tasks
 86
 87        # total loss
 88        loss = loss_cls
 89
 90        return {
 91            "preds": preds,
 92            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
 93            "loss_cls": loss_cls,
 94            "acc": acc,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
 95            "activations": activations,
 96        }
 97
 98    def validation_step(
 99        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
100    ) -> dict[str, Tensor]:
101        r"""Validation step. This is done task by task rather than mixing the tasks in batches.
102
103        **Args:**
104        - **batch** (`Any`): a batch of validation data.
105        - **dataloader_idx** (`int`): the task ID of seen tasks to be validated. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
106
107        **Returns:**
108        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
109        """
110        val_task_id = self.get_val_task_id_from_dataloader_idx(dataloader_idx)
111
112        x, y, _ = batch  # validation data are not provided task ID
113
114        # the batch is from the same task, so no need to divide the input batch by tasks
115        logits, _ = self.forward(
116            x, stage="validation", task_ids=val_task_id
117        )  # use the corresponding head to get the logits
118        loss_cls = self.criterion(logits, y)
119        preds = logits.argmax(dim=1)
120        acc = (preds == y).float().mean()
121
122        # return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
123        return {
124            "preds": preds,
125            "loss_cls": loss_cls,
126            "acc": acc,
127        }
128
129    def test_step(
130        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
131    ) -> dict[str, Tensor]:
132        r"""Test step. This is done task by task rather than mixing the tasks in batches.
133
134        **Args:**
135        - **batch** (`Any`): a batch of test data.
136        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
137
138        **Returns:**
139        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
140        """
141        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
142
143        x, y, _ = batch
144
145        # the batch is from the same task, so no need to divide the input batch by tasks
146        logits, _ = self.forward(
147            x, stage="test", task_ids=test_task_id
148        )  # use the corresponding head to get the logits
149        loss_cls = self.criterion(logits, y)
150        preds = logits.argmax(dim=1)
151        acc = (preds == y).float().mean()
152
153        # return metrics for lightning loggers callback to handle at `on_test_batch_end()`
154        return {
155            "preds": preds,
156            "loss_cls": loss_cls,
157            "acc": acc,
158        }

Joint learning algorithm.

The most naive way for multi-task learning. It directly trains all tasks.

JointLearning( backbone: clarena.backbones.Backbone, heads: clarena.heads.HeadsMTL, non_algorithmic_hparams: dict[str, typing.Any] = {})
29    def __init__(
30        self,
31        backbone: Backbone,
32        heads: HeadsMTL,
33        non_algorithmic_hparams: dict[str, Any] = {},
34    ) -> None:
35        r"""Initialize the JointLearning algorithm with the network. It has no additional hyperparameters.
36
37        **Args:**
38        - **backbone** (`Backbone`): backbone network.
39        - **heads** (`HeadsMTL`): output heads.
40        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
41        """
42        super().__init__(
43            backbone=backbone,
44            heads=heads,
45            non_algorithmic_hparams=non_algorithmic_hparams,
46        )

Initialize the JointLearning algorithm with the network. It has no additional hyperparameters.

Args:

  • backbone (Backbone): backbone network.
  • heads (HeadsMTL): output heads.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
def training_step(self, batch: Any) -> dict[str, torch.Tensor]:
48    def training_step(self, batch: Any) -> dict[str, Tensor]:
49        r"""Training step.
50
51        **Args:**
52        - **batch** (`Any`): a batch of training data, which can be from any mixed tasks. Must include task IDs in the batch.
53
54        **Returns:**
55        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this training step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
56        """
57        x, y, task_ids = batch  # train data are provided task ID in case of MTL
58        logits, activations = self.forward(x, stage="train", task_ids=task_ids)
59
60        # the data are from different tasks, so we need to calculate the loss and accuracy for each task separately
61        preds = torch.zeros_like(y)
62        loss_cls = 0.0
63        acc = 0.0
64
65        for task_id in torch.unique(task_ids):  # for each unique task in the batch
66            idx = (task_ids == task_id).nonzero(as_tuple=True)[
67                0
68            ]  # indices of the current task in the batch
69            logits_t = logits[idx]  # get the logits for the current task
70            y_t = y[idx]  # class labels for the current task
71
72            # classification loss
73            loss_cls_t = self.criterion(logits_t, y_t)
74            loss_cls = loss_cls + loss_cls_t
75
76            # predicted labels of this task
77            preds_t = logits_t.argmax(dim=1)
78            preds[idx] = preds_t
79
80            # accuracy of this task
81            acc_task = (preds_t == y_t).float().mean()
82            acc = acc + acc_task
83
84        loss_cls = loss_cls / len(torch.unique(task_ids))  # average loss over tasks
85        acc = acc / len(torch.unique(task_ids))  # average accuracy over tasks
86
87        # total loss
88        loss = loss_cls
89
90        return {
91            "preds": preds,
92            "loss": loss,  # return loss is essential for training step, or backpropagation will fail
93            "loss_cls": loss_cls,
94            "acc": acc,  # return other metrics for lightning loggers callback to handle at `on_train_batch_end()`
95            "activations": activations,
96        }

Training step.

Args:

  • batch (Any): a batch of training data, which can be from any mixed tasks. Must include task IDs in the batch.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and accuracy from this training step. Keys (str) are the metrics names, and values (Tensor) are the metrics. Must include the key 'loss' which is total loss in the case of automatic optimization, according to PyTorch Lightning docs.
def validation_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
 98    def validation_step(
 99        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
100    ) -> dict[str, Tensor]:
101        r"""Validation step. This is done task by task rather than mixing the tasks in batches.
102
103        **Args:**
104        - **batch** (`Any`): a batch of validation data.
105        - **dataloader_idx** (`int`): the task ID of seen tasks to be validated. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
106
107        **Returns:**
108        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this validation step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
109        """
110        val_task_id = self.get_val_task_id_from_dataloader_idx(dataloader_idx)
111
112        x, y, _ = batch  # validation data are not provided task ID
113
114        # the batch is from the same task, so no need to divide the input batch by tasks
115        logits, _ = self.forward(
116            x, stage="validation", task_ids=val_task_id
117        )  # use the corresponding head to get the logits
118        loss_cls = self.criterion(logits, y)
119        preds = logits.argmax(dim=1)
120        acc = (preds == y).float().mean()
121
122        # return metrics for lightning loggers callback to handle at `on_validation_batch_end()`
123        return {
124            "preds": preds,
125            "loss_cls": loss_cls,
126            "acc": acc,
127        }

Validation step. This is done task by task rather than mixing the tasks in batches.

Args:

  • batch (Any): a batch of validation data.
  • dataloader_idx (int): the task ID of seen tasks to be validated. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and accuracy from this validation step. Keys (str) are the metrics names, and values (Tensor) are the metrics.
def test_step( self, batch: torch.utils.data.dataloader.DataLoader, batch_idx: int, dataloader_idx: int = 0) -> dict[str, torch.Tensor]:
129    def test_step(
130        self, batch: DataLoader, batch_idx: int, dataloader_idx: int = 0
131    ) -> dict[str, Tensor]:
132        r"""Test step. This is done task by task rather than mixing the tasks in batches.
133
134        **Args:**
135        - **batch** (`Any`): a batch of test data.
136        - **dataloader_idx** (`int`): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a `RuntimeError`.
137
138        **Returns:**
139        - **outputs** (`dict[str, Tensor]`): a dictionary contains loss and accuracy from this test step. Keys (`str`) are the metrics names, and values (`Tensor`) are the metrics.
140        """
141        test_task_id = self.get_test_task_id_from_dataloader_idx(dataloader_idx)
142
143        x, y, _ = batch
144
145        # the batch is from the same task, so no need to divide the input batch by tasks
146        logits, _ = self.forward(
147            x, stage="test", task_ids=test_task_id
148        )  # use the corresponding head to get the logits
149        loss_cls = self.criterion(logits, y)
150        preds = logits.argmax(dim=1)
151        acc = (preds == y).float().mean()
152
153        # return metrics for lightning loggers callback to handle at `on_test_batch_end()`
154        return {
155            "preds": preds,
156            "loss_cls": loss_cls,
157            "acc": acc,
158        }

Test step. This is done task by task rather than mixing the tasks in batches.

Args:

  • batch (Any): a batch of test data.
  • dataloader_idx (int): the task ID of seen tasks to be tested. A default value of 0 is given otherwise the LightningModule will raise a RuntimeError.

Returns:

  • outputs (dict[str, Tensor]): a dictionary contains loss and accuracy from this test step. Keys (str) are the metrics names, and values (Tensor) are the metrics.