clarena.mtl_algorithms

Multi-Task Learning Algorithms

This submodule provides the multi-task learning algorithms in CLArena.

Here are the base classes for MTL algorithms, which inherit from PyTorch Lightning LightningModule:

  • MTLAlgorithm: the base class for all multi-task learning algorithms.

Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement MTL algorithms:

 1r"""
 2
 3# Multi-Task Learning Algorithms
 4
 5This submodule provides the **multi-task learning algorithms** in CLArena.
 6
 7Here are the base classes for MTL algorithms, which inherit from PyTorch Lightning `LightningModule`:
 8
 9- `MTLAlgorithm`: the base class for all multi-task learning algorithms.
10
11
12Please note that this is an API documantation. Please refer to the main documentation pages for more information about how to configure and implement MTL algorithms:
13
14- [**Configure MTL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/components/mtl-algorithm)
15- [**Implement Custom MTL Algorithm**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/custom-implementation/mtl-algorithm)
16
17"""
18
19from .base import MTLAlgorithm
20
21from .joint_learning import JointLearning
22
23__all__ = ["MTLAlgorithm", "joint_learning"]
class MTLAlgorithm(lightning.pytorch.core.module.LightningModule):
 23class MTLAlgorithm(LightningModule):
 24    r"""The base class of multi-task learning algorithms."""
 25
 26    def __init__(
 27        self,
 28        backbone: Backbone,
 29        heads: HeadsMTL,
 30        non_algorithmic_hparams: dict[str, Any] = {},
 31    ) -> None:
 32        r"""
 33        **Args:**
 34        - **backbone** (`Backbone`): backbone network.
 35        - **heads** (`HeadsMTL`): output heads.
 36        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 37        """
 38        super().__init__()
 39        self.save_hyperparameters(non_algorithmic_hparams)
 40
 41        # components
 42        self.backbone: Backbone = backbone
 43        r"""The backbone network."""
 44        self.heads: HeadsMTL = heads
 45        r"""The output heads."""
 46        self.optimizer: Optimizer
 47        r"""Optimizer (partially initialized) for the backpropagation. Will be equipped with parameters in `configure_optimizers()`."""
 48        self.lr_scheduler: LRScheduler | None
 49        r"""The learning rate scheduler for the optimizer. If `None`, no scheduler is used."""
 50        self.criterion = nn.CrossEntropyLoss()
 51        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
 52
 53        self.if_forward_func_return_logits_only: bool = False
 54        r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`."""
 55
 56        MTLAlgorithm.sanity_check(self)
 57
 58    def sanity_check(self) -> None:
 59        r"""Sanity check."""
 60
 61        # check backbone and heads compatibility
 62        if self.backbone.output_dim != self.heads.input_dim:
 63            raise ValueError(
 64                "The output_dim of backbone network should be equal to the input_dim of MTL heads!"
 65            )
 66
 67    def setup_tasks(
 68        self,
 69        task_ids: list[int],
 70        num_classes: dict[int, int],
 71        optimizer: Optimizer,
 72        lr_scheduler: LRScheduler | None,
 73    ) -> None:
 74        r"""Set up tasks for the MTL algorithm. This must be done before `forward()` method is called.
 75
 76        **Args:**
 77        - **task_ids** (`list[int]`): the list of task IDs.
 78        - **num_classes** (`dict[int, int]`): a dictionary mapping each task ID to its number of classes.
 79        - **optimizer** (`Optimizer`): the optimizer object (partially initialized).
 80        - **lr_scheduler** (`LRScheduler` | None): the learning rate scheduler for the optimizer. If `None`, no scheduler is used.
 81        """
 82        self.heads.setup_tasks(task_ids=task_ids, num_classes=num_classes)
 83        self.optimizer = optimizer
 84        self.lr_scheduler = lr_scheduler
 85
 86    def get_val_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
 87        r"""Get the validation task ID from the dataloader index.
 88
 89        **Args:**
 90        - **dataloader_idx** (`int`): the dataloader index.
 91
 92        **Returns:**
 93        - **val_task_id** (`int`): the validation task ID.
 94        """
 95        dataset_val = self.trainer.datamodule.dataset_val
 96        val_task_id = list(dataset_val.keys())[dataloader_idx]
 97        return val_task_id
 98
 99    def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
100        r"""Get the test task ID from the dataloader index.
101
102        **Args:**
103        - **dataloader_idx** (`int`): the dataloader index.
104
105        **Returns:**
106        - **test_task_id** (`int`): the test task ID.
107        """
108        dataset_test = self.trainer.datamodule.dataset_test
109        test_task_id = list(dataset_test.keys())[dataloader_idx]
110        return test_task_id
111
112    def set_forward_func_return_logits_only(
113        self, forward_func_return_logits_only: bool
114    ) -> None:
115        r"""Set whether the `forward()` method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT.
116
117        **Args:**
118        - **forward_func_return_logits_only** (`bool`): whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information.
119        """
120        self.if_forward_func_return_logits_only = forward_func_return_logits_only
121
122    def forward(self, input: Tensor, task_ids: int | Tensor, stage: str) -> Tensor:
123        r"""The forward pass for data. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many MTL algorithm including the vanilla JointLearning algorithm use.
124
125        This forward pass does not accept input batch in different tasks. Please make sure the input batch is from the same task. If you want to use this forward pass for different tasks, please divide the input batch by tasks and call this forward pass for each task separately.
126
127        **Args:**
128        - **input** (`Tensor`): The input tensor from data.
129        - **task_ids** (`int` | `Tensor`): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer.
130        - **stage** (`str`): the stage of the forward pass; one of:
131            1. 'train': training stage.
132            2. 'validation': validation stage.
133            3. 'test': testing stage.
134
135        **Returns:**
136        - **logits** (`Tensor`): the output logits tensor.
137        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
138        """
139        feature, activations = self.backbone(input, stage=stage)
140        logits = self.heads(feature, task_ids)
141        return (
142            logits if self.if_forward_func_return_logits_only else (logits, activations)
143        )
144
145    def configure_optimizers(self) -> Optimizer:
146        r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details."""
147        # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `MTLAlgorithm` (inherited from `LightningModule`) returns both backbone and heads parameters
148        fully_initialized_optimizer = self.optimizer(params=self.parameters())
149
150        if self.lr_scheduler:
151            fully_initialized_lr_scheduler = self.lr_scheduler(
152                optimizer=fully_initialized_optimizer
153            )
154
155            return {
156                "optimizer": fully_initialized_optimizer,
157                "lr_scheduler": {
158                    "scheduler": fully_initialized_lr_scheduler,
159                    "monitor": "learning_curve/val/loss_cls",
160                    "interval": "epoch",
161                    "frequency": 1,
162                },
163            }
164
165        return {"optimizer": fully_initialized_optimizer}

The base class of multi-task learning algorithms.

MTLAlgorithm( backbone: clarena.backbones.Backbone, heads: clarena.heads.HeadsMTL, non_algorithmic_hparams: dict[str, typing.Any] = {})
26    def __init__(
27        self,
28        backbone: Backbone,
29        heads: HeadsMTL,
30        non_algorithmic_hparams: dict[str, Any] = {},
31    ) -> None:
32        r"""
33        **Args:**
34        - **backbone** (`Backbone`): backbone network.
35        - **heads** (`HeadsMTL`): output heads.
36        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
37        """
38        super().__init__()
39        self.save_hyperparameters(non_algorithmic_hparams)
40
41        # components
42        self.backbone: Backbone = backbone
43        r"""The backbone network."""
44        self.heads: HeadsMTL = heads
45        r"""The output heads."""
46        self.optimizer: Optimizer
47        r"""Optimizer (partially initialized) for the backpropagation. Will be equipped with parameters in `configure_optimizers()`."""
48        self.lr_scheduler: LRScheduler | None
49        r"""The learning rate scheduler for the optimizer. If `None`, no scheduler is used."""
50        self.criterion = nn.CrossEntropyLoss()
51        r"""The loss function bewteen the output logits and the target labels. Default is cross-entropy loss."""
52
53        self.if_forward_func_return_logits_only: bool = False
54        r"""Whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information. Default is `False`."""
55
56        MTLAlgorithm.sanity_check(self)

Args:

  • backbone (Backbone): backbone network.
  • heads (HeadsMTL): output heads.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.

The backbone network.

The output heads.

optimizer: torch.optim.optimizer.Optimizer

Optimizer (partially initialized) for the backpropagation. Will be equipped with parameters in configure_optimizers().

lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None

The learning rate scheduler for the optimizer. If None, no scheduler is used.

criterion

The loss function bewteen the output logits and the target labels. Default is cross-entropy loss.

if_forward_func_return_logits_only: bool

Whether the forward() method returns logits only. If False, it returns a dictionary containing logits and other information. Default is False.

def sanity_check(self) -> None:
58    def sanity_check(self) -> None:
59        r"""Sanity check."""
60
61        # check backbone and heads compatibility
62        if self.backbone.output_dim != self.heads.input_dim:
63            raise ValueError(
64                "The output_dim of backbone network should be equal to the input_dim of MTL heads!"
65            )

Sanity check.

def setup_tasks( self, task_ids: list[int], num_classes: dict[int, int], optimizer: torch.optim.optimizer.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None) -> None:
67    def setup_tasks(
68        self,
69        task_ids: list[int],
70        num_classes: dict[int, int],
71        optimizer: Optimizer,
72        lr_scheduler: LRScheduler | None,
73    ) -> None:
74        r"""Set up tasks for the MTL algorithm. This must be done before `forward()` method is called.
75
76        **Args:**
77        - **task_ids** (`list[int]`): the list of task IDs.
78        - **num_classes** (`dict[int, int]`): a dictionary mapping each task ID to its number of classes.
79        - **optimizer** (`Optimizer`): the optimizer object (partially initialized).
80        - **lr_scheduler** (`LRScheduler` | None): the learning rate scheduler for the optimizer. If `None`, no scheduler is used.
81        """
82        self.heads.setup_tasks(task_ids=task_ids, num_classes=num_classes)
83        self.optimizer = optimizer
84        self.lr_scheduler = lr_scheduler

Set up tasks for the MTL algorithm. This must be done before forward() method is called.

Args:

  • task_ids (list[int]): the list of task IDs.
  • num_classes (dict[int, int]): a dictionary mapping each task ID to its number of classes.
  • optimizer (Optimizer): the optimizer object (partially initialized).
  • lr_scheduler (LRScheduler | None): the learning rate scheduler for the optimizer. If None, no scheduler is used.
def get_val_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
86    def get_val_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
87        r"""Get the validation task ID from the dataloader index.
88
89        **Args:**
90        - **dataloader_idx** (`int`): the dataloader index.
91
92        **Returns:**
93        - **val_task_id** (`int`): the validation task ID.
94        """
95        dataset_val = self.trainer.datamodule.dataset_val
96        val_task_id = list(dataset_val.keys())[dataloader_idx]
97        return val_task_id

Get the validation task ID from the dataloader index.

Args:

  • dataloader_idx (int): the dataloader index.

Returns:

  • val_task_id (int): the validation task ID.
def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
 99    def get_test_task_id_from_dataloader_idx(self, dataloader_idx: int) -> int:
100        r"""Get the test task ID from the dataloader index.
101
102        **Args:**
103        - **dataloader_idx** (`int`): the dataloader index.
104
105        **Returns:**
106        - **test_task_id** (`int`): the test task ID.
107        """
108        dataset_test = self.trainer.datamodule.dataset_test
109        test_task_id = list(dataset_test.keys())[dataloader_idx]
110        return test_task_id

Get the test task ID from the dataloader index.

Args:

  • dataloader_idx (int): the dataloader index.

Returns:

  • test_task_id (int): the test task ID.
def set_forward_func_return_logits_only(self, forward_func_return_logits_only: bool) -> None:
112    def set_forward_func_return_logits_only(
113        self, forward_func_return_logits_only: bool
114    ) -> None:
115        r"""Set whether the `forward()` method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT.
116
117        **Args:**
118        - **forward_func_return_logits_only** (`bool`): whether the `forward()` method returns logits only. If `False`, it returns a dictionary containing logits and other information.
119        """
120        self.if_forward_func_return_logits_only = forward_func_return_logits_only

Set whether the forward() method returns logits only. This is useful for some CL algorithms that require the forward function to return logits only, such as FG-AdaHAT.

Args:

  • forward_func_return_logits_only (bool): whether the forward() method returns logits only. If False, it returns a dictionary containing logits and other information.
def forward( self, input: torch.Tensor, task_ids: int | torch.Tensor, stage: str) -> torch.Tensor:
122    def forward(self, input: Tensor, task_ids: int | Tensor, stage: str) -> Tensor:
123        r"""The forward pass for data. Note that it is nothing to do with `forward()` method in `nn.Module`. This definition provides a template that many MTL algorithm including the vanilla JointLearning algorithm use.
124
125        This forward pass does not accept input batch in different tasks. Please make sure the input batch is from the same task. If you want to use this forward pass for different tasks, please divide the input batch by tasks and call this forward pass for each task separately.
126
127        **Args:**
128        - **input** (`Tensor`): The input tensor from data.
129        - **task_ids** (`int` | `Tensor`): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer.
130        - **stage** (`str`): the stage of the forward pass; one of:
131            1. 'train': training stage.
132            2. 'validation': validation stage.
133            3. 'test': testing stage.
134
135        **Returns:**
136        - **logits** (`Tensor`): the output logits tensor.
137        - **activations** (`dict[str, Tensor]`): the hidden features (after activation) in each weighted layer. Key (`str`) is the weighted layer name, value (`Tensor`) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
138        """
139        feature, activations = self.backbone(input, stage=stage)
140        logits = self.heads(feature, task_ids)
141        return (
142            logits if self.if_forward_func_return_logits_only else (logits, activations)
143        )

The forward pass for data. Note that it is nothing to do with forward() method in nn.Module. This definition provides a template that many MTL algorithm including the vanilla JointLearning algorithm use.

This forward pass does not accept input batch in different tasks. Please make sure the input batch is from the same task. If you want to use this forward pass for different tasks, please divide the input batch by tasks and call this forward pass for each task separately.

Args:

  • input (Tensor): The input tensor from data.
  • task_ids (int | Tensor): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer.
  • stage (str): the stage of the forward pass; one of:
    1. 'train': training stage.
    2. 'validation': validation stage.
    3. 'test': testing stage.

Returns:

  • logits (Tensor): the output logits tensor.
  • activations (dict[str, Tensor]): the hidden features (after activation) in each weighted layer. Key (str) is the weighted layer name, value (Tensor) is the hidden feature tensor. This is used for the continual learning algorithms that need to use the hidden features for various purposes.
def configure_optimizers(self) -> torch.optim.optimizer.Optimizer:
145    def configure_optimizers(self) -> Optimizer:
146        r"""Configure optimizer hooks by Lightning. See [Lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers) for more details."""
147        # finish partially initialized optimizer by specifying model parameters. The `parameters()` method of this `MTLAlgorithm` (inherited from `LightningModule`) returns both backbone and heads parameters
148        fully_initialized_optimizer = self.optimizer(params=self.parameters())
149
150        if self.lr_scheduler:
151            fully_initialized_lr_scheduler = self.lr_scheduler(
152                optimizer=fully_initialized_optimizer
153            )
154
155            return {
156                "optimizer": fully_initialized_optimizer,
157                "lr_scheduler": {
158                    "scheduler": fully_initialized_lr_scheduler,
159                    "monitor": "learning_curve/val/loss_cls",
160                    "interval": "epoch",
161                    "frequency": 1,
162                },
163            }
164
165        return {"optimizer": fully_initialized_optimizer}

Configure optimizer hooks by Lightning. See Lightning docs for more details.