clarena.cl_algorithms.regularizers.parameter_change
The submodule in regularizers for parameter change regularization.
1r"""The submodule in `regularizers` for parameter change regularization.""" 2 3__all__ = ["ParameterChangeReg"] 4 5import logging 6 7import torch 8from torch import Tensor, nn 9 10# always get logger for built-in logging in each module 11pylogger = logging.getLogger(__name__) 12 13 14class ParameterChangeReg(nn.Module): 15 r"""Parameter change regularizer. 16 17 $$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^2$$ 18 19 It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^2$ distance. The regularization can be parameter-wise weighted, i.e. $w_i$ in the formula. 20 21 It is used in: 22 - [L2 Regularization algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters. 23 - [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a weighted L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters. The regularization weights are parameter importance measure calculated from fisher information. See equation 3 in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 24 """ 25 26 def __init__( 27 self, 28 factor: float, 29 ) -> None: 30 r""" 31 **Args:** 32 - **factor** (`float`): the regularization factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 33 """ 34 super().__init__() 35 36 self.factor = factor 37 """The regularization factor for parameter change.""" 38 39 def forward( 40 self, 41 target_model: nn.Module, 42 ref_model: nn.Module, 43 weights: dict[str, Tensor], 44 ) -> Tensor: 45 r"""Calculate the regularization loss. 46 47 **Args:** 48 - **target_model** (nn.Module): the model of the target parameters. In EWC, it's the model of current training task. 49 - **ref_model** (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks. 50 - **weights** (dict[str, Tensor]): the regularization weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing. 51 52 **Returns:** 53 - **reg** (Tensor): the parameter change regularization value. 54 """ 55 reg = 0.0 56 57 # Compute the weighted squared difference for each parameter 58 for (param_name, target_param), (_, ref_param) in zip( 59 target_model.named_parameters(), ref_model.named_parameters() 60 ): 61 weight = weights[param_name] 62 # Element-wise squared difference multiplied by importance 63 reg += torch.sum(weight * (target_param - ref_param).pow(2)) 64 65 return self.factor * reg
class
ParameterChangeReg(torch.nn.modules.module.Module):
15class ParameterChangeReg(nn.Module): 16 r"""Parameter change regularizer. 17 18 $$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^2$$ 19 20 It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^2$ distance. The regularization can be parameter-wise weighted, i.e. $w_i$ in the formula. 21 22 It is used in: 23 - [L2 Regularization algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters. 24 - [EWC (Elastic Weight Consolidation) algorithm](https://www.pnas.org/doi/10.1073/pnas.1611835114): as a weighted L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters. The regularization weights are parameter importance measure calculated from fisher information. See equation 3 in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 25 """ 26 27 def __init__( 28 self, 29 factor: float, 30 ) -> None: 31 r""" 32 **Args:** 33 - **factor** (`float`): the regularization factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 34 """ 35 super().__init__() 36 37 self.factor = factor 38 """The regularization factor for parameter change.""" 39 40 def forward( 41 self, 42 target_model: nn.Module, 43 ref_model: nn.Module, 44 weights: dict[str, Tensor], 45 ) -> Tensor: 46 r"""Calculate the regularization loss. 47 48 **Args:** 49 - **target_model** (nn.Module): the model of the target parameters. In EWC, it's the model of current training task. 50 - **ref_model** (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks. 51 - **weights** (dict[str, Tensor]): the regularization weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing. 52 53 **Returns:** 54 - **reg** (Tensor): the parameter change regularization value. 55 """ 56 reg = 0.0 57 58 # Compute the weighted squared difference for each parameter 59 for (param_name, target_param), (_, ref_param) in zip( 60 target_model.named_parameters(), ref_model.named_parameters() 61 ): 62 weight = weights[param_name] 63 # Element-wise squared difference multiplied by importance 64 reg += torch.sum(weight * (target_param - ref_param).pow(2)) 65 66 return self.factor * reg
Parameter change regularizer.
$$R(\theta) = \text{factor} * \sum_i w_i \|\theta_i - \theta^\star_i\|^2$$
It promotes the target set of parameters $\theta = {\theta_i}_i$ not changing too much from another set of parameters $\theta^\star = {\theta^\star_i}_i$. The parameter distance here is $L^2$ distance. The regularization can be parameter-wise weighted, i.e. $w_i$ in the formula.
It is used in:
- L2 Regularization algorithm: as a L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters.
- EWC (Elastic Weight Consolidation) algorithm: as a weighted L2 regularizer for the current task parameters to prevent them from changing too much from the previous task parameters. The regularization weights are parameter importance measure calculated from fisher information. See equation 3 in the EWC paper.
ParameterChangeReg(factor: float)
27 def __init__( 28 self, 29 factor: float, 30 ) -> None: 31 r""" 32 **Args:** 33 - **factor** (`float`): the regularization factor. Note that it is $\frac{\lambda}{2}$ rather than $\lambda$ in the [EWC paper](https://www.pnas.org/doi/10.1073/pnas.1611835114). 34 """ 35 super().__init__() 36 37 self.factor = factor 38 """The regularization factor for parameter change."""
def
forward( self, target_model: torch.nn.modules.module.Module, ref_model: torch.nn.modules.module.Module, weights: dict[str, torch.Tensor]) -> torch.Tensor:
40 def forward( 41 self, 42 target_model: nn.Module, 43 ref_model: nn.Module, 44 weights: dict[str, Tensor], 45 ) -> Tensor: 46 r"""Calculate the regularization loss. 47 48 **Args:** 49 - **target_model** (nn.Module): the model of the target parameters. In EWC, it's the model of current training task. 50 - **ref_model** (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks. 51 - **weights** (dict[str, Tensor]): the regularization weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing. 52 53 **Returns:** 54 - **reg** (Tensor): the parameter change regularization value. 55 """ 56 reg = 0.0 57 58 # Compute the weighted squared difference for each parameter 59 for (param_name, target_param), (_, ref_param) in zip( 60 target_model.named_parameters(), ref_model.named_parameters() 61 ): 62 weight = weights[param_name] 63 # Element-wise squared difference multiplied by importance 64 reg += torch.sum(weight * (target_param - ref_param).pow(2)) 65 66 return self.factor * reg
Calculate the regularization loss.
Args:
- target_model (nn.Module): the model of the target parameters. In EWC, it's the model of current training task.
- ref_model (nn.Module): the reference model that you want target model parameters to prevent changing from. The reference model must have the same structure as the target model. In EWC, it's the model of one of the previous tasks.
- weights (dict[str, Tensor]): the regularization weight for each parameter. Keys are parameter names and values are the weight tensors. The weight tensors must match the shape of model parameters. In EWC, it's the importance measure of each parameter, calculated from fisher information thing.
Returns:
- reg (Tensor): the parameter change regularization value.