clarena.cl_algorithms.cbp

The submodule in cl_algorithms for CBP (Continual Backpropagation) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [CBP (Continual Backpropagation)](https://www.nature.com/articles/s41586-024-07711-7) algorithm.
  3"""
  4
  5__all__ = ["CBP"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12
 13from clarena.backbones import CLBackbone
 14from clarena.cl_algorithms import Finetuning
 15from clarena.cl_heads import HeadsCIL, HeadsTIL
 16from clarena.utils.transforms import min_max_normalise
 17
 18# always get logger for built-in logging in each module
 19pylogger = logging.getLogger(__name__)
 20
 21
 22class CBP(Finetuning):
 23    r"""CBP (Continual Backpropagation) algorithm.
 24
 25    [CBP (Continual Backpropagation, 2024)](https://www.nature.com/articles/s41586-024-07711-7) is a continual learning approach that reinitialises a small number of units during training, using an utility measures to determine which units to reinitialise. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning.
 26
 27    We implement CBP as a subclass of Finetuning algorithm, as CBP has the same `forward()`, `training_step()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 28    """
 29
 30    def __init__(
 31        self,
 32        backbone: CLBackbone,
 33        heads: HeadsTIL | HeadsCIL,
 34        replacement_rate: float,
 35        maturity_threshold: int,
 36        utility_decay_rate: float,
 37    ) -> None:
 38        r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
 39
 40        **Args:**
 41        - **backbone** (`CLBackbone`): backbone network.
 42        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 43        - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialised during training.
 44        - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialised.
 45        - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
 46        """
 47        Finetuning.__init__(self, backbone=backbone, heads=heads)
 48
 49        self.replacement_rate: float = replacement_rate
 50        r"""Store the replacement rate of units. """
 51        self.maturity_threshold: int = maturity_threshold
 52        r"""Store the maturity threshold of units. """
 53        self.utility_decay_rate: float = utility_decay_rate
 54        r"""Store the utility decay rate of units. """
 55
 56        self.contribution_utility: dict[str, Tensor] = {}
 57        r"""Store the contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units). """
 58        self.num_replacements: dict[str, int] = {}
 59        r"""Store the number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """
 60        self.age: dict[str, Tensor] = {}
 61        r"""Store the age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """
 62
 63    def on_train_start(self) -> None:
 64        r"""Initialise the utility, number of replacements and age for each layer as zeros."""
 65
 66        # initialise the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
 67        if self.task_id == 1:
 68            for layer_name in self.backbone.weighted_layer_names:
 69                layer = self.backbone.get_layer_by_name(
 70                    layer_name
 71                )  # get the layer by its name
 72                num_units = layer.weight.shape[0]
 73
 74                self.contribution_utility[layer_name] = torch.zeros(num_units).to(
 75                    self.device
 76                )
 77                self.num_replacements[layer_name] = 0
 78                self.age[layer_name] = torch.zeros(num_units).to(self.device)
 79
 80    def on_train_batch_end(
 81        self, outputs: dict[str, Any], batch: Any, batch_idx: int
 82    ) -> None:
 83        r"""Update the contribution utility and age of units after each training step, and conduct reinitialisation of units based on utility measures. This is the core of the CBP algorithm.
 84
 85        **Args:**
 86        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
 87        - **batch** (`Any`): the training data batch.
 88        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
 89        """
 90
 91        hidden_features = outputs["hidden_features"]
 92
 93        for layer_name in self.backbone.weighted_layer_names:
 94            # layer-wise operation
 95
 96            layer = self.backbone.get_layer_by_name(
 97                layer_name
 98            )  # get the layer by its name
 99
100            # update age
101            self.age[layer_name] += 1
102
103            # calculate current contribution utility
104            current_contribution_utility = (
105                torch.mean(
106                    torch.abs(hidden_features[layer_name]),
107                    dim=0,  # average the features over batch samples
108                )
109                * torch.sum(
110                    torch.abs(layer.weight),
111                    dim=1,  # sum over the output dimension
112                )
113            ).detach()
114            current_contribution_utility = min_max_normalise(
115                current_contribution_utility
116            )  # normalise the utility to [0,1] to avoid linearly increasing utility
117
118            # update utility
119            self.contribution_utility[layer_name] = (
120                self.utility_decay_rate * self.contribution_utility[layer_name]
121                + (1 - self.utility_decay_rate) * current_contribution_utility
122            )
123
124            # find eligible units
125            eligible_mask = self.age[layer_name] > self.maturity_threshold
126            eligible_indices = torch.where(eligible_mask)[0]
127
128            # update the number of replacements
129            num_eligible_units = eligible_indices.numel()
130            self.num_replacements[layer_name] += int(
131                self.replacement_rate * num_eligible_units
132            )
133
134            # if the number of replacements is greater than 1, execute the replacement
135            if self.num_replacements[layer_name] > 1:
136
137                # find the unit with smallest utility among eligible units
138                replaced_unit_idx = eligible_indices[
139                    torch.argmin(
140                        self.contribution_utility[layer_name][eligible_indices]
141                        / self.age[layer_name][eligible_indices]
142                    ).item()
143                ]
144
145                # reinitialise the input weights of the unit
146                preceding_layer_name = self.backbone.preceding_layer_name(layer_name)
147                if preceding_layer_name is not None:
148                    preceding_layer = self.backbone.get_layer_by_name(
149                        preceding_layer_name
150                    )
151                    with torch.no_grad():
152
153                        preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like(
154                            preceding_layer.weight[:, replaced_unit_idx]
155                        )
156
157                # reinitalise the output weights of the unit
158                with torch.no_grad():
159                    layer.weight[replaced_unit_idx] = torch.rand_like(
160                        layer.weight[replaced_unit_idx]
161                    )
162
163                # reinitialise utility
164                self.contribution_utility[layer_name][replaced_unit_idx] = 0.0
165
166                # reintialise age
167                self.age[layer_name][replaced_unit_idx] = 0
168
169                # update the number of replacements
170                self.num_replacements[layer_name] -= 1
 23class CBP(Finetuning):
 24    r"""CBP (Continual Backpropagation) algorithm.
 25
 26    [CBP (Continual Backpropagation, 2024)](https://www.nature.com/articles/s41586-024-07711-7) is a continual learning approach that reinitialises a small number of units during training, using an utility measures to determine which units to reinitialise. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning.
 27
 28    We implement CBP as a subclass of Finetuning algorithm, as CBP has the same `forward()`, `training_step()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: CLBackbone,
 34        heads: HeadsTIL | HeadsCIL,
 35        replacement_rate: float,
 36        maturity_threshold: int,
 37        utility_decay_rate: float,
 38    ) -> None:
 39        r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
 40
 41        **Args:**
 42        - **backbone** (`CLBackbone`): backbone network.
 43        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
 44        - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialised during training.
 45        - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialised.
 46        - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
 47        """
 48        Finetuning.__init__(self, backbone=backbone, heads=heads)
 49
 50        self.replacement_rate: float = replacement_rate
 51        r"""Store the replacement rate of units. """
 52        self.maturity_threshold: int = maturity_threshold
 53        r"""Store the maturity threshold of units. """
 54        self.utility_decay_rate: float = utility_decay_rate
 55        r"""Store the utility decay rate of units. """
 56
 57        self.contribution_utility: dict[str, Tensor] = {}
 58        r"""Store the contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units). """
 59        self.num_replacements: dict[str, int] = {}
 60        r"""Store the number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """
 61        self.age: dict[str, Tensor] = {}
 62        r"""Store the age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """
 63
 64    def on_train_start(self) -> None:
 65        r"""Initialise the utility, number of replacements and age for each layer as zeros."""
 66
 67        # initialise the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
 68        if self.task_id == 1:
 69            for layer_name in self.backbone.weighted_layer_names:
 70                layer = self.backbone.get_layer_by_name(
 71                    layer_name
 72                )  # get the layer by its name
 73                num_units = layer.weight.shape[0]
 74
 75                self.contribution_utility[layer_name] = torch.zeros(num_units).to(
 76                    self.device
 77                )
 78                self.num_replacements[layer_name] = 0
 79                self.age[layer_name] = torch.zeros(num_units).to(self.device)
 80
 81    def on_train_batch_end(
 82        self, outputs: dict[str, Any], batch: Any, batch_idx: int
 83    ) -> None:
 84        r"""Update the contribution utility and age of units after each training step, and conduct reinitialisation of units based on utility measures. This is the core of the CBP algorithm.
 85
 86        **Args:**
 87        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
 88        - **batch** (`Any`): the training data batch.
 89        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
 90        """
 91
 92        hidden_features = outputs["hidden_features"]
 93
 94        for layer_name in self.backbone.weighted_layer_names:
 95            # layer-wise operation
 96
 97            layer = self.backbone.get_layer_by_name(
 98                layer_name
 99            )  # get the layer by its name
100
101            # update age
102            self.age[layer_name] += 1
103
104            # calculate current contribution utility
105            current_contribution_utility = (
106                torch.mean(
107                    torch.abs(hidden_features[layer_name]),
108                    dim=0,  # average the features over batch samples
109                )
110                * torch.sum(
111                    torch.abs(layer.weight),
112                    dim=1,  # sum over the output dimension
113                )
114            ).detach()
115            current_contribution_utility = min_max_normalise(
116                current_contribution_utility
117            )  # normalise the utility to [0,1] to avoid linearly increasing utility
118
119            # update utility
120            self.contribution_utility[layer_name] = (
121                self.utility_decay_rate * self.contribution_utility[layer_name]
122                + (1 - self.utility_decay_rate) * current_contribution_utility
123            )
124
125            # find eligible units
126            eligible_mask = self.age[layer_name] > self.maturity_threshold
127            eligible_indices = torch.where(eligible_mask)[0]
128
129            # update the number of replacements
130            num_eligible_units = eligible_indices.numel()
131            self.num_replacements[layer_name] += int(
132                self.replacement_rate * num_eligible_units
133            )
134
135            # if the number of replacements is greater than 1, execute the replacement
136            if self.num_replacements[layer_name] > 1:
137
138                # find the unit with smallest utility among eligible units
139                replaced_unit_idx = eligible_indices[
140                    torch.argmin(
141                        self.contribution_utility[layer_name][eligible_indices]
142                        / self.age[layer_name][eligible_indices]
143                    ).item()
144                ]
145
146                # reinitialise the input weights of the unit
147                preceding_layer_name = self.backbone.preceding_layer_name(layer_name)
148                if preceding_layer_name is not None:
149                    preceding_layer = self.backbone.get_layer_by_name(
150                        preceding_layer_name
151                    )
152                    with torch.no_grad():
153
154                        preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like(
155                            preceding_layer.weight[:, replaced_unit_idx]
156                        )
157
158                # reinitalise the output weights of the unit
159                with torch.no_grad():
160                    layer.weight[replaced_unit_idx] = torch.rand_like(
161                        layer.weight[replaced_unit_idx]
162                    )
163
164                # reinitialise utility
165                self.contribution_utility[layer_name][replaced_unit_idx] = 0.0
166
167                # reintialise age
168                self.age[layer_name][replaced_unit_idx] = 0
169
170                # update the number of replacements
171                self.num_replacements[layer_name] -= 1

CBP (Continual Backpropagation) algorithm.

CBP (Continual Backpropagation, 2024) is a continual learning approach that reinitialises a small number of units during training, using an utility measures to determine which units to reinitialise. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning.

We implement CBP as a subclass of Finetuning algorithm, as CBP has the same forward(), training_step(), validation_step() and test_step() method as Finetuning class.

CBP( backbone: clarena.backbones.CLBackbone, heads: clarena.cl_heads.HeadsTIL | clarena.cl_heads.HeadsCIL, replacement_rate: float, maturity_threshold: int, utility_decay_rate: float)
31    def __init__(
32        self,
33        backbone: CLBackbone,
34        heads: HeadsTIL | HeadsCIL,
35        replacement_rate: float,
36        maturity_threshold: int,
37        utility_decay_rate: float,
38    ) -> None:
39        r"""Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.
40
41        **Args:**
42        - **backbone** (`CLBackbone`): backbone network.
43        - **heads** (`HeadsTIL` | `HeadsCIL`): output heads.
44        - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialised during training.
45        - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialised.
46        - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
47        """
48        Finetuning.__init__(self, backbone=backbone, heads=heads)
49
50        self.replacement_rate: float = replacement_rate
51        r"""Store the replacement rate of units. """
52        self.maturity_threshold: int = maturity_threshold
53        r"""Store the maturity threshold of units. """
54        self.utility_decay_rate: float = utility_decay_rate
55        r"""Store the utility decay rate of units. """
56
57        self.contribution_utility: dict[str, Tensor] = {}
58        r"""Store the contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units). """
59        self.num_replacements: dict[str, int] = {}
60        r"""Store the number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """
61        self.age: dict[str, Tensor] = {}
62        r"""Store the age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """

Initialise the Finetuning algorithm with the network. It has no additional hyperparamaters.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL): output heads.
  • replacement_rate (float): the replacement rate of units. It is the precentage of units to be reinitialised during training.
  • maturity_threshold (int): the maturity threshold of units. It is the number of training steps before a unit can be reinitialised.
  • utility_decay_rate (float): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
replacement_rate: float

Store the replacement rate of units.

maturity_threshold: int

Store the maturity threshold of units.

utility_decay_rate: float

Store the utility decay rate of units.

contribution_utility: dict[str, torch.Tensor]

Store the contribution utility of units. See equation (1) in the continual backpropagation paper. Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units).

num_replacements: dict[str, int]

Store the number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer.

age: dict[str, torch.Tensor]

Store the age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units).

def on_train_start(self) -> None:
64    def on_train_start(self) -> None:
65        r"""Initialise the utility, number of replacements and age for each layer as zeros."""
66
67        # initialise the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
68        if self.task_id == 1:
69            for layer_name in self.backbone.weighted_layer_names:
70                layer = self.backbone.get_layer_by_name(
71                    layer_name
72                )  # get the layer by its name
73                num_units = layer.weight.shape[0]
74
75                self.contribution_utility[layer_name] = torch.zeros(num_units).to(
76                    self.device
77                )
78                self.num_replacements[layer_name] = 0
79                self.age[layer_name] = torch.zeros(num_units).to(self.device)

Initialise the utility, number of replacements and age for each layer as zeros.

def on_train_batch_end(self, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 81    def on_train_batch_end(
 82        self, outputs: dict[str, Any], batch: Any, batch_idx: int
 83    ) -> None:
 84        r"""Update the contribution utility and age of units after each training step, and conduct reinitialisation of units based on utility measures. This is the core of the CBP algorithm.
 85
 86        **Args:**
 87        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
 88        - **batch** (`Any`): the training data batch.
 89        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
 90        """
 91
 92        hidden_features = outputs["hidden_features"]
 93
 94        for layer_name in self.backbone.weighted_layer_names:
 95            # layer-wise operation
 96
 97            layer = self.backbone.get_layer_by_name(
 98                layer_name
 99            )  # get the layer by its name
100
101            # update age
102            self.age[layer_name] += 1
103
104            # calculate current contribution utility
105            current_contribution_utility = (
106                torch.mean(
107                    torch.abs(hidden_features[layer_name]),
108                    dim=0,  # average the features over batch samples
109                )
110                * torch.sum(
111                    torch.abs(layer.weight),
112                    dim=1,  # sum over the output dimension
113                )
114            ).detach()
115            current_contribution_utility = min_max_normalise(
116                current_contribution_utility
117            )  # normalise the utility to [0,1] to avoid linearly increasing utility
118
119            # update utility
120            self.contribution_utility[layer_name] = (
121                self.utility_decay_rate * self.contribution_utility[layer_name]
122                + (1 - self.utility_decay_rate) * current_contribution_utility
123            )
124
125            # find eligible units
126            eligible_mask = self.age[layer_name] > self.maturity_threshold
127            eligible_indices = torch.where(eligible_mask)[0]
128
129            # update the number of replacements
130            num_eligible_units = eligible_indices.numel()
131            self.num_replacements[layer_name] += int(
132                self.replacement_rate * num_eligible_units
133            )
134
135            # if the number of replacements is greater than 1, execute the replacement
136            if self.num_replacements[layer_name] > 1:
137
138                # find the unit with smallest utility among eligible units
139                replaced_unit_idx = eligible_indices[
140                    torch.argmin(
141                        self.contribution_utility[layer_name][eligible_indices]
142                        / self.age[layer_name][eligible_indices]
143                    ).item()
144                ]
145
146                # reinitialise the input weights of the unit
147                preceding_layer_name = self.backbone.preceding_layer_name(layer_name)
148                if preceding_layer_name is not None:
149                    preceding_layer = self.backbone.get_layer_by_name(
150                        preceding_layer_name
151                    )
152                    with torch.no_grad():
153
154                        preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like(
155                            preceding_layer.weight[:, replaced_unit_idx]
156                        )
157
158                # reinitalise the output weights of the unit
159                with torch.no_grad():
160                    layer.weight[replaced_unit_idx] = torch.rand_like(
161                        layer.weight[replaced_unit_idx]
162                    )
163
164                # reinitialise utility
165                self.contribution_utility[layer_name][replaced_unit_idx] = 0.0
166
167                # reintialise age
168                self.age[layer_name][replaced_unit_idx] = 0
169
170                # update the number of replacements
171                self.num_replacements[layer_name] -= 1

Update the contribution utility and age of units after each training step, and conduct reinitialisation of units based on utility measures. This is the core of the CBP algorithm.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, which is the returns of the training_step() method in the CLAlgorithm.
  • batch (Any): the training data batch.
  • batch_idx (int): the index of the current batch. This is for the file name of mask figures.