clarena.heads

Output Heads

This submodule provides the output heads in CLArena.

There are two types of continual learning / unlearning heads in CLArena: HeadsTIL, HeadsCIL and HeadDIL, corresponding to three CL paradigms respectively: Task-Incremental Learning (TIL), Class-Incremental Learning (CIL) and Domain-Incremental Learning (DIL). For Multi-Task Learning (MTL), we have HeadsMTL which is a collection of independent heads for each task.

Please note that this is an API documantation. Please refer to the main documentation pages for more information about the heads.

 1r"""
 2
 3# Output Heads
 4
 5This submodule provides the **output heads** in CLArena.
 6
 7There are two types of continual learning / unlearning heads in CLArena: `HeadsTIL`, `HeadsCIL` and `HeadDIL`, corresponding to three CL paradigms respectively: Task-Incremental Learning (TIL), Class-Incremental Learning (CIL) and  Domain-Incremental Learning (DIL). For Multi-Task Learning (MTL), we have `HeadsMTL` which is a collection of independent heads for each task.
 8
 9Please note that this is an API documantation. Please refer to the main documentation pages for more information about the heads.
10
11- [**Configure CL Paradigm in Experiment Index Config**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiment/experiment-index-config)
12- [**A Beginners' Guide to Continual Learning (Multi-head Classifier)](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-CL-classification)
13
14"""
15
16from .heads_cil import HeadsCIL
17from .head_dil import HeadDIL
18from .heads_til import HeadsTIL
19
20from .heads_mtl import HeadsMTL
21
22from .head_stl import HeadSTL
23
24__all__ = ["HeadsTIL", "HeadsCIL", "HeadDIL", "HeadsMTL", "HeadSTL"]
class HeadsTIL(torch.nn.modules.module.Module):
14class HeadsTIL(nn.Module):
15    r"""The output heads for Task-Incremental Learning (TIL). Independent head assigned to each TIL task takes the output from backbone network and forwards it into logits for predicting classes of the task."""
16
17    def __init__(self, input_dim: int) -> None:
18        r"""Initializes TIL heads object with no heads.
19
20        **Args:**
21        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
22        """
23        super().__init__()
24
25        self.heads: nn.ModuleDict = nn.ModuleDict()  # initially no heads
26        r"""TIL output heads are stored independently in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.
27        
28        Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model. """
29        self.head_t: nn.Linear | None = None
30        r"""The output head for the current task. It is created when the task arrives and stored in `self.heads`."""
31
32        self.input_dim: int = input_dim
33        r"""Store the input dimension of the heads. Used when creating new heads."""
34
35        self.task_id: int
36        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1. """
37
38    def setup_task_id(self, task_id: int, num_classes_t: int) -> None:
39        r"""Create the output head when task `task_id` arrives if there's no. This must be done before `forward()` is called.
40
41        **Args:**
42        - **task_id** (`int`): the target task ID.
43        - **num_classes_t** (`int`): the number of classes in the task.
44        """
45        self.task_id = task_id
46        if self.task_id not in self.heads.keys():
47            self.head_t = nn.Linear(self.input_dim, num_classes_t)
48            self.heads[f"{self.task_id}"] = self.head_t
49
50    def get_head(self, task_id: int) -> nn.Linear:
51        r"""Get the output head for task `task_id`.
52
53        **Args:**
54        - **task_id** (`int`): the target task ID.
55
56        **Returns:**
57        - **head_t** (`nn.Linear`): the output head for task `task_id`.
58        """
59        return self.heads[f"{task_id}"]
60
61    def forward(self, feature: Tensor, task_id: int) -> Tensor:
62        r"""The forward pass for data from task `task_id`. A head is selected according to the task_id and the feature is passed through the head.
63
64        **Args:**
65        - **feature** (`Tensor`): the feature tensor from the backbone network.
66        - **task_id** (`int`): the task ID where the data are from, which is provided by task-incremental setting.
67
68        **Returns:**
69        - **logits** (`Tensor`): the output logits tensor.
70        """
71
72        head_t = self.get_head(task_id)
73        logits = head_t(feature)
74
75        return logits

The output heads for Task-Incremental Learning (TIL). Independent head assigned to each TIL task takes the output from backbone network and forwards it into logits for predicting classes of the task.

HeadsTIL(input_dim: int)
17    def __init__(self, input_dim: int) -> None:
18        r"""Initializes TIL heads object with no heads.
19
20        **Args:**
21        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
22        """
23        super().__init__()
24
25        self.heads: nn.ModuleDict = nn.ModuleDict()  # initially no heads
26        r"""TIL output heads are stored independently in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.
27        
28        Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model. """
29        self.head_t: nn.Linear | None = None
30        r"""The output head for the current task. It is created when the task arrives and stored in `self.heads`."""
31
32        self.input_dim: int = input_dim
33        r"""Store the input dimension of the heads. Used when creating new heads."""
34
35        self.task_id: int
36        r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1. """

