clarena.cl_algorithms.cbp

The submodule in cl_algorithms for CBP (Continual Backpropagation) algorithm.

  1r"""
  2The submodule in `cl_algorithms` for [CBP (Continual Backpropagation)](https://www.nature.com/articles/s41586-024-07711-7) algorithm.
  3"""
  4
  5__all__ = ["CBP"]
  6
  7import logging
  8from typing import Any
  9
 10import torch
 11from torch import Tensor
 12
 13from clarena.backbones import CLBackbone
 14from clarena.cl_algorithms import Finetuning
 15from clarena.heads import HeadDIL, HeadsCIL, HeadsTIL
 16from clarena.utils.transforms import min_max_normalize
 17
 18# always get logger for built-in logging in each module
 19pylogger = logging.getLogger(__name__)
 20
 21
 22class CBP(Finetuning):
 23    r"""[CBP (Continual Backpropagation)](https://www.nature.com/articles/s41586-024-07711-7) algorithm.
 24
 25    A continual learning approach that reinitializes a small number of units during training, using an utility measures to determine which units to reinitialize. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning.
 26
 27    We implement CBP as a subclass of Finetuning algorithm, as CBP has the same `forward()`, `training_step()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 28    """
 29
 30    def __init__(
 31        self,
 32        backbone: CLBackbone,
 33        heads: HeadsTIL | HeadsCIL | HeadDIL,
 34        replacement_rate: float,
 35        maturity_threshold: int,
 36        utility_decay_rate: float,
 37        non_algorithmic_hparams: dict[str, Any] = {},
 38        **kwargs,
 39    ) -> None:
 40        r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.
 41
 42        **Args:**
 43        - **backbone** (`CLBackbone`): backbone network.
 44        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
 45        - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialized during training.
 46        - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialized.
 47        - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
 48        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 49        - **kwargs**: Reserved for multiple inheritance.
 50
 51        """
 52        super().__init__(
 53            backbone=backbone,
 54            heads=heads,
 55            non_algorithmic_hparams=non_algorithmic_hparams,
 56            **kwargs,
 57        )
 58
 59        self.replacement_rate: float = replacement_rate
 60        r"""The replacement rate of units. """
 61        self.maturity_threshold: int = maturity_threshold
 62        r"""The maturity threshold of units. """
 63        self.utility_decay_rate: float = utility_decay_rate
 64        r"""The utility decay rate of units. """
 65
 66        # save additional algorithmic hyperparameters
 67        self.save_hyperparameters(
 68            "replacement_rate",
 69            "maturity_threshold",
 70            "utility_decay_rate",
 71        )
 72
 73        self.contribution_utility: dict[str, Tensor] = {}
 74        r"""The contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
 75        self.num_replacements: dict[str, int] = {}
 76        r"""The number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """
 77        self.age: dict[str, Tensor] = {}
 78        r"""The age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """
 79
 80    def on_train_start(self) -> None:
 81        r"""Initialize the utility, number of replacements and age for each layer as zeros."""
 82
 83        # initialize the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
 84        if self.task_id == 1:
 85            for layer_name in self.backbone.weighted_layer_names:
 86                layer = self.backbone.get_layer_by_name(
 87                    layer_name
 88                )  # get the layer by its name
 89                num_units = layer.weight.shape[0]
 90
 91                self.contribution_utility[layer_name] = torch.zeros(num_units).to(
 92                    self.device
 93                )
 94                self.num_replacements[layer_name] = 0
 95                self.age[layer_name] = torch.zeros(num_units).to(self.device)
 96
 97    def on_train_batch_end(
 98        self, outputs: dict[str, Any], batch: Any, batch_idx: int
 99    ) -> None:
