clarena.cl_algorithms.regularisers.parameter_change

The submodule in regularisers for parameter change regularisation.

 1r"""The submodule in `regularisers` for parameter change regularisation."""
 2
 3__all__ = ["ParameterChangeReg"]
 4
 5import torch
 6from torch import Tensor, nn
 7
 8
 9class ParameterChangeReg(nn.Module):
10    r"""Parameter change regulariser.
11
12    $$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^p$$
13
14    It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^p$ distance. The regularisation can be parameter-wise weighted, i.e. $w_i$ in the formula.
15
16    It is used in:
17    - [L2 Regularisation algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters.
18    - [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a weighted L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters. The regularisation weights are parameter importance measure calculated from fisher information. See equation 3 in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
19    """
20
21    def __init__(
22        self,
23        factor: float,
24        p_norm: float,
25    ) -> None:
26        r"""Initialise the regulariser.
27
28        **Args:**
29        - **factor** (`float`): the regularisation factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
30        - **p_norm** (`float`): the norm of the distance, should be a positive float.
31        """
32        nn.Module.__init__(self)
33
34        self.factor = factor
35        """Store the regularisation factor for parameter change."""
36        self.p_norm = p_norm
37        """Store the the norm of the distance of two set of parameters. """
38
39    def forward(
40        self,
41        target_model: nn.Module,
42        ref_model: nn.Module,
43        weights: dict[str, Tensor],
44    ) -> Tensor:
45        r"""Calculate the regularisation loss.
46
47        **Args:**
48        - **target_model** (`nn.Module`): the model of the target parameters. In EWC, it's the model of current training task.
49        - **ref_model** (`nn.Module`): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
50        - **weights** (`dict[str, Tensor]`): the regularisation weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.
51
52        **Returns:**
53        - **reg** (`Tensor`): the parameter change regularisation value.
54        """
55        reg = 0.0
56
57        for (param_name, target_param), (param_name, ref_param) in zip(
58            target_model.named_parameters(), ref_model.named_parameters()
59        ):
60            weight = weights[param_name]
61            reg += torch.sum(weight * (target_param - ref_param).norm(p=self.p_norm))
62
63        return self.factor * reg
class ParameterChangeReg(torch.nn.modules.module.Module):
10class ParameterChangeReg(nn.Module):
11    r"""Parameter change regulariser.
12
13    $$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^p$$
14
15    It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^p$ distance. The regularisation can be parameter-wise weighted, i.e. $w_i$ in the formula.
16
17    It is used in:
18    - [L2 Regularisation algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters.
19    - [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a weighted L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters. The regularisation weights are parameter importance measure calculated from fisher information. See equation 3 in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
20    """
21
22    def __init__(
23        self,
24        factor: float,
25        p_norm: float,
26    ) -> None:
27        r"""Initialise the regulariser.
28
29        **Args:**
30        - **factor** (`float`): the regularisation factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
31        - **p_norm** (`float`): the norm of the distance, should be a positive float.
32        """
33        nn.Module.__init__(self)
34
35        self.factor = factor
36        """Store the regularisation factor for parameter change."""
37        self.p_norm = p_norm
38        """Store the the norm of the distance of two set of parameters. """
39
40    def forward(
41        self,
42        target_model: nn.Module,
43        ref_model: nn.Module,
44        weights: dict[str, Tensor],
45    ) -> Tensor:
46        r"""Calculate the regularisation loss.
47
48        **Args:**
49        - **target_model** (`nn.Module`): the model of the target parameters. In EWC, it's the model of current training task.
50        - **ref_model** (`nn.Module`): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
51        - **weights** (`dict[str, Tensor]`): the regularisation weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.
52
53        **Returns:**
54        - **reg** (`Tensor`): the parameter change regularisation value.
55        """
56        reg = 0.0
57
58        for (param_name, target_param), (param_name, ref_param) in zip(
59            target_model.named_parameters(), ref_model.named_parameters()
60        ):
61            weight = weights[param_name]
62            reg += torch.sum(weight * (target_param - ref_param).norm(p=self.p_norm))
63
64        return self.factor * reg

Parameter change regulariser.

$$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^p$$

It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^p$ distance. The regularisation can be parameter-wise weighted, i.e. $w_i$ in the formula.

It is used in:

  • L2 Regularisation algorithm: as a L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters.
  • EWC (Elastic Weight Consolidation) algorithm: as a weighted L2 regulariser for the current task parameters to prevent them from changing too much from the previous task parameters. The regularisation weights are parameter importance measure calculated from fisher information. See equation 3 in the EWC paper.
ParameterChangeReg(factor: float, p_norm: float)
22    def __init__(
23        self,
24        factor: float,
25        p_norm: float,
26    ) -> None:
27        r"""Initialise the regulariser.
28
29        **Args:**
30        - **factor** (`float`): the regularisation factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
31        - **p_norm** (`float`): the norm of the distance, should be a positive float.
32        """
33        nn.Module.__init__(self)
34
35        self.factor = factor
36        """Store the regularisation factor for parameter change."""
37        self.p_norm = p_norm
38        """Store the the norm of the distance of two set of parameters. """

Initialise the regulariser.

Args:

  • factor (float): the regularisation factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the EWC paper.
  • p_norm (float): the norm of the distance, should be a positive float.
factor

Store the regularisation factor for parameter change.

p_norm

Store the the norm of the distance of two set of parameters.

def forward( self, target_model: torch.nn.modules.module.Module, ref_model: torch.nn.modules.module.Module, weights: dict[str, torch.Tensor]) -> torch.Tensor:
40    def forward(
41        self,
42        target_model: nn.Module,
43        ref_model: nn.Module,
44        weights: dict[str, Tensor],
45    ) -> Tensor:
46        r"""Calculate the regularisation loss.
47
48        **Args:**
49        - **target_model** (`nn.Module`): the model of the target parameters. In EWC, it's the model of current training task.
50        - **ref_model** (`nn.Module`): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
51        - **weights** (`dict[str, Tensor]`): the regularisation weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.
52
53        **Returns:**
54        - **reg** (`Tensor`): the parameter change regularisation value.
55        """
56        reg = 0.0
57
58        for (param_name, target_param), (param_name, ref_param) in zip(
59            target_model.named_parameters(), ref_model.named_parameters()
60        ):
61            weight = weights[param_name]
62            reg += torch.sum(weight * (target_param - ref_param).norm(p=self.p_norm))
63
64        return self.factor * reg

Calculate the regularisation loss.

Args:

  • target_model (nn.Module): the model of the target parameters. In EWC, it's the model of current training task.
  • ref_model (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
  • weights (dict[str, Tensor]): the regularisation weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.

Returns:

  • reg (Tensor): the parameter change regularisation value.