Initializes TIL heads object with no heads.

Args:

  • input_dim (int): the input dimension of the heads. Must be equal to the output_dim of the connected backbone.
heads: torch.nn.modules.container.ModuleDict

TIL output heads are stored independently in a ModuleDict. Keys are task IDs and values are the corresponding nn.Linear heads. We use ModuleDict rather than dict to make sure LightningModule can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.

Note that the task IDs must be string type in order to let LightningModule identify this part of the model.

head_t: torch.nn.modules.linear.Linear | None

The output head for the current task. It is created when the task arrives and stored in self.heads.

input_dim: int

Store the input dimension of the heads. Used when creating new heads.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1.

def setup_task_id(self, task_id: int, num_classes_t: int) -> None:
38    def setup_task_id(self, task_id: int, num_classes_t: int) -> None:
39        r"""Create the output head when task `task_id` arrives if there's no. This must be done before `forward()` is called.
40
41        **Args:**
42        - **task_id** (`int`): the target task ID.
43        - **num_classes_t** (`int`): the number of classes in the task.
44        """
45        self.task_id = task_id
46        if self.task_id not in self.heads.keys():
47            self.head_t = nn.Linear(self.input_dim, num_classes_t)
48            self.heads[f"{self.task_id}"] = self.head_t

Create the output head when task task_id arrives if there's no. This must be done before forward() is called.

Args:

  • task_id (int): the target task ID.
  • num_classes_t (int): the number of classes in the task.
def get_head(self, task_id: int) -> torch.nn.modules.linear.Linear:
50    def get_head(self, task_id: int) -> nn.Linear:
51        r"""Get the output head for task `task_id`.
52
53        **Args:**
54        - **task_id** (`int`): the target task ID.
55
56        **Returns:**
57        - **head_t** (`nn.Linear`): the output head for task `task_id`.
58        """
59        return self.heads[f"{task_id}"]

Get the output head for task task_id.

Args:

  • task_id (int): the target task ID.

Returns:

  • head_t (nn.Linear): the output head for task task_id.
def forward(self, feature: torch.Tensor, task_id: int) -> torch.Tensor:
61    def forward(self, feature: Tensor, task_id: int) -> Tensor:
62        r"""The forward pass for data from task `task_id`. A head is selected according to the task_id and the feature is passed through the head.
63
64        **Args:**
65        - **feature** (`Tensor`): the feature tensor from the backbone network.
66        - **task_id** (`int`): the task ID where the data are from, which is provided by task-incremental setting.
67
68        **Returns:**
69        - **logits** (`Tensor`): the output logits tensor.
70        """
71
72        head_t = self.get_head(task_id)
73        logits = head_t(feature)
74
75        return logits

The forward pass for data from task task_id. A head is selected according to the task_id and the feature is passed through the head.

Args:

  • feature (Tensor): the feature tensor from the backbone network.
  • task_id (int): the task ID where the data are from, which is provided by task-incremental setting.

Returns:

  • logits (Tensor): the output logits tensor.