100        r"""Update the contribution utility and age of units after each training step, and conduct reinitialization of units based on utility measures. This is the core of the CBP algorithm.
101
102        **Args:**
103        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
104        - **batch** (`Any`): the training data batch.
105        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
106        """
107
108        activations = outputs["activations"]
109
110        for layer_name in self.backbone.weighted_layer_names:
111            # layer-wise operation
112
113            layer = self.backbone.get_layer_by_name(
114                layer_name
115            )  # get the layer by its name
116
117            # update age
118            self.age[layer_name] += 1
119
120            # calculate current contribution utility
121            current_contribution_utility = (
122                torch.mean(
123                    torch.abs(activations[layer_name]),
124                    dim=0,  # average the features over batch samples
125                )
126                * torch.sum(
127                    torch.abs(layer.weight),
128                    dim=1,  # sum over the output dimension
129                )
130            ).detach()
131            current_contribution_utility = min_max_normalize(
132                current_contribution_utility
133            )  # normalize the utility to [0,1] to avoid linearly increasing utility
134
135            # update utility
136            self.contribution_utility[layer_name] = (
137                self.utility_decay_rate * self.contribution_utility[layer_name]
138                + (1 - self.utility_decay_rate) * current_contribution_utility
139            )
140
141            # find eligible units
142            eligible_mask = self.age[layer_name] > self.maturity_threshold
143            eligible_indices = torch.where(eligible_mask)[0]
144
145            # update the number of replacements
146            num_eligible_units = eligible_indices.numel()
147            self.num_replacements[layer_name] += int(
148                self.replacement_rate * num_eligible_units
149            )
150
151            # if the number of replacements is greater than 1, execute the replacement
152            if self.num_replacements[layer_name] > 1:
153
154                # find the unit with smallest utility among eligible units
155                replaced_unit_idx = eligible_indices[
156                    torch.argmin(
157                        self.contribution_utility[layer_name][eligible_indices]
158                        / self.age[layer_name][eligible_indices]
159                    ).item()
160                ]
161
162                # reinitialize the input weights of the unit
163                preceding_layer = self.backbone.preceding_layer(layer_name)
164                if preceding_layer is not None:
165
166                    with torch.no_grad():
167
168                        preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like(
169                            preceding_layer.weight[:, replaced_unit_idx]
170                        )
171
172                # reinitalize the output weights of the unit
173                with torch.no_grad():
174                    layer.weight[replaced_unit_idx] = torch.rand_like(
175                        layer.weight[replaced_unit_idx]
176                    )
177
178                # reinitialize utility
179                self.contribution_utility[layer_name][replaced_unit_idx] = 0.0
180
181                # reintialize age
182                self.age[layer_name][replaced_unit_idx] = 0
183
184                # update the number of replacements
185                self.num_replacements[layer_name] -= 1
 23class CBP(Finetuning):
 24    r"""[CBP (Continual Backpropagation)](https://www.nature.com/articles/s41586-024-07711-7) algorithm.
 25
 26    A continual learning approach that reinitializes a small number of units during training, using an utility measures to determine which units to reinitialize. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning.
 27
 28    We implement CBP as a subclass of Finetuning algorithm, as CBP has the same `forward()`, `training_step()`, `validation_step()` and `test_step()` method as `Finetuning` class.
 29    """
 30
 31    def __init__(
 32        self,
 33        backbone: CLBackbone,
 34        heads: HeadsTIL | HeadsCIL | HeadDIL,
 35        replacement_rate: float,
 36        maturity_threshold: int,
 37        utility_decay_rate: float,
 38        non_algorithmic_hparams: dict[str, Any] = {},
 39        **kwargs,
 40    ) -> None:
 41        r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.
 42
 43        **Args:**
 44        - **backbone** (`CLBackbone`): backbone network.
 45        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
 46        - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialized during training.
 47        - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialized.
 48        - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
 49        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
 50        - **kwargs**: Reserved for multiple inheritance.
 51
 52        """
 53        super().__init__(
 54            backbone=backbone,
 55            heads=heads,
 56            non_algorithmic_hparams=non_algorithmic_hparams,
 57            **kwargs,
 58        )
 59
 60        self.replacement_rate: float = replacement_rate
 61        r"""The replacement rate of units. """
 62        self.maturity_threshold: int = maturity_threshold
 63        r"""The maturity threshold of units. """
 64        self.utility_decay_rate: float = utility_decay_rate
 65        r"""The utility decay rate of units. """
 66
 67        # save additional algorithmic hyperparameters
 68        self.save_hyperparameters(
 69            "replacement_rate",
 70            "maturity_threshold",
 71            "utility_decay_rate",
 72        )
 73
 74        self.contribution_utility: dict[str, Tensor] = {}
 75        r"""The contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
 76        self.num_replacements: dict[str, int] = {}
 77        r"""The number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """
 78        self.age: dict[str, Tensor] = {}
 79        r"""The age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """
 80
 81    def on_train_start(self) -> None:
 82        r"""Initialize the utility, number of replacements and age for each layer as zeros."""
 83
 84        # initialize the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
 85        if self.task_id == 1:
 86            for layer_name in self.backbone.weighted_layer_names:
 87                layer = self.backbone.get_layer_by_name(
 88                    layer_name
 89                )  # get the layer by its name
 90                num_units = layer.weight.shape[0]
 91
 92                self.contribution_utility[layer_name] = torch.zeros(num_units).to(
 93                    self.device
 94                )
 95                self.num_replacements[layer_name] = 0
 96                self.age[layer_name] = torch.zeros(num_units).to(self.device)
 97
 98    def on_train_batch_end(
 99        self, outputs: dict[str, Any], batch: Any, batch_idx: int
100    ) -> None:
101        r"""Update the contribution utility and age of units after each training step, and conduct reinitialization of units based on utility measures. This is the core of the CBP algorithm.
102
103        **Args:**
104        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
105        - **batch** (`Any`): the training data batch.
106        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
107        """
108
109        activations = outputs["activations"]
110
111        for layer_name in self.backbone.weighted_layer_names:
112            # layer-wise operation
113
114            layer = self.backbone.get_layer_by_name(
115                layer_name
116            )  # get the layer by its name
117
118            # update age
119            self.age[layer_name] += 1
120
121            # calculate current contribution utility
122            current_contribution_utility = (
123                torch.mean(
124                    torch.abs(activations[layer_name]),
125                    dim=0,  # average the features over batch samples
126                )
127                * torch.sum(
128                    torch.abs(layer.weight),
129                    dim=1,  # sum over the output dimension
130                )
131            ).detach()
132            current_contribution_utility = min_max_normalize(
133                current_contribution_utility
134            )  # normalize the utility to [0,1] to avoid linearly increasing utility
135
136            # update utility
137            self.contribution_utility[layer_name] = (
138                self.utility_decay_rate * self.contribution_utility[layer_name]
139                + (1 - self.utility_decay_rate) * current_contribution_utility
140            )
141
142            # find eligible units
143            eligible_mask = self.age[layer_name] > self.maturity_threshold
144            eligible_indices = torch.where(eligible_mask)[0]
145
146            # update the number of replacements
147            num_eligible_units = eligible_indices.numel()
148            self.num_replacements[layer_name] += int(
149                self.replacement_rate * num_eligible_units
150            )
151
152            # if the number of replacements is greater than 1, execute the replacement
153            if self.num_replacements[layer_name] > 1:
154
155                # find the unit with smallest utility among eligible units
156                replaced_unit_idx = eligible_indices[
157                    torch.argmin(
158                        self.contribution_utility[layer_name][eligible_indices]
159                        / self.age[layer_name][eligible_indices]
160                    ).item()
161                ]
162
163                # reinitialize the input weights of the unit
164                preceding_layer = self.backbone.preceding_layer(layer_name)
165                if preceding_layer is not None:
166
167                    with torch.no_grad():
168
169                        preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like(
170                            preceding_layer.weight[:, replaced_unit_idx]
171                        )
172
173                # reinitalize the output weights of the unit
174                with torch.no_grad():
175                    layer.weight[replaced_unit_idx] = torch.rand_like(
176                        layer.weight[replaced_unit_idx]
177                    )
178
179                # reinitialize utility
180                self.contribution_utility[layer_name][replaced_unit_idx] = 0.0
181
182                # reintialize age
183                self.age[layer_name][replaced_unit_idx] = 0
184
185                # update the number of replacements
186                self.num_replacements[layer_name] -= 1

