clarena.heads
Output Heads
This submodule provides the output heads in CLArena.
There are two types of continual learning / unlearning heads in CLArena: HeadsTIL, HeadsCIL and HeadDIL, corresponding to three CL paradigms respectively: Task-Incremental Learning (TIL), Class-Incremental Learning (CIL) and Domain-Incremental Learning (DIL). For Multi-Task Learning (MTL), we have HeadsMTL which is a collection of independent heads for each task.
Please note that this is an API documantation. Please refer to the main documentation pages for more information about the heads.
1r""" 2 3# Output Heads 4 5This submodule provides the **output heads** in CLArena. 6 7There are two types of continual learning / unlearning heads in CLArena: `HeadsTIL`, `HeadsCIL` and `HeadDIL`, corresponding to three CL paradigms respectively: Task-Incremental Learning (TIL), Class-Incremental Learning (CIL) and Domain-Incremental Learning (DIL). For Multi-Task Learning (MTL), we have `HeadsMTL` which is a collection of independent heads for each task. 8 9Please note that this is an API documantation. Please refer to the main documentation pages for more information about the heads. 10 11- [**Configure CL Paradigm in Experiment Index Config**](https://pengxiang-wang.com/projects/continual-learning-arena/docs/configure-your-experiment/experiment-index-config) 12- [**A Beginners' Guide to Continual Learning (Multi-head Classifier)](https://pengxiang-wang.com/posts/continual-learning-beginners-guide#sec-CL-classification) 13 14""" 15 16from .heads_cil import HeadsCIL 17from .head_dil import HeadDIL 18from .heads_til import HeadsTIL 19 20from .heads_mtl import HeadsMTL 21 22from .head_stl import HeadSTL 23 24__all__ = ["HeadsTIL", "HeadsCIL", "HeadDIL", "HeadsMTL", "HeadSTL"]
14class HeadsTIL(nn.Module): 15 r"""The output heads for Task-Incremental Learning (TIL). Independent head assigned to each TIL task takes the output from backbone network and forwards it into logits for predicting classes of the task.""" 16 17 def __init__(self, input_dim: int) -> None: 18 r"""Initializes TIL heads object with no heads. 19 20 **Args:** 21 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 22 """ 23 super().__init__() 24 25 self.heads: nn.ModuleDict = nn.ModuleDict() # initially no heads 26 r"""TIL output heads are stored independently in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries. 27 28 Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model. """ 29 self.head_t: nn.Linear | None = None 30 r"""The output head for the current task. It is created when the task arrives and stored in `self.heads`.""" 31 32 self.input_dim: int = input_dim 33 r"""Store the input dimension of the heads. Used when creating new heads.""" 34 35 self.task_id: int 36 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1. """ 37 38 def setup_task_id(self, task_id: int, num_classes_t: int) -> None: 39 r"""Create the output head when task `task_id` arrives if there's no. This must be done before `forward()` is called. 40 41 **Args:** 42 - **task_id** (`int`): the target task ID. 43 - **num_classes_t** (`int`): the number of classes in the task. 44 """ 45 self.task_id = task_id 46 if self.task_id not in self.heads.keys(): 47 self.head_t = nn.Linear(self.input_dim, num_classes_t) 48 self.heads[f"{self.task_id}"] = self.head_t 49 50 def get_head(self, task_id: int) -> nn.Linear: 51 r"""Get the output head for task `task_id`. 52 53 **Args:** 54 - **task_id** (`int`): the target task ID. 55 56 **Returns:** 57 - **head_t** (`nn.Linear`): the output head for task `task_id`. 58 """ 59 return self.heads[f"{task_id}"] 60 61 def forward(self, feature: Tensor, task_id: int) -> Tensor: 62 r"""The forward pass for data from task `task_id`. A head is selected according to the task_id and the feature is passed through the head. 63 64 **Args:** 65 - **feature** (`Tensor`): the feature tensor from the backbone network. 66 - **task_id** (`int`): the task ID where the data are from, which is provided by task-incremental setting. 67 68 **Returns:** 69 - **logits** (`Tensor`): the output logits tensor. 70 """ 71 72 head_t = self.get_head(task_id) 73 logits = head_t(feature) 74 75 return logits
The output heads for Task-Incremental Learning (TIL). Independent head assigned to each TIL task takes the output from backbone network and forwards it into logits for predicting classes of the task.
17 def __init__(self, input_dim: int) -> None: 18 r"""Initializes TIL heads object with no heads. 19 20 **Args:** 21 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 22 """ 23 super().__init__() 24 25 self.heads: nn.ModuleDict = nn.ModuleDict() # initially no heads 26 r"""TIL output heads are stored independently in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries. 27 28 Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model. """ 29 self.head_t: nn.Linear | None = None 30 r"""The output head for the current task. It is created when the task arrives and stored in `self.heads`.""" 31 32 self.input_dim: int = input_dim 33 r"""Store the input dimension of the heads. Used when creating new heads.""" 34 35 self.task_id: int 36 r"""Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1. """
Initializes TIL heads object with no heads.
Args:
- input_dim (
int): the input dimension of the heads. Must be equal to theoutput_dimof the connected backbone.
TIL output heads are stored independently in a ModuleDict. Keys are task IDs and values are the corresponding nn.Linear heads. We use ModuleDict rather than dict to make sure LightningModule can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.
Note that the task IDs must be string type in order to let LightningModule identify this part of the model.
The output head for the current task. It is created when the task arrives and stored in self.heads.
Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1.
38 def setup_task_id(self, task_id: int, num_classes_t: int) -> None: 39 r"""Create the output head when task `task_id` arrives if there's no. This must be done before `forward()` is called. 40 41 **Args:** 42 - **task_id** (`int`): the target task ID. 43 - **num_classes_t** (`int`): the number of classes in the task. 44 """ 45 self.task_id = task_id 46 if self.task_id not in self.heads.keys(): 47 self.head_t = nn.Linear(self.input_dim, num_classes_t) 48 self.heads[f"{self.task_id}"] = self.head_t
61 def forward(self, feature: Tensor, task_id: int) -> Tensor: 62 r"""The forward pass for data from task `task_id`. A head is selected according to the task_id and the feature is passed through the head. 63 64 **Args:** 65 - **feature** (`Tensor`): the feature tensor from the backbone network. 66 - **task_id** (`int`): the task ID where the data are from, which is provided by task-incremental setting. 67 68 **Returns:** 69 - **logits** (`Tensor`): the output logits tensor. 70 """ 71 72 head_t = self.get_head(task_id) 73 logits = head_t(feature) 74 75 return logits
The forward pass for data from task task_id. A head is selected according to the task_id and the feature is passed through the head.
Args:
- feature (
Tensor): the feature tensor from the backbone network. - task_id (
int): the task ID where the data are from, which is provided by task-incremental setting.
Returns:
- logits (
Tensor): the output logits tensor.
15class HeadsCIL(nn.Module): 16 r"""The output heads for Class-Incremental Learning (CIL). Head of all classes from CIL tasks takes the output from backbone network and forwards it into logits for predicting classes of all tasks.""" 17 18 def __init__(self, input_dim: int) -> None: 19 """Initializes a CIL heads object with no heads. 20 21 **Args:** 22 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 23 """ 24 super().__init__() 25 26 self.heads: nn.ModuleDict = nn.ModuleDict() # initially no heads 27 """CIL output heads are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries. """ 28 29 self.input_dim: int = input_dim 30 """The input dimension of the heads. Used when creating new heads.""" 31 32 self.task_id: int 33 """Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1. """ 34 self.processed_task_ids: list[int] = [] 35 r"""Task IDs that have been processed.""" 36 37 def setup_task_id(self, task_id: int, num_classes_t: int) -> None: 38 """Create the output head when task `task_id` arrives if there's no. This must be done before `forward()` is called. 39 40 **Args:** 41 - **task_id** (`int`): the target task ID. 42 - **num_classes_t** (`int`): the number of classes in the task. 43 """ 44 self.task_id = task_id 45 self.processed_task_ids.append(task_id) 46 if self.task_id not in self.heads.keys(): 47 self.heads[f"{self.task_id}"] = nn.Linear(self.input_dim, num_classes_t) 48 49 def forward(self, feature: Tensor, task_id: int | None = None) -> Tensor: 50 r"""The forward pass for data. The information of which `task_id` the data are from is not provided. The head for all classes is selected and the feature is passed. 51 52 **Args:** 53 - **feature** (`Tensor`): the feature tensor from the backbone network. 54 - **task_id** (`int` or `None`): the task ID where the data are from. In CIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value. 55 56 **Returns:** 57 - **logits** (`Tensor`): the output logits tensor. 58 """ 59 logits = torch.cat( 60 [self.heads[f"{t}"](feature) for t in self.processed_task_ids], dim=-1 61 ) # concatenate logits of classes from all heads 62 63 return logits
The output heads for Class-Incremental Learning (CIL). Head of all classes from CIL tasks takes the output from backbone network and forwards it into logits for predicting classes of all tasks.
18 def __init__(self, input_dim: int) -> None: 19 """Initializes a CIL heads object with no heads. 20 21 **Args:** 22 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 23 """ 24 super().__init__() 25 26 self.heads: nn.ModuleDict = nn.ModuleDict() # initially no heads 27 """CIL output heads are stored in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries. """ 28 29 self.input_dim: int = input_dim 30 """The input dimension of the heads. Used when creating new heads.""" 31 32 self.task_id: int 33 """Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1. """ 34 self.processed_task_ids: list[int] = [] 35 r"""Task IDs that have been processed."""
Initializes a CIL heads object with no heads.
Args:
- input_dim (
int): the input dimension of the heads. Must be equal to theoutput_dimof the connected backbone.
CIL output heads are stored in a ModuleDict. Keys are task IDs and values are the corresponding nn.Linear heads. We use ModuleDict rather than dict to make sure LightningModule can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.
Task ID counter indicating which task is being processed. Self updated during the task loop. Starting from 1.
37 def setup_task_id(self, task_id: int, num_classes_t: int) -> None: 38 """Create the output head when task `task_id` arrives if there's no. This must be done before `forward()` is called. 39 40 **Args:** 41 - **task_id** (`int`): the target task ID. 42 - **num_classes_t** (`int`): the number of classes in the task. 43 """ 44 self.task_id = task_id 45 self.processed_task_ids.append(task_id) 46 if self.task_id not in self.heads.keys(): 47 self.heads[f"{self.task_id}"] = nn.Linear(self.input_dim, num_classes_t)
49 def forward(self, feature: Tensor, task_id: int | None = None) -> Tensor: 50 r"""The forward pass for data. The information of which `task_id` the data are from is not provided. The head for all classes is selected and the feature is passed. 51 52 **Args:** 53 - **feature** (`Tensor`): the feature tensor from the backbone network. 54 - **task_id** (`int` or `None`): the task ID where the data are from. In CIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value. 55 56 **Returns:** 57 - **logits** (`Tensor`): the output logits tensor. 58 """ 59 logits = torch.cat( 60 [self.heads[f"{t}"](feature) for t in self.processed_task_ids], dim=-1 61 ) # concatenate logits of classes from all heads 62 63 return logits
The forward pass for data. The information of which task_id the data are from is not provided. The head for all classes is selected and the feature is passed.
Args:
- feature (
Tensor): the feature tensor from the backbone network. - task_id (
intorNone): the task ID where the data are from. In CIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
Returns:
- logits (
Tensor): the output logits tensor.
14class HeadDIL(nn.Module): 15 r"""The output head for Domain-Incremental Learning (DIL).""" 16 17 def __init__(self, input_dim: int) -> None: 18 r"""Initializes DIL head object. 19 20 **Args:** 21 - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone. 22 """ 23 super().__init__() 24 25 self.head: nn.Linear = None 26 r"""DIL output head. """ 27 28 self.input_dim: int = input_dim 29 r"""Store the input dimension of the head. Used when creating new head.""" 30 31 self._if_head_setup: bool = False 32 r"""Flag indicating whether the head has been set up.""" 33 34 def if_head_setup(self) -> bool: 35 r"""Check whether the head has been set up. 36 37 **Returns:** 38 - **if_head_setup** (`bool`): whether the head has been set up. 39 """ 40 return self._if_head_setup 41 42 def get_head(self, task_id: int | None = None) -> nn.Linear: 43 r"""Get the output head for DIL. 44 45 **Args:** 46 - **task_id** (`int` or `None`): the task ID where the data are from. This does not matter at all for DIL head as there is only one head for all tasks, so it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value. 47 48 **Returns:** 49 - **head** (`nn.Linear`): the output head for DIL. 50 """ 51 return self.head 52 53 def setup_task(self, num_classes: dict[int, int]) -> None: 54 r"""Create the output head. This must be done before `forward()` is called. 55 56 **Args:** 57 - **num_classes** (`int`): the number of classes in the task. 58 """ 59 self.head = nn.Linear(self.input_dim, num_classes) 60 self._if_head_setup = True 61 62 def forward(self, feature: Tensor, task_id: int | None = None) -> Tensor: 63 r"""The forward pass for data. The information of which `task_id` the data are from is not provided. 64 65 **Args:** 66 - **feature** (`Tensor`): the feature tensor from the backbone network. 67 - **task_id** (`int` or `None`): the task ID where the data are from. In DIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value. 68 69 **Returns:** 70 - **logits** (`Tensor`): the output logits tensor. 71 """ 72 logits = self.head(feature) 73 74 return logits
The output head for Domain-Incremental Learning (DIL).
17 def __init__(self, input_dim: int) -> None: 18 r"""Initializes DIL head object. 19 20 **Args:** 21 - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone. 22 """ 23 super().__init__() 24 25 self.head: nn.Linear = None 26 r"""DIL output head. """ 27 28 self.input_dim: int = input_dim 29 r"""Store the input dimension of the head. Used when creating new head.""" 30 31 self._if_head_setup: bool = False 32 r"""Flag indicating whether the head has been set up."""
Initializes DIL head object.
Args:
- input_dim (
int): the input dimension of the head. Must be equal to theoutput_dimof the connected backbone.
34 def if_head_setup(self) -> bool: 35 r"""Check whether the head has been set up. 36 37 **Returns:** 38 - **if_head_setup** (`bool`): whether the head has been set up. 39 """ 40 return self._if_head_setup
Check whether the head has been set up.
Returns:
- if_head_setup (
bool): whether the head has been set up.
42 def get_head(self, task_id: int | None = None) -> nn.Linear: 43 r"""Get the output head for DIL. 44 45 **Args:** 46 - **task_id** (`int` or `None`): the task ID where the data are from. This does not matter at all for DIL head as there is only one head for all tasks, so it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value. 47 48 **Returns:** 49 - **head** (`nn.Linear`): the output head for DIL. 50 """ 51 return self.head
Get the output head for DIL.
Args:
- task_id (
intorNone): the task ID where the data are from. This does not matter at all for DIL head as there is only one head for all tasks, so it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
Returns:
- head (
nn.Linear): the output head for DIL.
53 def setup_task(self, num_classes: dict[int, int]) -> None: 54 r"""Create the output head. This must be done before `forward()` is called. 55 56 **Args:** 57 - **num_classes** (`int`): the number of classes in the task. 58 """ 59 self.head = nn.Linear(self.input_dim, num_classes) 60 self._if_head_setup = True
Create the output head. This must be done before forward() is called.
Args:
- num_classes (
int): the number of classes in the task.
62 def forward(self, feature: Tensor, task_id: int | None = None) -> Tensor: 63 r"""The forward pass for data. The information of which `task_id` the data are from is not provided. 64 65 **Args:** 66 - **feature** (`Tensor`): the feature tensor from the backbone network. 67 - **task_id** (`int` or `None`): the task ID where the data are from. In DIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value. 68 69 **Returns:** 70 - **logits** (`Tensor`): the output logits tensor. 71 """ 72 logits = self.head(feature) 73 74 return logits
The forward pass for data. The information of which task_id the data are from is not provided.
Args:
- feature (
Tensor): the feature tensor from the backbone network. - task_id (
intorNone): the task ID where the data are from. In DIL, it is just a placeholder for API consistence with the TIL heads but never used. Best practices are not to provide this argument and leave it as the default value.
Returns:
- logits (
Tensor): the output logits tensor.
15class HeadsMTL(nn.Module): 16 r"""The output heads for Multi-Task Learning (MTL). Independent head assigned to each task takes the output from backbone network and forwards it into logits for predicting classes of the task.""" 17 18 def __init__(self, input_dim: int) -> None: 19 r"""Initializes MTL heads object. 20 21 **Args:** 22 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 23 """ 24 super().__init__() 25 26 self.heads: nn.ModuleDict = nn.ModuleDict() 27 r"""MTL output heads are stored independently in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries. 28 29 Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model. """ 30 31 self.input_dim: int = input_dim 32 r"""Store the input dimension of the heads. Used when creating new heads.""" 33 34 def setup_tasks(self, task_ids: list[int], num_classes: dict[int, int]) -> None: 35 r"""Create the output heads. This must be done before `forward()` is called. 36 37 **Args:** 38 - **task_id** (`list[int]`): the target task IDs. 39 - **num_classes** (`dict[int, int]`): the number of classes in each task. Keys are task IDs and values are the number of classes for the corresponding task. 40 """ 41 for task_id in task_ids: 42 self.heads[f"{task_id}"] = nn.Linear(self.input_dim, num_classes[task_id]) 43 44 def get_head(self, task_id: int) -> nn.Linear: 45 r"""Get the output head for task `task_id`. 46 47 **Args:** 48 - **task_id** (`int`): the target task ID. 49 50 **Returns:** 51 - **head_t** (`nn.Linear`): the output head for task `task_id`. 52 """ 53 return self.heads[f"{task_id}"] 54 55 def forward(self, feature: Tensor, task_ids: int | Tensor) -> Tensor: 56 r"""The forward pass for data from task `task_id`. A head is selected according to the task_id and the feature is passed through the head. 57 58 **Args:** 59 - **feature** (`Tensor`): the feature tensor from the backbone network. 60 - **task_ids** (`int` | `Tensor`): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer. 61 62 **Returns:** 63 - **logits** (`Tensor`): the output logits tensor. 64 """ 65 66 if isinstance(task_ids, int): 67 head_t = self.get_head(task_ids) 68 logits = head_t(feature) 69 70 elif isinstance(task_ids, Tensor): 71 logits_list = [] 72 for task_id in torch.unique(task_ids): # for each unique task in the batch 73 idx = (task_ids == task_id).nonzero(as_tuple=True)[ 74 0 75 ] # indices of the current task in the batch 76 features_t = feature[idx] # get the features for the current task 77 head_t = self.get_head(task_id.item()) 78 logits_t = head_t(features_t) 79 logits_list.append((idx, logits_t)) 80 81 # reconstruct logits tensor in the order of task_ids 82 logits = torch.zeros(len(task_ids), logits_t.size(1), device=feature.device) 83 for idx, logits_t in logits_list: 84 logits[idx] = logits_t 85 86 return logits
The output heads for Multi-Task Learning (MTL). Independent head assigned to each task takes the output from backbone network and forwards it into logits for predicting classes of the task.
18 def __init__(self, input_dim: int) -> None: 19 r"""Initializes MTL heads object. 20 21 **Args:** 22 - **input_dim** (`int`): the input dimension of the heads. Must be equal to the `output_dim` of the connected backbone. 23 """ 24 super().__init__() 25 26 self.heads: nn.ModuleDict = nn.ModuleDict() 27 r"""MTL output heads are stored independently in a `ModuleDict`. Keys are task IDs and values are the corresponding `nn.Linear` heads. We use `ModuleDict` rather than `dict` to make sure `LightningModule` can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries. 28 29 Note that the task IDs must be string type in order to let `LightningModule` identify this part of the model. """ 30 31 self.input_dim: int = input_dim 32 r"""Store the input dimension of the heads. Used when creating new heads."""
Initializes MTL heads object.
Args:
- input_dim (
int): the input dimension of the heads. Must be equal to theoutput_dimof the connected backbone.
MTL output heads are stored independently in a ModuleDict. Keys are task IDs and values are the corresponding nn.Linear heads. We use ModuleDict rather than dict to make sure LightningModule can track these model parameters for the purpose of, such as automatically to device, recorded in model summaries.
Note that the task IDs must be string type in order to let LightningModule identify this part of the model.
34 def setup_tasks(self, task_ids: list[int], num_classes: dict[int, int]) -> None: 35 r"""Create the output heads. This must be done before `forward()` is called. 36 37 **Args:** 38 - **task_id** (`list[int]`): the target task IDs. 39 - **num_classes** (`dict[int, int]`): the number of classes in each task. Keys are task IDs and values are the number of classes for the corresponding task. 40 """ 41 for task_id in task_ids: 42 self.heads[f"{task_id}"] = nn.Linear(self.input_dim, num_classes[task_id])
Create the output heads. This must be done before forward() is called.
Args:
- task_id (
list[int]): the target task IDs. - num_classes (
dict[int, int]): the number of classes in each task. Keys are task IDs and values are the number of classes for the corresponding task.
44 def get_head(self, task_id: int) -> nn.Linear: 45 r"""Get the output head for task `task_id`. 46 47 **Args:** 48 - **task_id** (`int`): the target task ID. 49 50 **Returns:** 51 - **head_t** (`nn.Linear`): the output head for task `task_id`. 52 """ 53 return self.heads[f"{task_id}"]
Get the output head for task task_id.
Args:
- task_id (
int): the target task ID.
Returns:
- head_t (
nn.Linear): the output head for tasktask_id.
55 def forward(self, feature: Tensor, task_ids: int | Tensor) -> Tensor: 56 r"""The forward pass for data from task `task_id`. A head is selected according to the task_id and the feature is passed through the head. 57 58 **Args:** 59 - **feature** (`Tensor`): the feature tensor from the backbone network. 60 - **task_ids** (`int` | `Tensor`): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer. 61 62 **Returns:** 63 - **logits** (`Tensor`): the output logits tensor. 64 """ 65 66 if isinstance(task_ids, int): 67 head_t = self.get_head(task_ids) 68 logits = head_t(feature) 69 70 elif isinstance(task_ids, Tensor): 71 logits_list = [] 72 for task_id in torch.unique(task_ids): # for each unique task in the batch 73 idx = (task_ids == task_id).nonzero(as_tuple=True)[ 74 0 75 ] # indices of the current task in the batch 76 features_t = feature[idx] # get the features for the current task 77 head_t = self.get_head(task_id.item()) 78 logits_t = head_t(features_t) 79 logits_list.append((idx, logits_t)) 80 81 # reconstruct logits tensor in the order of task_ids 82 logits = torch.zeros(len(task_ids), logits_t.size(1), device=feature.device) 83 for idx, logits_t in logits_list: 84 logits[idx] = logits_t 85 86 return logits
The forward pass for data from task task_id. A head is selected according to the task_id and the feature is passed through the head.
Args:
- feature (
Tensor): the feature tensor from the backbone network. - task_ids (
int|Tensor): the task ID(s) for the input data. If the input batch is from the same task, this can be a single integer.
Returns:
- logits (
Tensor): the output logits tensor.
14class HeadSTL(nn.Module): 15 r"""The output head for Single-Task Learning (STL).""" 16 17 def __init__(self, input_dim: int) -> None: 18 r"""Initializes STL head object. 19 20 **Args:** 21 - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone. 22 """ 23 super().__init__() 24 25 self.head: nn.Linear = None 26 r"""STL output head. """ 27 28 self.input_dim: int = input_dim 29 r"""Store the input dimension of the head. Used when creating new head.""" 30 31 def setup_task(self, num_classes: dict[int, int]) -> None: 32 r"""Create the output head. This must be done before `forward()` is called. 33 34 **Args:** 35 - **num_classes** (`int`): the number of classes in the task. 36 """ 37 self.head = nn.Linear(self.input_dim, num_classes) 38 39 def forward(self, feature: Tensor) -> Tensor: 40 r"""The forward pass for data. 41 42 **Args:** 43 - **feature** (`Tensor`): the feature tensor from the backbone network. 44 45 **Returns:** 46 - **logits** (`Tensor`): the output logits tensor. 47 """ 48 logits = self.head(feature) 49 50 return logits
The output head for Single-Task Learning (STL).
17 def __init__(self, input_dim: int) -> None: 18 r"""Initializes STL head object. 19 20 **Args:** 21 - **input_dim** (`int`): the input dimension of the head. Must be equal to the `output_dim` of the connected backbone. 22 """ 23 super().__init__() 24 25 self.head: nn.Linear = None 26 r"""STL output head. """ 27 28 self.input_dim: int = input_dim 29 r"""Store the input dimension of the head. Used when creating new head."""
Initializes STL head object.
Args:
- input_dim (
int): the input dimension of the head. Must be equal to theoutput_dimof the connected backbone.
31 def setup_task(self, num_classes: dict[int, int]) -> None: 32 r"""Create the output head. This must be done before `forward()` is called. 33 34 **Args:** 35 - **num_classes** (`int`): the number of classes in the task. 36 """ 37 self.head = nn.Linear(self.input_dim, num_classes)
Create the output head. This must be done before forward() is called.
Args:
- num_classes (
int): the number of classes in the task.
39 def forward(self, feature: Tensor) -> Tensor: 40 r"""The forward pass for data. 41 42 **Args:** 43 - **feature** (`Tensor`): the feature tensor from the backbone network. 44 45 **Returns:** 46 - **logits** (`Tensor`): the output logits tensor. 47 """ 48 logits = self.head(feature) 49 50 return logits
The forward pass for data.
Args:
- feature (
Tensor): the feature tensor from the backbone network.
Returns:
- logits (
Tensor): the output logits tensor.