class HeadsCIL(torch.nn.modules.module.Module):
15class HeadsCIL(nn.Module):
16    r"""The output heads for Class-Incremental Learning (CIL). Head of all classes from CIL tasks takes the output from backbone network and forwards it into logits for predicting classes of all tasks."""
17
18    def __init__(self, input_dim: int) -> None:
19        """Initializes a CIL heads object with no heads.
20
21        **Args:**
22        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
23        """
24        super().__init__()
25
26        self.heads: nn.ModuleDict = nn.ModuleDict()  # initially no heads
27        """CIL output heads are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries. """
28
29        self.input_dim: int = input_dim
30        """The input dimension of the heads. Used when creating new heads."""
31
32        self.task_id: int
33        """Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1. """
34        self.processed_task_ids: list[int] = []
35        r"""Task IDs that have been processed."""
36
37    def setup_task_id(self, task_id: int, num_classes_t: int) -> None:
38        """Create the output head when task `task_id` arrives if there's no. This must be done before `forward()` is called.
39
40        **Args:**
41        - **task_id** (`int`): the target task ID.
42        - **num_classes_t** (`int`): the number of classes in the task.
43        """
44        self.task_id = task_id
45        self.processed_task_ids.append(task_id)
46        if self.task_id not in self.heads.keys():
47            self.heads[f"{self.task_id}"] = nn.Linear(self.input_dim, num_classes_t)
48
49    def forward(self, feature: Tensor, task_id: int | None = None) -> Tensor:
50        r"""The forward pass for data. The information of which `task_id` the data are from is not provided. The head for all classes is selected and the feature is passed.
51
52        **Args:**
53        - **feature** (`Tensor`): the feature tensor from the backbone network.
54        - **task_id** (`int` or `None`): the task ID where the data are from. In CIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
55
56        **Returns:**
57        - **logits** (`Tensor`): the output logits tensor.
58        """
59        logits = torch.cat(
60            [self.heads[f"{t}"](feature) for t in self.processed_task_ids], dim=-1
61        )  # concatenate logits of classes from all heads
62
63        return logits

The output heads for Class-Incremental Learning (CIL). Head of all classes from CIL tasks takes the output from backbone network and forwards it into logits for predicting classes of all tasks.

HeadsCIL(input_dim: int)
18    def __init__(self, input_dim: int) -> None:
19        """Initializes a CIL heads object with no heads.
20
21        **Args:**
22        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
23        """
24        super().__init__()
25
26        self.heads: nn.ModuleDict = nn.ModuleDict()  # initially no heads
27        """CIL output heads are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries. """
28
29        self.input_dim: int = input_dim
30        """The input dimension of the heads. Used when creating new heads."""
31
32        self.task_id: int
33        """Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1. """
34        self.processed_task_ids: list[int] = []
35        r"""Task IDs that have been processed."""

Initializes a CIL heads object with no heads.

Args:

  • input_dim (int): the input dimension of the heads. Must be equal to the output_dim of the connected backbone.
heads: torch.nn.modules.container.ModuleDict

CIL output heads are stored in a ModuleDict. Keys are task IDs and values are the corresponding nn.Linear heads. We use ModuleDict rather than dict to make sure LightningModule can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.

input_dim: int

The input dimension of the heads. Used when creating new heads.

task_id: int

Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1.

processed_task_ids: list[int]

Task IDs that have been processed.

def setup_task_id(self, task_id: int, num_classes_t: int) -> None:
37    def setup_task_id(self, task_id: int, num_classes_t: int) -> None:
38        """Create the output head when task `task_id` arrives if there's no. This must be done before `forward()` is called.
39
40        **Args:**
41        - **task_id** (`int`): the target task ID.
42        - **num_classes_t** (`int`): the number of classes in the task.
43        """
44        self.task_id = task_id
45        self.processed_task_ids.append(task_id)
46        if self.task_id not in self.heads.keys():
47            self.heads[f"{self.task_id}"] = nn.Linear(self.input_dim, num_classes_t)

Create the output head when task task_id arrives if there's no. This must be done before forward() is called.

Args:

  • task_id (int): the target task ID.
  • num_classes_t (int): the number of classes in the task.