CBP (Continual Backpropagation) algorithm.

A continual learning approach that reinitializes a small number of units during training, using an utility measures to determine which units to reinitialize. It aims to address loss of plasticity problem for learning new tasks, yet not very well solve the catastrophic forgetting problem in continual learning.

We implement CBP as a subclass of Finetuning algorithm, as CBP has the same forward(), training_step(), validation_step() and test_step() method as Finetuning class.

CBP( backbone: clarena.backbones.CLBackbone, heads: clarena.heads.HeadsTIL | clarena.heads.HeadsCIL | clarena.heads.HeadDIL, replacement_rate: float, maturity_threshold: int, utility_decay_rate: float, non_algorithmic_hparams: dict[str, typing.Any] = {}, **kwargs)
31    def __init__(
32        self,
33        backbone: CLBackbone,
34        heads: HeadsTIL | HeadsCIL | HeadDIL,
35        replacement_rate: float,
36        maturity_threshold: int,
37        utility_decay_rate: float,
38        non_algorithmic_hparams: dict[str, Any] = {},
39        **kwargs,
40    ) -> None:
41        r"""Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.
42
43        **Args:**
44        - **backbone** (`CLBackbone`): backbone network.
45        - **heads** (`HeadsTIL` | `HeadsCIL` | `HeadDIL`): output heads.
46        - **replacement_rate** (`float`): the replacement rate of units. It is the precentage of units to be reinitialized during training.
47        - **maturity_threshold** (`int`): the maturity threshold of units. It is the number of training steps before a unit can be reinitialized.
48        - **utility_decay_rate** (`float`): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
49        - **non_algorithmic_hparams** (`dict[str, Any]`): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this `LightningModule` object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from `save_hyperparameters()` method. This is useful for the experiment configuration and reproducibility.
50        - **kwargs**: Reserved for multiple inheritance.
51
52        """
53        super().__init__(
54            backbone=backbone,
55            heads=heads,
56            non_algorithmic_hparams=non_algorithmic_hparams,
57            **kwargs,
58        )
59
60        self.replacement_rate: float = replacement_rate
61        r"""The replacement rate of units. """
62        self.maturity_threshold: int = maturity_threshold
63        r"""The maturity threshold of units. """
64        self.utility_decay_rate: float = utility_decay_rate
65        r"""The utility decay rate of units. """
66
67        # save additional algorithmic hyperparameters
68        self.save_hyperparameters(
69            "replacement_rate",
70            "maturity_threshold",
71            "utility_decay_rate",
72        )
73
74        self.contribution_utility: dict[str, Tensor] = {}
75        r"""The contribution utility of units. See equation (1) in the [continual backpropagation paper](https://www.nature.com/articles/s41586-024-07711-7). Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ). """
76        self.num_replacements: dict[str, int] = {}
77        r"""The number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer. """
78        self.age: dict[str, Tensor] = {}
79        r"""The age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units). """

