clarena.cl_algorithms.regularizers.parameter_change

The submodule in regularizers for parameter change regularization.

 1r"""The submodule in `regularizers` for parameter change regularization."""
 2
 3__all__ = ["ParameterChangeReg"]
 4
 5import logging
 6
 7import torch
 8from torch import Tensor, nn
 9
10# always get logger for built-in logging in each module
11pylogger = logging.getLogger(__name__)
12
13
14class ParameterChangeReg(nn.Module):
15    r"""Parameter change regularizer.
16
17    $$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^2$$
18
19    It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^2$ distance. The regularization can be parameter-wise weighted, i.e. $w_i$ in the formula.
20
21    It is used in:
22    - [L2 Regularization algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters.
23    - [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a weighted L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters. The regularization weights are parameter importance measure calculated from fisher information. See equation 3 in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
24    """
25
26    def __init__(
27        self,
28        factor: float,
29    ) -> None:
30        r"""
31        **Args:**
32        - **factor** (`float`): the regularization factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
33        """
34        super().__init__()
35
36        self.factor = factor
37        """The regularization factor for parameter change."""
38
39    def forward(
40        self,
41        target_model: nn.Module,
42        ref_model: nn.Module,
43        weights: dict[str, Tensor],
44    ) -> Tensor:
45        r"""Calculate the regularization loss.
46
47        **Args:**
48        - **target_model** (nn.Module): the model of the target parameters. In EWC, it's the model of current training task.
49        - **ref_model** (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
50        - **weights** (dict[str, Tensor]): the regularization weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.
51
52        **Returns:**
53        - **reg** (Tensor): the parameter change regularization value.
54        """
55        reg = 0.0
56
57        # Compute the weighted squared difference for each parameter
58        for (param_name, target_param), (_, ref_param) in zip(
59            target_model.named_parameters(), ref_model.named_parameters()
60        ):
61            weight = weights[param_name]
62            # Element-wise squared difference multiplied by importance
63            reg += torch.sum(weight * (target_param - ref_param).pow(2))
64
65        return self.factor * reg
class ParameterChangeReg(torch.nn.modules.module.Module):
15class ParameterChangeReg(nn.Module):
16    r"""Parameter change regularizer.
17
18    $$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^2$$
19
20    It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^2$ distance. The regularization can be parameter-wise weighted, i.e. $w_i$ in the formula.
21
22    It is used in:
23    - [L2 Regularization algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters.
24    - [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a weighted L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters. The regularization weights are parameter importance measure calculated from fisher information. See equation 3 in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
25    """
26
27    def __init__(
28        self,
29        factor: float,
30    ) -> None:
31        r"""
32        **Args:**
33        - **factor** (`float`): the regularization factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
34        """
35        super().__init__()
36
37        self.factor = factor
38        """The regularization factor for parameter change."""
39
40    def forward(
41        self,
42        target_model: nn.Module,
43        ref_model: nn.Module,
44        weights: dict[str, Tensor],
45    ) -> Tensor:
46        r"""Calculate the regularization loss.
47
48        **Args:**
49        - **target_model** (nn.Module): the model of the target parameters. In EWC, it's the model of current training task.
50        - **ref_model** (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
51        - **weights** (dict[str, Tensor]): the regularization weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.
52
53        **Returns:**
54        - **reg** (Tensor): the parameter change regularization value.
55        """
56        reg = 0.0
57
58        # Compute the weighted squared difference for each parameter
59        for (param_name, target_param), (_, ref_param) in zip(
60            target_model.named_parameters(), ref_model.named_parameters()
61        ):
62            weight = weights[param_name]
63            # Element-wise squared difference multiplied by importance
64            reg += torch.sum(weight * (target_param - ref_param).pow(2))
65
66        return self.factor * reg

Parameter change regularizer.

$$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^2$$

It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^2$ distance. The regularization can be parameter-wise weighted, i.e. $w_i$ in the formula.

It is used in:

  • L2 Regularization algorithm: as a L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters.
  • EWC (Elastic Weight Consolidation) algorithm: as a weighted L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters. The regularization weights are parameter importance measure calculated from fisher information. See equation 3 in the EWC paper.
ParameterChangeReg(factor: float)
27    def __init__(
28        self,
29        factor: float,
30    ) -> None:
31        r"""
32        **Args:**
33        - **factor** (`float`): the regularization factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114).
34        """
35        super().__init__()
36
37        self.factor = factor
38        """The regularization factor for parameter change."""

Args:

  • factor (float): the regularization factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the EWC paper.
factor

The regularization factor for parameter change.

def forward( self, target_model: torch.nn.modules.module.Module, ref_model: torch.nn.modules.module.Module, weights: dict[str, torch.Tensor]) -> torch.Tensor:
40    def forward(
41        self,
42        target_model: nn.Module,
43        ref_model: nn.Module,
44        weights: dict[str, Tensor],
45    ) -> Tensor:
46        r"""Calculate the regularization loss.
47
48        **Args:**
49        - **target_model** (nn.Module): the model of the target parameters. In EWC, it's the model of current training task.
50        - **ref_model** (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
51        - **weights** (dict[str, Tensor]): the regularization weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.
52
53        **Returns:**
54        - **reg** (Tensor): the parameter change regularization value.
55        """
56        reg = 0.0
57
58        # Compute the weighted squared difference for each parameter
59        for (param_name, target_param), (_, ref_param) in zip(
60            target_model.named_parameters(), ref_model.named_parameters()
61        ):
62            weight = weights[param_name]
63            # Element-wise squared difference multiplied by importance
64            reg += torch.sum(weight * (target_param - ref_param).pow(2))
65
66        return self.factor * reg

Calculate the regularization loss.

Args:

  • target_model (nn.Module): the model of the target parameters. In EWC, it's the model of current training task.
  • ref_model (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
  • weights (dict[str, Tensor]): the regularization weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.

Returns:

  • reg (Tensor): the parameter change regularization value.