def forward(self, feature: torch.Tensor, task_id: int | None = None) -> torch.Tensor:
49    def forward(self, feature: Tensor, task_id: int | None = None) -> Tensor:
50        r"""The forward pass for data. The information of which `task_id` the data are from is not provided. The head for all classes is selected and the feature is passed.
51
52        **Args:**
53        - **feature** (`Tensor`): the feature tensor from the backbone network.
54        - **task_id** (`int` or `None`): the task ID where the data are from. In CIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
55
56        **Returns:**
57        - **logits** (`Tensor`): the output logits tensor.
58        """
59        logits = torch.cat(
60            [self.heads[f"{t}"](feature) for t in self.processed_task_ids], dim=-1
61        )  # concatenate logits of classes from all heads
62
63        return logits

The forward pass for data. The information of which task_id the data are from is not provided. The head for all classes is selected and the feature is passed.

Args:

  • feature (Tensor): the feature tensor from the backbone network.
  • task_id (int or None): the task ID where the data are from. In CIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.

Returns:

  • logits (Tensor): the output logits tensor.
class HeadDIL(torch.nn.modules.module.Module):
14class HeadDIL(nn.Module):
15    r"""The output head for Domain-Incremental Learning (DIL)."""
16
17    def __init__(self, input_dim: int) -> None:
18        r"""Initializes DIL head object.
19
20        **Args:**
21        - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone.
22        """
23        super().__init__()
24
25        self.head: nn.Linear = None
26        r"""DIL output head. """
27
28        self.input_dim: int = input_dim
29        r"""Store the input dimension of the head. Used when creating new head."""
30
31        self._if_head_setup: bool = False
32        r"""Flag indicating whether the head has been set up."""
33
34    def if_head_setup(self) -> bool:
35        r"""Check whether the head has been set up.
36
37        **Returns:**
38        - **if_head_setup** (`bool`): whether the head has been set up.
39        """
40        return self._if_head_setup
41
42    def get_head(self, task_id: int | None = None) -> nn.Linear:
43        r"""Get the output head for DIL.
44
45        **Args:**
46        - **task_id** (`int` or `None`): the task ID where the data are from. This does not matter at all for DIL head as there is only one head for all tasks, so it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
47
48        **Returns:**
49        - **head** (`nn.Linear`): the output head for DIL.
50        """
51        return self.head
52
53    def setup_task(self, num_classes: dict[int, int]) -> None:
54        r"""Create the output head. This must be done before `forward()` is called.
55
56        **Args:**
57        - **num_classes** (`int`): the number of classes in the task.
58        """
59        self.head = nn.Linear(self.input_dim, num_classes)
60        self._if_head_setup = True
61
62    def forward(self, feature: Tensor, task_id: int | None = None) -> Tensor:
63        r"""The forward pass for data. The information of which `task_id` the data are from is not provided.
64
65        **Args:**
66        - **feature** (`Tensor`): the feature tensor from the backbone network.
67        - **task_id** (`int` or `None`): the task ID where the data are from. In DIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
68
69        **Returns:**
70        - **logits** (`Tensor`): the output logits tensor.
71        """
72        logits = self.head(feature)
73
74        return logits

The output head for Domain-Incremental Learning (DIL).

HeadDIL(input_dim: int)
17    def __init__(self, input_dim: int) -> None:
18        r"""Initializes DIL head object.
19
20        **Args:**
21        - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone.
22        """
23        super().__init__()
24
25        self.head: nn.Linear = None
26        r"""DIL output head. """
27
28        self.input_dim: int = input_dim
29        r"""Store the input dimension of the head. Used when creating new head."""
30
31        self._if_head_setup: bool = False
32        r"""Flag indicating whether the head has been set up."""

Initializes DIL head object.

Args:

  • input_dim (int): the input dimension of the head. Must be equal to the output_dim of the connected backbone.
head: torch.nn.modules.linear.Linear

DIL output head.

input_dim: int

Store the input dimension of the head. Used when creating new head.

def if_head_setup(self) -> bool:
34    def if_head_setup(self) -> bool:
35        r"""Check whether the head has been set up.
36
37        **Returns:**
38        - **if_head_setup** (`bool`): whether the head has been set up.
39        """
40        return self._if_head_setup