Initialize the Finetuning algorithm with the network. It has no additional hyperparameters.

Args:

  • backbone (CLBackbone): backbone network.
  • heads (HeadsTIL | HeadsCIL | HeadDIL): output heads.
  • replacement_rate (float): the replacement rate of units. It is the precentage of units to be reinitialized during training.
  • maturity_threshold (int): the maturity threshold of units. It is the number of training steps before a unit can be reinitialized.
  • utility_decay_rate (float): the utility decay rate of units. It is the rate at which the utility of a unit decays over time.
  • non_algorithmic_hparams (dict[str, Any]): non-algorithmic hyperparameters that are not related to the algorithm itself are passed to this LightningModule object from the config, such as optimizer and learning rate scheduler configurations. They are saved for Lightning APIs from save_hyperparameters() method. This is useful for the experiment configuration and reproducibility.
  • kwargs: Reserved for multiple inheritance.
replacement_rate: float

The replacement rate of units.

maturity_threshold: int

The maturity threshold of units.

utility_decay_rate: float

The utility decay rate of units.

contribution_utility: dict[str, torch.Tensor]

The contribution utility of units. See equation (1) in the continual backpropagation paper. Keys are layer names and values are the utility tensor for the layer. The utility tensor is the same size as the feature tensor with size (number of units, ).

num_replacements: dict[str, int]

The number of replacements of units in each layer. Keys are layer names and values are the number of replacements for the layer.

age: dict[str, torch.Tensor]

The age of units. Keys are layer names and values are the age tensor for the layer. The age tensor is the same size as the feature tensor with size (1, number of units).