Check whether the head has been set up.

Returns:

  • if_head_setup (bool): whether the head has been set up.
def get_head(self, task_id: int | None = None) -> torch.nn.modules.linear.Linear:
42    def get_head(self, task_id: int | None = None) -> nn.Linear:
43        r"""Get the output head for DIL.
44
45        **Args:**
46        - **task_id** (`int` or `None`): the task ID where the data are from. This does not matter at all for DIL head as there is only one head for all tasks, so it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
47
48        **Returns:**
49        - **head** (`nn.Linear`): the output head for DIL.
50        """
51        return self.head

Get the output head for DIL.

Args:

  • task_id (int or None): the task ID where the data are from. This does not matter at all for DIL head as there is only one head for all tasks, so it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.

Returns:

  • head (nn.Linear): the output head for DIL.
def setup_task(self, num_classes: dict[int, int]) -> None:
53    def setup_task(self, num_classes: dict[int, int]) -> None:
54        r"""Create the output head. This must be done before `forward()` is called.
55
56        **Args:**
57        - **num_classes** (`int`): the number of classes in the task.
58        """
59        self.head = nn.Linear(self.input_dim, num_classes)
60        self._if_head_setup = True

Create the output head. This must be done before forward() is called.

Args:

  • num_classes (int): the number of classes in the task.
def forward(self, feature: torch.Tensor, task_id: int | None = None) -> torch.Tensor:
62    def forward(self, feature: Tensor, task_id: int | None = None) -> Tensor:
63        r"""The forward pass for data. The information of which `task_id` the data are from is not provided.
64
65        **Args:**
66        - **feature** (`Tensor`): the feature tensor from the backbone network.
67        - **task_id** (`int` or `None`): the task ID where the data are from. In DIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
68
69        **Returns:**
70        - **logits** (`Tensor`): the output logits tensor.
71        """
72        logits = self.head(feature)
73
74        return logits

The forward pass for data. The information of which task_id the data are from is not provided.

Args:

  • feature (Tensor): the feature tensor from the backbone network.
  • task_id (int or None): the task ID where the data are from. In DIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.

Returns:

  • logits (Tensor): the output logits tensor.
class HeadsMTL(torch.nn.modules.module.Module):
15class HeadsMTL(nn.Module):
16    r"""The output heads for Multi-Task Learning (MTL). Independent head assigned to each task takes the output from backbone network and forwards it into logits for predicting classes of the task."""
17
18    def __init__(self, input_dim: int) -> None:
19        r"""Initializes MTL heads object.
20
21        **Args:**
22        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
23        """
24        super().__init__()
25
26        self.heads: nn.ModuleDict = nn.ModuleDict()
27        r"""MTL output heads are stored independently in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.
28        
29        Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model. """
30
31        self.input_dim: int = input_dim
32        r"""Store the input dimension of the heads. Used when creating new heads."""
33
34    def setup_tasks(self, task_ids: list[int], num_classes: dict[int, int]) -> None:
35        r"""Create the output heads. This must be done before `forward()` is called.
36
37        **Args:**
38        - **task_id** (`list[int]`): the target task IDs.
39        - **num_classes** (`dict[int, int]`): the number of classes in each task. Keys are task IDs and values are the number of classes for the corresponding task.
40        """
41        for task_id in task_ids:
42            self.heads[f"{task_id}"] = nn.Linear(self.input_dim, num_classes[task_id])
43
44    def get_head(self, task_id: int) -> nn.Linear:
45        r"""Get the output head for task `task_id`.
46
47        **Args:**
48        - **task_id** (`int`): the target task ID.
49
50        **Returns:**
51        - **head_t** (`nn.Linear`): the output head for task `task_id`.
52        """
53        return self.heads[f"{task_id}"]
54
55    def forward(self, feature: Tensor, task_ids: int | Tensor) -> Tensor:
56        r"""The forward pass for data from task `task_id`. A head is selected according to the task_id and the feature is passed through the head.
57
58        **Args:**
59        - **feature** (`Tensor`): the feature tensor from the backbone network.
60        - **task_ids** (`int` | `Tensor`): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer.
61
62        **Returns:**
63        - **logits** (`Tensor`): the output logits tensor.
64        """
65
66        if isinstance(task_ids, int):
67            head_t = self.get_head(task_ids)
68            logits = head_t(feature)
69
70        elif isinstance(task_ids, Tensor):
71            logits_list = []
72            for task_id in torch.unique(task_ids):  # for each unique task in the batch
73                idx = (task_ids == task_id).nonzero(as_tuple=True)[
74                    0
75                ]  # indices of the current task in the batch
76                features_t = feature[idx]  # get the features for the current task
77                head_t = self.get_head(task_id.item())
78                logits_t = head_t(features_t)
79                logits_list.append((idx, logits_t))
80
81            # reconstruct logits tensor in the order of task_ids
82            logits = torch.zeros(len(task_ids), logits_t.size(1), device=feature.device)
83            for idx, logits_t in logits_list:
84                logits[idx] = logits_t
85
86        return logits

The output heads for Multi-Task Learning (MTL). Independent head assigned to each task takes the output from backbone network and forwards it into logits for predicting classes of the task.

HeadsMTL(input_dim: int)
18    def __init__(self, input_dim: int) -> None:
19        r"""Initializes MTL heads object.
20
21        **Args:**
22        - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone.
23        """
24        super().__init__()
25
26        self.heads: nn.ModuleDict = nn.ModuleDict()
27        r"""MTL output heads are stored independently in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.
28        
29        Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model. """
30
31        self.input_dim: int = input_dim
32        r"""Store the input dimension of the heads. Used when creating new heads."""

Initializes MTL heads object.

Args:

  • input_dim (int): the input dimension of the heads. Must be equal to the output_dim of the connected backbone.
heads: torch.nn.modules.container.ModuleDict

MTL output heads are stored independently in a ModuleDict. Keys are task IDs and values are the corresponding nn.Linear heads. We use ModuleDict rather than dict to make sure LightningModule can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.

Note that the task IDs must be string type in order to let LightningModule identify this part of the model.

input_dim: int

Store the input dimension of the heads. Used when creating new heads.

def setup_tasks(self, task_ids: list[int], num_classes: dict[int, int]) -> None:
34    def setup_tasks(self, task_ids: list[int], num_classes: dict[int, int]) -> None:
35        r"""Create the output heads. This must be done before `forward()` is called.
36
37        **Args:**
38        - **task_id** (`list[int]`): the target task IDs.
39        - **num_classes** (`dict[int, int]`): the number of classes in each task. Keys are task IDs and values are the number of classes for the corresponding task.
40        """
41        for task_id in task_ids:
42            self.heads[f"{task_id}"] = nn.Linear(self.input_dim, num_classes[task_id])

Create the output heads. This must be done before forward() is called.

Args:

  • task_id (list[int]): the target task IDs.
  • num_classes (dict[int, int]): the number of classes in each task. Keys are task IDs and values are the number of classes for the corresponding task.
def get_head(self, task_id: int) -> torch.nn.modules.linear.Linear:
44    def get_head(self, task_id: int) -> nn.Linear:
45        r"""Get the output head for task `task_id`.
46
47        **Args:**
48        - **task_id** (`int`): the target task ID.
49
50        **Returns:**
51        - **head_t** (`nn.Linear`): the output head for task `task_id`.
52        """
53        return self.heads[f"{task_id}"]

Get the output head for task task_id.

Args:

  • task_id (int): the target task ID.

Returns:

  • head_t (nn.Linear): the output head for task task_id.