def on_train_start(self) -> None:
81    def on_train_start(self) -> None:
82        r"""Initialize the utility, number of replacements and age for each layer as zeros."""
83
84        # initialize the utility, number of replacements and age as zeros at the beginning of first task. This should not be called in `__init__()` method as the `self.device` is not available at that time.
85        if self.task_id == 1:
86            for layer_name in self.backbone.weighted_layer_names:
87                layer = self.backbone.get_layer_by_name(
88                    layer_name
89                )  # get the layer by its name
90                num_units = layer.weight.shape[0]
91
92                self.contribution_utility[layer_name] = torch.zeros(num_units).to(
93                    self.device
94                )
95                self.num_replacements[layer_name] = 0
96                self.age[layer_name] = torch.zeros(num_units).to(self.device)

Initialize the utility, number of replacements and age for each layer as zeros.

def on_train_batch_end(self, outputs: dict[str, typing.Any], batch: Any, batch_idx: int) -> None:
 98    def on_train_batch_end(
 99        self, outputs: dict[str, Any], batch: Any, batch_idx: int
100    ) -> None:
101        r"""Update the contribution utility and age of units after each training step, and conduct reinitialization of units based on utility measures. This is the core of the CBP algorithm.
102
103        **Args:**
104        - **outputs** (`dict[str, Any]`): the outputs of the training step, which is the returns of the `training_step()` method in the `CLAlgorithm`.
105        - **batch** (`Any`): the training data batch.
106        - **batch_idx** (`int`): the index of the current batch. This is for the file name of mask figures.
107        """
108
109        activations = outputs["activations"]
110
111        for layer_name in self.backbone.weighted_layer_names:
112            # layer-wise operation
113
114            layer = self.backbone.get_layer_by_name(
115                layer_name
116            )  # get the layer by its name
117
118            # update age
119            self.age[layer_name] += 1
120
121            # calculate current contribution utility
122            current_contribution_utility = (
123                torch.mean(
124                    torch.abs(activations[layer_name]),
125                    dim=0,  # average the features over batch samples
126                )
127                * torch.sum(
128                    torch.abs(layer.weight),
129                    dim=1,  # sum over the output dimension
130                )
131            ).detach()
132            current_contribution_utility = min_max_normalize(
133                current_contribution_utility
134            )  # normalize the utility to [0,1] to avoid linearly increasing utility
135
136            # update utility
137            self.contribution_utility[layer_name] = (
138                self.utility_decay_rate * self.contribution_utility[layer_name]
139                + (1 - self.utility_decay_rate) * current_contribution_utility
140            )
141
142            # find eligible units
143            eligible_mask = self.age[layer_name] > self.maturity_threshold
144            eligible_indices = torch.where(eligible_mask)[0]
145
146            # update the number of replacements
147            num_eligible_units = eligible_indices.numel()
148            self.num_replacements[layer_name] += int(
149                self.replacement_rate * num_eligible_units
150            )
151
152            # if the number of replacements is greater than 1, execute the replacement
153            if self.num_replacements[layer_name] > 1:
154
155                # find the unit with smallest utility among eligible units
156                replaced_unit_idx = eligible_indices[
157                    torch.argmin(
158                        self.contribution_utility[layer_name][eligible_indices]
159                        / self.age[layer_name][eligible_indices]
160                    ).item()
161                ]
162
163                # reinitialize the input weights of the unit
164                preceding_layer = self.backbone.preceding_layer(layer_name)
165                if preceding_layer is not None:
166
167                    with torch.no_grad():
168
169                        preceding_layer.weight[:, replaced_unit_idx] = torch.rand_like(
170                            preceding_layer.weight[:, replaced_unit_idx]
171                        )
172
173                # reinitalize the output weights of the unit
174                with torch.no_grad():
175                    layer.weight[replaced_unit_idx] = torch.rand_like(
176                        layer.weight[replaced_unit_idx]
177                    )
178
179                # reinitialize utility
180                self.contribution_utility[layer_name][replaced_unit_idx] = 0.0
181
182                # reintialize age
183                self.age[layer_name][replaced_unit_idx] = 0
184
185                # update the number of replacements
186                self.num_replacements[layer_name] -= 1

Update the contribution utility and age of units after each training step, and conduct reinitialization of units based on utility measures. This is the core of the CBP algorithm.

Args:

  • outputs (dict[str, Any]): the outputs of the training step, which is the returns of the training_step() method in the CLAlgorithm.
  • batch (Any): the training data batch.
  • batch_idx (int): the index of the current batch. This is for the file name of mask figures.