def forward( self, feature: torch.Tensor, task_ids: int | torch.Tensor) -> torch.Tensor:
55    def forward(self, feature: Tensor, task_ids: int | Tensor) -> Tensor:
56        r"""The forward pass for data from task `task_id`. A head is selected according to the task_id and the feature is passed through the head.
57
58        **Args:**
59        - **feature** (`Tensor`): the feature tensor from the backbone network.
60        - **task_ids** (`int` | `Tensor`): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer.
61
62        **Returns:**
63        - **logits** (`Tensor`): the output logits tensor.
64        """
65
66        if isinstance(task_ids, int):
67            head_t = self.get_head(task_ids)
68            logits = head_t(feature)
69
70        elif isinstance(task_ids, Tensor):
71            logits_list = []
72            for task_id in torch.unique(task_ids):  # for each unique task in the batch
73                idx = (task_ids == task_id).nonzero(as_tuple=True)[
74                    0
75                ]  # indices of the current task in the batch
76                features_t = feature[idx]  # get the features for the current task
77                head_t = self.get_head(task_id.item())
78                logits_t = head_t(features_t)
79                logits_list.append((idx, logits_t))
80
81            # reconstruct logits tensor in the order of task_ids
82            logits = torch.zeros(len(task_ids), logits_t.size(1), device=feature.device)
83            for idx, logits_t in logits_list:
84                logits[idx] = logits_t
85
86        return logits

The forward pass for data from task task_id. A head is selected according to the task_id and the feature is passed through the head.

Args:

  • feature (Tensor): the feature tensor from the backbone network.
  • task_ids (int | Tensor): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer.

Returns:

  • logits (Tensor): the output logits tensor.
class HeadSTL(torch.nn.modules.module.Module):
14class HeadSTL(nn.Module):
15    r"""The output head for Single-Task Learning (STL)."""
16
17    def __init__(self, input_dim: int) -> None:
18        r"""Initializes STL head object.
19
20        **Args:**
21        - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone.
22        """
23        super().__init__()
24
25        self.head: nn.Linear = None
26        r"""STL output head. """
27
28        self.input_dim: int = input_dim
29        r"""Store the input dimension of the head. Used when creating new head."""
30
31    def setup_task(self, num_classes: dict[int, int]) -> None:
32        r"""Create the output head. This must be done before `forward()` is called.
33
34        **Args:**
35        - **num_classes** (`int`): the number of classes in the task.
36        """
37        self.head = nn.Linear(self.input_dim, num_classes)
38
39    def forward(self, feature: Tensor) -> Tensor:
40        r"""The forward pass for data.
41
42        **Args:**
43        - **feature** (`Tensor`): the feature tensor from the backbone network.
44
45        **Returns:**
46        - **logits** (`Tensor`): the output logits tensor.
47        """
48        logits = self.head(feature)
49
50        return logits

The output head for Single-Task Learning (STL).

HeadSTL(input_dim: int)
17    def __init__(self, input_dim: int) -> None:
18        r"""Initializes STL head object.
19
20        **Args:**
21        - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone.
22        """
23        super().__init__()
24
25        self.head: nn.Linear = None
26        r"""STL output head. """
27
28        self.input_dim: int = input_dim
29        r"""Store the input dimension of the head. Used when creating new head."""

Initializes STL head object.

Args:

  • input_dim (int): the input dimension of the head. Must be equal to the output_dim of the connected backbone.
head: torch.nn.modules.linear.Linear

STL output head.

input_dim: int

Store the input dimension of the head. Used when creating new head.

def setup_task(self, num_classes: dict[int, int]) -> None:
31    def setup_task(self, num_classes: dict[int, int]) -> None:
32        r"""Create the output head. This must be done before `forward()` is called.
33
34        **Args:**
35        - **num_classes** (`int`): the number of classes in the task.
36        """
37        self.head = nn.Linear(self.input_dim, num_classes)

Create the output head. This must be done before forward() is called.

Args:

  • num_classes (int): the number of classes in the task.
def forward(self, feature: torch.Tensor) -> torch.Tensor:
39    def forward(self, feature: Tensor) -> Tensor:
40        r"""The forward pass for data.
41
42        **Args:**
43        - **feature** (`Tensor`): the feature tensor from the backbone network.
44
45        **Returns:**
46        - **logits** (`Tensor`): the output logits tensor.
47        """
48        logits = self.head(feature)
49
50        return logits

The forward pass for data.

Args:

  • feature (Tensor): the feature tensor from the backbone network.

Returns:

  • logits (Tensor): the output logits tensor.